diff --git a/requirements.txt b/requirements.txt index cfa90326d..1697da90e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,6 +20,7 @@ python-multipart==0.0.5 rdkit==2022.3.5 requests==2.28.1 SQLAlchemy==1.4.45 +starlette-context==0.3.5 urllib3==1.26.9 uvicorn[standard]==0.20.0 yamlreader==3.0.4 diff --git a/tdp_core/dbmanager.py b/tdp_core/dbmanager.py index 594c45bde..6e2a9a30a 100644 --- a/tdp_core/dbmanager.py +++ b/tdp_core/dbmanager.py @@ -8,7 +8,7 @@ from . import manager from .dbview import DBConnector from .middleware.close_web_sessions_middleware import CloseWebSessionsMiddleware -from .middleware.request_context_middleware import get_request +from .middleware.request_context_plugin import get_request _log = logging.getLogger(__name__) @@ -27,7 +27,8 @@ def init_app(self, app: FastAPI): for p in manager.registry.list("tdp-sql-database-definition"): config: Dict[str, Any] = manager.settings.get_nested(p.configKey) # type: ignore - connector: DBConnector = p.load().factory() + # Only instantiate the connector if it has a module factory, otherwise use an empty one + connector: DBConnector = p.load().factory() if p.module else DBConnector() if not connector.dburl: connector.dburl = config["dburl"] if not connector.statement_timeout: @@ -93,11 +94,14 @@ def create_web_session(self, engine_or_id: Union[Engine, str]) -> Session: """ session = self.create_session(engine_or_id) + r = get_request() + if not r: + raise Exception("No request found, did you use a create_web_sesssion outside of a request?") try: - existing_sessions = get_request().state.db_sessions + existing_sessions = r.state.db_sessions except (KeyError, AttributeError): existing_sessions = [] - get_request().state.db_sessions = existing_sessions + r.state.db_sessions = existing_sessions existing_sessions.append(session) return session diff --git a/tdp_core/dbview.py b/tdp_core/dbview.py index 616e84beb..1dd7bf33e 100644 --- a/tdp_core/dbview.py +++ b/tdp_core/dbview.py @@ -602,7 +602,7 @@ class DBConnector(object): basic connector object """ - def __init__(self, views, agg_score=None, mappings=None): + def __init__(self, views={}, agg_score=None, mappings=None): """ :param views: the dict of query views :param agg_score: optional specify how aggregation should be handled @@ -621,19 +621,7 @@ def dump(self, name): def create_engine(self, config) -> Engine: engine_options = config.get("engine", {}) - engine = sqlalchemy.create_engine(self.dburl, **engine_options) - try: - # Assuming that gevent monkey patched the builtin - # threading library, we're likely good to use - # SQLAlchemy's QueuePool, which is the default - # pool class. However, we need to make it use - # threadlocal connections - # https://github.com/kljensen/async-flask-sqlalchemy-example/blob/master/server.py - engine.pool._use_threadlocal = True # type: ignore - except Exception: - pass - - return engine + return sqlalchemy.create_engine(self.dburl, pool_size=30, pool_pre_ping=True, **engine_options) def create_sessionmaker(self, engine) -> sessionmaker: return sessionmaker(bind=engine) diff --git a/tdp_core/middleware/close_web_sessions_middleware.py b/tdp_core/middleware/close_web_sessions_middleware.py index 068fb5e9d..c5ac1daed 100644 --- a/tdp_core/middleware/close_web_sessions_middleware.py +++ b/tdp_core/middleware/close_web_sessions_middleware.py @@ -1,18 +1,23 @@ -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request +from fastapi import FastAPI +from .request_context_plugin import get_request -class CloseWebSessionsMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): - response = await call_next(request) - try: - for db_session in request.state.db_sessions: - try: - db_session.close() - except Exception: - pass - except (KeyError, AttributeError): - pass +# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643 +class CloseWebSessionsMiddleware: + def __init__(self, app: FastAPI): + self.app = app - return response + async def __call__(self, scope, receive, send): + await self.app(scope, receive, send) + + r = get_request() + if r: + try: + for db_session in r.state.db_sessions: + try: + db_session.close() + except Exception: + pass + except (KeyError, AttributeError): + pass diff --git a/tdp_core/middleware/exception_handler_middleware.py b/tdp_core/middleware/exception_handler_middleware.py index 31a24848e..66ee38c48 100644 --- a/tdp_core/middleware/exception_handler_middleware.py +++ b/tdp_core/middleware/exception_handler_middleware.py @@ -1,20 +1,23 @@ import logging -from fastapi import HTTPException +from fastapi import FastAPI, HTTPException from fastapi.exception_handlers import http_exception_handler -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request from ..server.utils import detail_from_exception +from .request_context_plugin import get_request -class ExceptionHandlerMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): +# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643 +class ExceptionHandlerMiddleware: + def __init__(self, app: FastAPI): + self.app = app + + async def __call__(self, scope, receive, send): try: - return await call_next(request) + await self.app(scope, receive, send) except Exception as e: logging.exception("An error occurred in FastAPI") return await http_exception_handler( - request, + get_request(), # type: ignore e if isinstance(e, HTTPException) else HTTPException(status_code=500, detail=detail_from_exception(e)), ) diff --git a/tdp_core/middleware/request_context_middleware.py b/tdp_core/middleware/request_context_middleware.py deleted file mode 100644 index d102332d6..000000000 --- a/tdp_core/middleware/request_context_middleware.py +++ /dev/null @@ -1,21 +0,0 @@ -from contextvars import ContextVar -from typing import Optional - -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request - -REQUEST_CTX_KEY = "fastapi_request" - -_request_ctx_var: ContextVar[Optional[Request]] = ContextVar(REQUEST_CTX_KEY, default=None) - - -def get_request() -> Request: - return _request_ctx_var.get() # type: ignore TODO: It is None in non-request context - - -class RequestContextMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): - request_ctx_key = _request_ctx_var.set(request) - response = await call_next(request) - _request_ctx_var.reset(request_ctx_key) - return response diff --git a/tdp_core/middleware/request_context_plugin.py b/tdp_core/middleware/request_context_plugin.py new file mode 100644 index 000000000..ed2b65b33 --- /dev/null +++ b/tdp_core/middleware/request_context_plugin.py @@ -0,0 +1,17 @@ +from typing import Optional + +from starlette.requests import HTTPConnection, Request +from starlette_context import context +from starlette_context.plugins.base import Plugin + + +def get_request() -> Request | None: + return context.get("request") + + +class RequestContextPlugin(Plugin): + # The returned value will be inserted in the context with this key + key = "request" + + async def process_request(self, request: Request | HTTPConnection) -> Optional[Request | HTTPConnection]: + return request diff --git a/tdp_core/security/manager.py b/tdp_core/security/manager.py index 3a26a358a..5af45a013 100644 --- a/tdp_core/security/manager.py +++ b/tdp_core/security/manager.py @@ -9,7 +9,7 @@ from fastapi.security.utils import get_authorization_scheme_param from .. import manager -from ..middleware.request_context_middleware import get_request +from ..middleware.request_context_plugin import get_request from .model import ANONYMOUS_USER, LogoutReturnValue, User from .store.base_store import BaseStore @@ -119,17 +119,18 @@ def _delegate_stores_until_not_none(self, store_method_name: str, *args): @property def current_user(self) -> Optional[User]: try: - req = get_request() - # Fetch the existing user from the request if there is any - try: - user = req.state.user - if user: - return user - except (KeyError, AttributeError): - pass - # If there is no user, try to load it from the request and store it in the request - user = req.state.user = self.load_from_request(get_request()) - return user + r = get_request() + if r: + # Fetch the existing user from the request if there is any + try: + user = r.state.user + if user: + return user + except (KeyError, AttributeError): + pass + # If there is no user, try to load it from the request and store it in the request + user = r.state.user = self.load_from_request(r) + return user except HTTPException: return None except Exception: diff --git a/tdp_core/server/visyn_server.py b/tdp_core/server/visyn_server.py index 3a0239657..bb5c2f510 100644 --- a/tdp_core/server/visyn_server.py +++ b/tdp_core/server/visyn_server.py @@ -8,6 +8,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware from pydantic import create_model from pydantic.utils import deep_update +from starlette_context.middleware import RawContextMiddleware from ..settings.constants import default_logging_dict @@ -60,7 +61,6 @@ def create_visyn_server( ) from ..middleware.exception_handler_middleware import ExceptionHandlerMiddleware - from ..middleware.request_context_middleware import RequestContextMiddleware # TODO: For some reason, a @app.exception_handler(Exception) is not called here. We use a middleware instead. app.add_middleware(ExceptionHandlerMiddleware) @@ -143,8 +143,10 @@ def create_visyn_server( for p in plugins: p.plugin.init_app(app) - # Add middleware to access Request "outside" - app.add_middleware(RequestContextMiddleware) + from ..middleware.request_context_plugin import RequestContextPlugin + + # Use starlette-context to store the current request globally, i.e. accessible via context['request'] + app.add_middleware(RawContextMiddleware, plugins=(RequestContextPlugin(),)) # TODO: Move up? app.add_api_route("/health", health) # type: ignore