Skip to content

Commit

Permalink
startup and shutdown replaced with lifespan
Browse files Browse the repository at this point in the history
  • Loading branch information
igorbenav committed Feb 10, 2024
1 parent 444ce98 commit 2a60325
Showing 1 changed file with 49 additions and 20 deletions.
69 changes: 49 additions & 20 deletions src/app/core/setup.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator, Callable
from contextlib import _AsyncGeneratorContextManager, asynccontextmanager
from typing import Any

import anyio
Expand Down Expand Up @@ -68,6 +70,50 @@ async def set_threadpool_tokens(number_of_tokens: int = 100) -> None:
limiter.total_tokens = number_of_tokens


def lifespan_factory(
settings: (
DatabaseSettings
| RedisCacheSettings
| AppSettings
| ClientSideCacheSettings
| RedisQueueSettings
| RedisRateLimiterSettings
| EnvironmentSettings
),
create_tables_on_start: bool = True,
) -> Callable[[FastAPI], _AsyncGeneratorContextManager[Any]]:
"""Factory to create a lifespan async context manager for a FastAPI app."""

@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
await set_threadpool_tokens()

if isinstance(settings, DatabaseSettings) and create_tables_on_start:
await create_tables()

if isinstance(settings, RedisCacheSettings):
await create_redis_cache_pool()

if isinstance(settings, RedisQueueSettings):
await create_redis_queue_pool()

if isinstance(settings, RedisRateLimiterSettings):
await create_redis_rate_limit_pool()

yield

if isinstance(settings, RedisCacheSettings):
await close_redis_cache_pool()

if isinstance(settings, RedisQueueSettings):
await close_redis_queue_pool()

if isinstance(settings, RedisRateLimiterSettings):
await close_redis_rate_limit_pool()

return lifespan


# -------------- application --------------
def create_application(
router: APIRouter,
Expand Down Expand Up @@ -136,30 +182,13 @@ def create_application(
if isinstance(settings, EnvironmentSettings):
kwargs.update({"docs_url": None, "redoc_url": None, "openapi_url": None})

application = FastAPI(**kwargs)

# --- application created ---
application.include_router(router)
application.add_event_handler("startup", set_threadpool_tokens)
lifespan = lifespan_factory(settings, create_tables_on_start=create_tables_on_start)

if isinstance(settings, DatabaseSettings) and create_tables_on_start:
application.add_event_handler("startup", create_tables)

if isinstance(settings, RedisCacheSettings):
application.add_event_handler("startup", create_redis_cache_pool)
application.add_event_handler("shutdown", close_redis_cache_pool)
application = FastAPI(lifespan=lifespan, **kwargs)

if isinstance(settings, ClientSideCacheSettings):
application.add_middleware(ClientCacheMiddleware, max_age=settings.CLIENT_CACHE_MAX_AGE)

if isinstance(settings, RedisQueueSettings):
application.add_event_handler("startup", create_redis_queue_pool)
application.add_event_handler("shutdown", close_redis_queue_pool)

if isinstance(settings, RedisRateLimiterSettings):
application.add_event_handler("startup", create_redis_rate_limit_pool)
application.add_event_handler("shutdown", close_redis_rate_limit_pool)

if isinstance(settings, EnvironmentSettings):
if settings.ENVIRONMENT != EnvironmentOption.PRODUCTION:
docs_router = APIRouter()
Expand All @@ -181,4 +210,4 @@ async def openapi() -> dict[str, Any]:

application.include_router(docs_router)

return application
return application

0 comments on commit 2a60325

Please # to comment.