Skip to content

Commit

Permalink
Replace BaseHTTPMiddleware with much faster ASGI equalivalent
Browse files Browse the repository at this point in the history
  • Loading branch information
puehringer committed Dec 22, 2022
1 parent 068e108 commit 0c34551
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 75 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ python-multipart==0.0.5
rdkit==2022.3.5
requests==2.28.1
SQLAlchemy==1.4.45
starlette-context==0.3.5
urllib3==1.26.9
uvicorn[standard]==0.20.0
yamlreader==3.0.4
12 changes: 8 additions & 4 deletions tdp_core/dbmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from . import manager
from .dbview import DBConnector
from .middleware.close_web_sessions_middleware import CloseWebSessionsMiddleware
from .middleware.request_context_middleware import get_request
from .middleware.request_context_plugin import get_request

_log = logging.getLogger(__name__)

Expand All @@ -27,7 +27,8 @@ def init_app(self, app: FastAPI):

for p in manager.registry.list("tdp-sql-database-definition"):
config: Dict[str, Any] = manager.settings.get_nested(p.configKey) # type: ignore
connector: DBConnector = p.load().factory()
# Only instantiate the connector if it has a module factory, otherwise use an empty one
connector: DBConnector = p.load().factory() if p.module else DBConnector()
if not connector.dburl:
connector.dburl = config["dburl"]
if not connector.statement_timeout:
Expand Down Expand Up @@ -93,11 +94,14 @@ def create_web_session(self, engine_or_id: Union[Engine, str]) -> Session:
"""
session = self.create_session(engine_or_id)

r = get_request()
if not r:
raise Exception("No request found, did you use a create_web_sesssion outside of a request?")
try:
existing_sessions = get_request().state.db_sessions
existing_sessions = r.state.db_sessions
except (KeyError, AttributeError):
existing_sessions = []
get_request().state.db_sessions = existing_sessions
r.state.db_sessions = existing_sessions
existing_sessions.append(session)

return session
16 changes: 2 additions & 14 deletions tdp_core/dbview.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,7 @@ class DBConnector(object):
basic connector object
"""

