From 2a60325acbe9a82cbcd70266d8f183e7c7a5e260 Mon Sep 17 00:00:00 2001 From: Igor Magalhaes Date: Sat, 10 Feb 2024 15:59:01 -0300 Subject: [PATCH] startup and shutdown replaced with lifespan --- src/app/core/setup.py | 69 ++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 20 deletions(-) diff --git a/src/app/core/setup.py b/src/app/core/setup.py index c2ceadc..c126cd3 100644 --- a/src/app/core/setup.py +++ b/src/app/core/setup.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator, Callable +from contextlib import _AsyncGeneratorContextManager, asynccontextmanager from typing import Any import anyio @@ -68,6 +70,50 @@ async def set_threadpool_tokens(number_of_tokens: int = 100) -> None: limiter.total_tokens = number_of_tokens +def lifespan_factory( + settings: ( + DatabaseSettings + | RedisCacheSettings + | AppSettings + | ClientSideCacheSettings + | RedisQueueSettings + | RedisRateLimiterSettings + | EnvironmentSettings + ), + create_tables_on_start: bool = True, +) -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]: + """Factory to create a lifespan async context manager for a FastAPI app.""" + + @asynccontextmanager + async def lifespan(app: FastAPI) -> AsyncGenerator: + await set_threadpool_tokens() + + if isinstance(settings, DatabaseSettings) and create_tables_on_start: + await create_tables() + + if isinstance(settings, RedisCacheSettings): + await create_redis_cache_pool() + + if isinstance(settings, RedisQueueSettings): + await create_redis_queue_pool() + + if isinstance(settings, RedisRateLimiterSettings): + await create_redis_rate_limit_pool() + + yield + + if isinstance(settings, RedisCacheSettings): + await close_redis_cache_pool() + + if isinstance(settings, RedisQueueSettings): + await close_redis_queue_pool() + + if isinstance(settings, RedisRateLimiterSettings): + await close_redis_rate_limit_pool() + + return lifespan + + # -------------- application -------------- def create_application( router: APIRouter, @@ -136,30 +182,13 @@ def create_application( if isinstance(settings, EnvironmentSettings): kwargs.update({"docs_url": None, "redoc_url": None, "openapi_url": None}) - application = FastAPI(**kwargs) - - # --- application created --- - application.include_router(router) - application.add_event_handler("startup", set_threadpool_tokens) + lifespan = lifespan_factory(settings, create_tables_on_start=create_tables_on_start) - if isinstance(settings, DatabaseSettings) and create_tables_on_start: - application.add_event_handler("startup", create_tables) - - if isinstance(settings, RedisCacheSettings): - application.add_event_handler("startup", create_redis_cache_pool) - application.add_event_handler("shutdown", close_redis_cache_pool) + application = FastAPI(lifespan=lifespan, **kwargs) if isinstance(settings, ClientSideCacheSettings): application.add_middleware(ClientCacheMiddleware, max_age=settings.CLIENT_CACHE_MAX_AGE) - if isinstance(settings, RedisQueueSettings): - application.add_event_handler("startup", create_redis_queue_pool) - application.add_event_handler("shutdown", close_redis_queue_pool) - - if isinstance(settings, RedisRateLimiterSettings): - application.add_event_handler("startup", create_redis_rate_limit_pool) - application.add_event_handler("shutdown", close_redis_rate_limit_pool) - if isinstance(settings, EnvironmentSettings): if settings.ENVIRONMENT != EnvironmentOption.PRODUCTION: docs_router = APIRouter() @@ -181,4 +210,4 @@ async def openapi() -> dict[str, Any]: application.include_router(docs_router) - return application + return application