diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ce61964246..bab5121b28 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -7,13 +7,14 @@ import torch import uvicorn -from fastapi import FastAPI +from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi_events.handlers.local import local_handler from fastapi_events.middleware import EventHandlerASGIMiddleware +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from torch.backends.mps import is_available as is_mps_available # for PyCharm: @@ -78,6 +79,29 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) + +class RedirectRootWithQueryStringMiddleware(BaseHTTPMiddleware): + """When a request is made to the root path with a query string, redirect to the root path without the query string. + + For example, to force a Gradio app to use dark mode, users may append `?__theme=dark` to the URL. Their browser may + have this query string saved in history or a bookmark, so when the user navigates to `http://127.0.0.1:9090/`, the + browser takes them to `http://127.0.0.1:9090/?__theme=dark`. + + This breaks the static file serving in the UI, so we redirect the user to the root path without the query string. + """ + + async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): + if request.url.path == "/" and request.url.query: + return RedirectResponse(url="/") + + response = await call_next(request) + return response + + +# Add the middleware +app.add_middleware(RedirectRootWithQueryStringMiddleware) + + # Add event handler event_handler_id: int = id(app) app.add_middleware(