def __init__(self, views, agg_score=None, mappings=None):
def __init__(self, views={}, agg_score=None, mappings=None):
"""
:param views: the dict of query views
:param agg_score: optional specify how aggregation should be handled
Expand All @@ -621,19 +621,7 @@ def dump(self, name):

def create_engine(self, config) -> Engine:
engine_options = config.get("engine", {})
engine = sqlalchemy.create_engine(self.dburl, **engine_options)
try:
# Assuming that gevent monkey patched the builtin
# threading library, we're likely good to use
# SQLAlchemy's QueuePool, which is the default
# pool class. However, we need to make it use
# threadlocal connections
# https://github.com/kljensen/async-flask-sqlalchemy-example/blob/master/server.py
engine.pool._use_threadlocal = True # type: ignore
except Exception:
pass

return engine
return sqlalchemy.create_engine(self.dburl, pool_size=30, pool_pre_ping=True, **engine_options)

def create_sessionmaker(self, engine) -> sessionmaker:
return sessionmaker(bind=engine)
33 changes: 19 additions & 14 deletions tdp_core/middleware/close_web_sessions_middleware.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from fastapi import FastAPI

from .request_context_plugin import get_request

class CloseWebSessionsMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
response = await call_next(request)

try:
for db_session in request.state.db_sessions:
try:
db_session.close()
except Exception:
pass
except (KeyError, AttributeError):
pass
# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643
class CloseWebSessionsMiddleware:
def __init__(self, app: FastAPI):
self.app = app

return response
async def __call__(self, scope, receive, send):
await self.app(scope, receive, send)

r = get_request()
if r:
try:
for db_session in r.state.db_sessions:
try:
db_session.close()
except Exception:
pass
except (KeyError, AttributeError):
pass
17 changes: 10 additions & 7 deletions tdp_core/middleware/exception_handler_middleware.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
import logging

from fastapi import HTTPException
from fastapi import FastAPI, HTTPException
from fastapi.exception_handlers import http_exception_handler
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request

from ..server.utils import detail_from_exception
from .request_context_plugin import get_request


class ExceptionHandlerMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint):
# Use basic ASGI middleware instead of BaseHTTPMiddleware as it is significantly faster: https://github.com/tiangolo/fastapi/issues/2696#issuecomment-768224643
class ExceptionHandlerMiddleware:
def __init__(self, app: FastAPI):
self.app = app

async def __call__(self, scope, receive, send):
try:
return await call_next(request)
await self.app(scope, receive, send)
except Exception as e:
logging.exception("An error occurred in FastAPI")
return await http_exception_handler(
request,
get_request(), # type: ignore
e if isinstance(e, HTTPException) else HTTPException(status_code=500, detail=detail_from_exception(e)),
)
21 changes: 0 additions & 21 deletions tdp_core/middleware/request_context_middleware.py

This file was deleted.

17 changes: 17 additions & 0 deletions tdp_core/middleware/request_context_plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from typing import Optional

from starlette.requests import HTTPConnection, Request
from starlette_context import context
from starlette_context.plugins.base import Plugin


def get_request() -> Request | None:
return context.get("request")


class RequestContextPlugin(Plugin):
# The returned value will be inserted in the context with this key
key = "request"

async def process_request(self, request: Request | HTTPConnection) -> Optional[Request | HTTPConnection]:
return request
25 changes: 13 additions & 12 deletions tdp_core/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi.security.utils import get_authorization_scheme_param

from .. import manager
from ..middleware.request_context_middleware import get_request
from ..middleware.request_context_plugin import get_request
from .model import ANONYMOUS_USER, LogoutReturnValue, User
from .store.base_store import BaseStore

Expand Down Expand Up @@ -119,17 +119,18 @@ def _delegate_stores_until_not_none(self, store_method_name: str, *args):
@property
def current_user(self) -> Optional[User]:
try:
req = get_request()
# Fetch the existing user from the request if there is any
try:
user = req.state.user
if user:
return user
except (KeyError, AttributeError):
pass
# If there is no user, try to load it from the request and store it in the request
user = req.state.user = self.load_from_request(get_request())
return user
r = get_request()
if r:
# Fetch the existing user from the request if there is any
try:
user = r.state.user
if user:
return user
except (KeyError, AttributeError):
pass
# If there is no user, try to load it from the request and store it in the request
user = r.state.user = self.load_from_request(r)
return user
except HTTPException:
return None
except Exception:
Expand Down
8 changes: 5 additions & 3 deletions tdp_core/server/visyn_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.middleware.wsgi import WSGIMiddleware
from pydantic import create_model
from pydantic.utils import deep_update
from starlette_context.middleware import RawContextMiddleware

from ..settings.constants import default_logging_dict

Expand Down Expand Up @@ -60,7 +61,6 @@ def create_visyn_server(
)

from ..middleware.exception_handler_middleware import ExceptionHandlerMiddleware
from ..middleware.request_context_middleware import RequestContextMiddleware

# TODO: For some reason, a @app.exception_handler(Exception) is not called here. We use a middleware instead.
app.add_middleware(ExceptionHandlerMiddleware)
Expand Down Expand Up @@ -143,8 +143,10 @@ def create_visyn_server(
for p in plugins:
p.plugin.init_app(app)

# Add middleware to access Request "outside"
app.add_middleware(RequestContextMiddleware)
from ..middleware.request_context_plugin import RequestContextPlugin

# Use starlette-context to store the current request globally, i.e. accessible via context['request']
app.add_middleware(RawContextMiddleware, plugins=(RequestContextPlugin(),))

# TODO: Move up?
app.add_api_route("/health", health) # type: ignore
Expand Down

0 comments on commit 0c34551

Please # to comment.