From b5babd349a27ab4a5b8fd05b5e416c4c7e2e4bf8 Mon Sep 17 00:00:00 2001 From: Igor Magalhaes Date: Wed, 24 Jan 2024 01:56:15 -0300 Subject: [PATCH] python upgrade, CRUDBase replaced with fastcrud --- README.md | 13 +- pyproject.toml | 1 + src/app/core/db/crud_token_blacklist.py | 5 +- src/app/core/exceptions/http_exceptions.py | 45 -- src/app/core/worker/functions.py | 1 + src/app/core/worker/settings.py | 3 +- src/app/crud/crud_base.py | 467 --------------------- src/app/crud/crud_posts.py | 5 +- src/app/crud/crud_rate_limit.py | 5 +- src/app/crud/crud_tier.py | 5 +- src/app/crud/crud_users.py | 5 +- src/app/crud/helper.py | 124 ------ 12 files changed, 26 insertions(+), 653 deletions(-) delete mode 100644 src/app/crud/crud_base.py delete mode 100644 src/app/crud/helper.py diff --git a/README.md b/README.md index edfeb42..9117e71 100644 --- a/README.md +++ b/README.md @@ -55,8 +55,8 @@ - 🏬 Easy redis caching - 👜 Easy client-side caching - 🚦 ARQ integration for task queue -- ⚙️ Efficient querying (only queries what's needed) with support for joins -- ⎘ Out of the box pagination support +- ⚙️ Efficient and robust queries with fastcrud +- ⎘ Out of the box offset and cursor pagination support with fastcrud - 🛑 Rate Limiter dependency - 👮 FastAPI docs behind authentication and hidden based on the environment - 🦾 Easily extendable @@ -749,14 +749,15 @@ poetry run alembic upgrade head ### 5.6 CRUD -Inside `app/crud`, create a new `crud_entities.py` inheriting from `CRUDBase` for each new entity: +Inside `app/crud`, create a new `crud_entities.py` inheriting from `FastCRUD` for each new entity: ```python -from app.crud.crud_base import CRUDBase +from fastcrud import FastCRUD + from app.models.entity import Entity from app.schemas.entity import EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete -CRUDEntity = CRUDBase[Entity, EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete] +CRUDEntity = FastCRUD[Entity, EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete] crud_entity = CRUDEntity(Entity) ``` @@ -767,7 +768,7 @@ So, for users: from app.model.user import User from app.schemas.user import UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete -CRUDUser = CRUDBase[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete] +CRUDUser = FastCRUD[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete] crud_users = CRUDUser(User) ``` diff --git a/pyproject.toml b/pyproject.toml index 1463706..35dcad8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ redis = "^5.0.1" arq = "^0.25.0" gunicorn = "^21.2.0" bcrypt = "^4.1.1" +fastcrud = "^0.1.5" [build-system] diff --git a/src/app/core/db/crud_token_blacklist.py b/src/app/core/db/crud_token_blacklist.py index bd15058..a67aec6 100644 --- a/src/app/core/db/crud_token_blacklist.py +++ b/src/app/core/db/crud_token_blacklist.py @@ -1,6 +1,7 @@ -from ...crud.crud_base import CRUDBase +from fastcrud import FastCRUD + from ..db.token_blacklist import TokenBlacklist from ..schemas import TokenBlacklistCreate, TokenBlacklistUpdate -CRUDTokenBlacklist = CRUDBase[TokenBlacklist, TokenBlacklistCreate, TokenBlacklistUpdate, TokenBlacklistUpdate, None] +CRUDTokenBlacklist = FastCRUD[TokenBlacklist, TokenBlacklistCreate, TokenBlacklistUpdate, TokenBlacklistUpdate, None] crud_token_blacklist = CRUDTokenBlacklist(TokenBlacklist) diff --git a/src/app/core/exceptions/http_exceptions.py b/src/app/core/exceptions/http_exceptions.py index 167934a..e69de29 100644 --- a/src/app/core/exceptions/http_exceptions.py +++ b/src/app/core/exceptions/http_exceptions.py @@ -1,45 +0,0 @@ -from http import HTTPStatus - -from fastapi import HTTPException, status - - -class CustomException(HTTPException): - def __init__(self, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, detail: str | None = None): - if not detail: - detail = HTTPStatus(status_code).description - super().__init__(status_code=status_code, detail=detail) - - -class BadRequestException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail) - - -class NotFoundException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=detail) - - -class ForbiddenException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail) - - -class UnauthorizedException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail) - - -class UnprocessableEntityException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail) - - -class DuplicateValueException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail) - - -class RateLimitException(CustomException): - def __init__(self, detail: str | None = None): - super().__init__(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail) diff --git a/src/app/core/worker/functions.py b/src/app/core/worker/functions.py index 88dcf3b..296d894 100644 --- a/src/app/core/worker/functions.py +++ b/src/app/core/worker/functions.py @@ -1,5 +1,6 @@ import asyncio import logging + import uvloop from arq.worker import Worker diff --git a/src/app/core/worker/settings.py b/src/app/core/worker/settings.py index 89556c4..bb21894 100644 --- a/src/app/core/worker/settings.py +++ b/src/app/core/worker/settings.py @@ -1,6 +1,7 @@ from arq.connections import RedisSettings + from ...core.config import settings -from .functions import sample_background_task, startup, shutdown +from .functions import sample_background_task, shutdown, startup REDIS_QUEUE_HOST = settings.REDIS_QUEUE_HOST REDIS_QUEUE_PORT = settings.REDIS_QUEUE_PORT diff --git a/src/app/crud/crud_base.py b/src/app/crud/crud_base.py deleted file mode 100644 index dfbf627..0000000 --- a/src/app/crud/crud_base.py +++ /dev/null @@ -1,467 +0,0 @@ -from datetime import UTC, datetime -from typing import Any, Generic, TypeVar - -from pydantic import BaseModel -from sqlalchemy import and_, delete, func, inspect, select, update -from sqlalchemy.engine.row import Row -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.sql import Join - -from ..core.db.database import Base -from .helper import ( - _add_column_with_prefix, - _auto_detect_join_condition, - _extract_matching_columns_from_kwargs, - _extract_matching_columns_from_schema, -) - -ModelType = TypeVar("ModelType", bound=Base) -CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel) -UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) -UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel) -DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel) - - -class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType, UpdateSchemaInternalType, DeleteSchemaType]): - """Base class for CRUD operations on a model. - - Parameters - ---------- - model : Type[ModelType] - The SQLAlchemy model type. - """ - - def __init__(self, model: type[ModelType]) -> None: - self._model = model - - async def create(self, db: AsyncSession, object: CreateSchemaType) -> ModelType: - """Create a new record in the database. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - object : CreateSchemaType - The Pydantic schema containing the data to be saved. - - Returns - ------- - ModelType - The created database object. - """ - object_dict = object.model_dump() - db_object: ModelType = self._model(**object_dict) - db.add(db_object) - await db.commit() - return db_object - - async def get( - self, db: AsyncSession, schema_to_select: type[BaseModel] | list | None = None, **kwargs: Any - ) -> dict | None: - """Fetch a single record based on filters. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - schema_to_select : Union[Type[BaseModel], list, None], optional - Pydantic schema for selecting specific columns. Default is None to select all columns. - kwargs : dict - Filters to apply to the query. - - Returns - ------- - dict | None - The fetched database row or None if not found. - """ - to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) - stmt = select(*to_select).filter_by(**kwargs) - - db_row = await db.execute(stmt) - result: Row = db_row.first() - if result is not None: - out: dict = dict(result._mapping) - return out - - return None - - async def exists(self, db: AsyncSession, **kwargs: Any) -> bool: - """Check if a record exists based on filters. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - kwargs : dict - Filters to apply to the query. - - Returns - ------- - bool - True if a record exists, False otherwise. - """ - to_select = _extract_matching_columns_from_kwargs(model=self._model, kwargs=kwargs) - stmt = select(*to_select).filter_by(**kwargs).limit(1) - - result = await db.execute(stmt) - return result.first() is not None - - async def count(self, db: AsyncSession, **kwargs: Any) -> int: - """Count the records based on filters. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - kwargs : dict - Filters to apply to the query. - - Returns - ------- - int - Total count of records that match the applied filters. - - Note - ---- - This method provides a quick way to get the count of records without retrieving the actual data. - """ - if kwargs: - conditions = [getattr(self._model, key) == value for key, value in kwargs.items()] - combined_conditions = and_(*conditions) - count_query = select(func.count()).select_from(self._model).filter(combined_conditions) - else: - count_query = select(func.count()).select_from(self._model) - - total_count: int = await db.scalar(count_query) - - return total_count - - async def get_multi( - self, - db: AsyncSession, - offset: int = 0, - limit: int = 100, - schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None, - **kwargs: Any, - ) -> dict[str, Any]: - """Fetch multiple records based on filters. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - offset : int, optional - Number of rows to skip before fetching. Default is 0. - limit : int, optional - Maximum number of rows to fetch. Default is 100. - schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional - Pydantic schema for selecting specific columns. Default is None to select all columns. - kwargs : dict - Filters to apply to the query. - - Returns - ------- - dict[str, Any] - Dictionary containing the fetched rows under 'data' key and total count under 'total_count'. - """ - to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) - stmt = select(*to_select).filter_by(**kwargs).offset(offset).limit(limit) - - result = await db.execute(stmt) - data = [dict(row) for row in result.mappings()] - - total_count = await self.count(db=db, **kwargs) - - return {"data": data, "total_count": total_count} - - async def get_joined( - self, - db: AsyncSession, - join_model: type[ModelType], - join_prefix: str | None = None, - join_on: Join | None = None, - schema_to_select: type[BaseModel] | list | None = None, - join_schema_to_select: type[BaseModel] | list | None = None, - join_type: str = "left", - **kwargs: Any, - ) -> dict | None: - """Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts - to automatically detect the join condition using foreign key relationships. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - join_model : Type[ModelType] - The model to join with. - join_prefix : Optional[str] - Optional prefix to be added to all columns of the joined model. If None, no prefix is added. - join_on : Join, optional - SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is - auto-detected based on foreign keys. - schema_to_select : Union[Type[BaseModel], list, None], optional - Pydantic schema for selecting specific columns from the primary model. - join_schema_to_select : Union[Type[BaseModel], list, None], optional - Pydantic schema for selecting specific columns from the joined model. - join_type : str, default "left" - Specifies the type of join operation to perform. Can be "left" for a left outer join - or "inner" for an inner join. - kwargs : dict - Filters to apply to the query. - - Returns - ------- - dict | None - The fetched database row or None if not found. - - Examples - -------- - Simple example: Joining User and Tier models without explicitly providing join_on - ```python - result = await crud_user.get_joined( - db=session, join_model=Tier, schema_to_select=UserSchema, join_schema_to_select=TierSchema - ) - ``` - - Complex example: Joining with a custom join condition, additional filter parameters, and a prefix - ```python - from sqlalchemy import and_ - - result = await crud_user.get_joined( - db=session, - join_model=Tier, - join_prefix="tier_", - join_on=and_(User.tier_id == Tier.id, User.is_superuser == True), - schema_to_select=UserSchema, - join_schema_to_select=TierSchema, - username="john_doe", - ) - ``` - - Return example: prefix added, no schema_to_select or join_schema_to_select - ```python - { - "id": 1, - "name": "John Doe", - "username": "john_doe", - "email": "johndoe@example.com", - "hashed_password": "hashed_password_example", - "profile_image_url": "https://profileimageurl.com/default.jpg", - "uuid": "123e4567-e89b-12d3-a456-426614174000", - "created_at": "2023-01-01T12:00:00", - "updated_at": "2023-01-02T12:00:00", - "deleted_at": null, - "is_deleted": false, - "is_superuser": false, - "tier_id": 2, - "tier_name": "Premium", - "tier_created_at": "2022-12-01T10:00:00", - "tier_updated_at": "2023-01-01T11:00:00", - } - ``` - """ - if join_on is None: - join_on = _auto_detect_join_condition(self._model, join_model) - - primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) - join_select = [] - - if join_schema_to_select: - columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select) - else: - columns = inspect(join_model).c - - for column in columns: - labeled_column = _add_column_with_prefix(column, join_prefix) - if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]: - join_select.append(labeled_column) - - if join_type == "left": - stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) - elif join_type == "inner": - stmt = select(*primary_select, *join_select).join(join_model, join_on) - else: - raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.") - - for key, value in kwargs.items(): - if hasattr(self._model, key): - stmt = stmt.where(getattr(self._model, key) == value) - - db_row = await db.execute(stmt) - result: Row = db_row.first() - if result: - out: dict = dict(result._mapping) - return out - - return None - - async def get_multi_joined( - self, - db: AsyncSession, - join_model: type[ModelType], - join_prefix: str | None = None, - join_on: Join | None = None, - schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None, - join_schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None, - join_type: str = "left", - offset: int = 0, - limit: int = 100, - **kwargs: Any, - ) -> dict[str, Any]: - """Fetch multiple records with a join on another model, allowing for pagination. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - join_model : Type[ModelType] - The model to join with. - join_prefix : Optional[str] - Optional prefix to be added to all columns of the joined model. If None, no prefix is added. - join_on : Join, optional - SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is - auto-detected based on foreign keys. - schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional - Pydantic schema for selecting specific columns from the primary model. - join_schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional - Pydantic schema for selecting specific columns from the joined model. - join_type : str, default "left" - Specifies the type of join operation to perform. Can be "left" for a left outer join - or "inner" for an inner join. - offset : int, default 0 - The offset (number of records to skip) for pagination. - limit : int, default 100 - The limit (maximum number of records to return) for pagination. - kwargs : dict - Filters to apply to the primary query. - - Returns - ------- - dict[str, Any] - A dictionary containing the fetched rows under 'data' key and total count under 'total_count'. - - Examples - -------- - # Fetching multiple User records joined with Tier records, using left join - users = await crud_user.get_multi_joined( - db=session, - join_model=Tier, - join_prefix="tier_", - schema_to_select=UserSchema, - join_schema_to_select=TierSchema, - offset=0, - limit=10 - ) - """ - if join_on is None: - join_on = _auto_detect_join_condition(self._model, join_model) - - primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select) - join_select = [] - - if join_schema_to_select: - columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select) - else: - columns = inspect(join_model).c - - for column in columns: - labeled_column = _add_column_with_prefix(column, join_prefix) - if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]: - join_select.append(labeled_column) - - if join_type == "left": - stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on) - elif join_type == "inner": - stmt = select(*primary_select, *join_select).join(join_model, join_on) - else: - raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.") - - for key, value in kwargs.items(): - if hasattr(self._model, key): - stmt = stmt.where(getattr(self._model, key) == value) - - stmt = stmt.offset(offset).limit(limit) - - db_rows = await db.execute(stmt) - data = [dict(row._mapping) for row in db_rows] - - total_count = await self.count(db=db, **kwargs) - - return {"data": data, "total_count": total_count} - - async def update(self, db: AsyncSession, object: UpdateSchemaType | dict[str, Any], **kwargs: Any) -> None: - """Update an existing record in the database. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - object : Union[UpdateSchemaType, dict[str, Any]] - The Pydantic schema or dictionary containing the data to be updated. - kwargs : dict - Filters for the update. - - Returns - ------- - None - """ - if isinstance(object, dict): - update_data = object - else: - update_data = object.model_dump(exclude_unset=True) - - if "updated_at" in update_data.keys(): - update_data["updated_at"] = datetime.now(UTC) - - stmt = update(self._model).filter_by(**kwargs).values(update_data) - - await db.execute(stmt) - await db.commit() - - async def db_delete(self, db: AsyncSession, **kwargs: Any) -> None: - """Delete a record in the database. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - kwargs : dict - Filters for the delete. - - Returns - ------- - None - """ - stmt = delete(self._model).filter_by(**kwargs) - await db.execute(stmt) - await db.commit() - - async def delete(self, db: AsyncSession, db_row: Row | None = None, **kwargs: Any) -> None: - """Soft delete a record if it has "is_deleted" attribute, otherwise perform a hard delete. - - Parameters - ---------- - db : AsyncSession - The SQLAlchemy async session. - db_row : Row | None, optional - Existing database row to delete. If None, it will be fetched based on `kwargs`. Default is None. - kwargs : dict - Filters for fetching the database row if not provided. - - Returns - ------- - None - """ - db_row = db_row or await self.exists(db=db, **kwargs) - if db_row: - if "is_deleted" in self._model.__table__.columns: - object_dict = {"is_deleted": True, "deleted_at": datetime.now(UTC)} - stmt = update(self._model).filter_by(**kwargs).values(object_dict) - - await db.execute(stmt) - await db.commit() - - else: - stmt = delete(self._model).filter_by(**kwargs) - await db.execute(stmt) - await db.commit() diff --git a/src/app/crud/crud_posts.py b/src/app/crud/crud_posts.py index 5e7fe0c..c823625 100644 --- a/src/app/crud/crud_posts.py +++ b/src/app/crud/crud_posts.py @@ -1,6 +1,7 @@ +from fastcrud import FastCRUD + from ..models.post import Post from ..schemas.post import PostCreateInternal, PostDelete, PostUpdate, PostUpdateInternal -from .crud_base import CRUDBase -CRUDPost = CRUDBase[Post, PostCreateInternal, PostUpdate, PostUpdateInternal, PostDelete] +CRUDPost = FastCRUD[Post, PostCreateInternal, PostUpdate, PostUpdateInternal, PostDelete] crud_posts = CRUDPost(Post) diff --git a/src/app/crud/crud_rate_limit.py b/src/app/crud/crud_rate_limit.py index f018344..9824856 100644 --- a/src/app/crud/crud_rate_limit.py +++ b/src/app/crud/crud_rate_limit.py @@ -1,6 +1,7 @@ +from fastcrud import FastCRUD + from ..models.rate_limit import RateLimit from ..schemas.rate_limit import RateLimitCreateInternal, RateLimitDelete, RateLimitUpdate, RateLimitUpdateInternal -from .crud_base import CRUDBase -CRUDRateLimit = CRUDBase[RateLimit, RateLimitCreateInternal, RateLimitUpdate, RateLimitUpdateInternal, RateLimitDelete] +CRUDRateLimit = FastCRUD[RateLimit, RateLimitCreateInternal, RateLimitUpdate, RateLimitUpdateInternal, RateLimitDelete] crud_rate_limits = CRUDRateLimit(RateLimit) diff --git a/src/app/crud/crud_tier.py b/src/app/crud/crud_tier.py index efeeb7e..861eb36 100644 --- a/src/app/crud/crud_tier.py +++ b/src/app/crud/crud_tier.py @@ -1,6 +1,7 @@ +from fastcrud import FastCRUD + from ..models.tier import Tier from ..schemas.tier import TierCreateInternal, TierDelete, TierUpdate, TierUpdateInternal -from .crud_base import CRUDBase -CRUDTier = CRUDBase[Tier, TierCreateInternal, TierUpdate, TierUpdateInternal, TierDelete] +CRUDTier = FastCRUD[Tier, TierCreateInternal, TierUpdate, TierUpdateInternal, TierDelete] crud_tiers = CRUDTier(Tier) diff --git a/src/app/crud/crud_users.py b/src/app/crud/crud_users.py index f3b1419..3130bca 100644 --- a/src/app/crud/crud_users.py +++ b/src/app/crud/crud_users.py @@ -1,6 +1,7 @@ +from fastcrud import FastCRUD + from ..models.user import User from ..schemas.user import UserCreateInternal, UserDelete, UserUpdate, UserUpdateInternal -from .crud_base import CRUDBase -CRUDUser = CRUDBase[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete] +CRUDUser = FastCRUD[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete] crud_users = CRUDUser(User) diff --git a/src/app/crud/helper.py b/src/app/crud/helper.py deleted file mode 100644 index 258f97d..0000000 --- a/src/app/crud/helper.py +++ /dev/null @@ -1,124 +0,0 @@ -from typing import Any - -from pydantic import BaseModel -from sqlalchemy import inspect -from sqlalchemy.orm import DeclarativeMeta -from sqlalchemy.sql import ColumnElement -from sqlalchemy.sql.elements import Label -from sqlalchemy.sql.schema import Column - -from ..core.db.database import Base - - -def _extract_matching_columns_from_schema(model: type[Base], schema: type[BaseModel] | list | None) -> list[Any]: - """Retrieves a list of ORM column objects from a SQLAlchemy model that match the field names in a given - Pydantic schema. - - Parameters - ---------- - model: Type[Base] - The SQLAlchemy ORM model containing columns to be matched with the schema fields. - schema: Type[BaseModel] - The Pydantic schema containing field names to be matched with the model's columns. - - Returns - ------- - list[Any] - A list of ORM column objects from the model that correspond to the field names defined in the schema. - """ - column_list = list(model.__table__.columns) - if schema is not None: - if isinstance(schema, list): - schema_fields = schema - else: - schema_fields = schema.model_fields.keys() - - column_list = [] - for column_name in schema_fields: - if hasattr(model, column_name): - column_list.append(getattr(model, column_name)) - - return column_list - - -def _extract_matching_columns_from_kwargs(model: type[Base], kwargs: dict) -> list[Any]: - if kwargs is not None: - kwargs_fields = kwargs.keys() - column_list = [] - for column_name in kwargs_fields: - if hasattr(model, column_name): - column_list.append(getattr(model, column_name)) - - return column_list - - -def _extract_matching_columns_from_column_names(model: type[Base], column_names: list) -> list[Any]: - column_list = [] - for column_name in column_names: - if hasattr(model, column_name): - column_list.append(getattr(model, column_name)) - - return column_list - - -def _auto_detect_join_condition( - base_model: type[DeclarativeMeta], join_model: type[DeclarativeMeta] -) -> ColumnElement | None: - """Automatically detects the join condition for SQLAlchemy models based on foreign key relationships. This - function scans the foreign keys in the base model and tries to match them with columns in the join model. - - Parameters - ---------- - base_model : Type[DeclarativeMeta] - The base SQLAlchemy model from which to join. - join_model : Type[DeclarativeMeta] - The SQLAlchemy model to join with the base model. - - Returns - ------- - Optional[ColumnElement] - A SQLAlchemy ColumnElement representing the join condition, if successfully detected. - - Raises - ------ - ValueError - If the join condition cannot be automatically determined, a ValueError is raised. - - Example - ------- - # Assuming User has a foreign key reference to Tier: - join_condition = auto_detect_join_condition(User, Tier) - """ - fk_columns = [col for col in inspect(base_model).c if col.foreign_keys] - join_on = next( - ( - base_model.__table__.c[col.name] == join_model.__table__.c[list(col.foreign_keys)[0].column.name] - for col in fk_columns - if list(col.foreign_keys)[0].column.table == join_model.__table__ - ), - None, - ) - - if join_on is None: - raise ValueError("Could not automatically determine join condition. Please provide join_on.") - - return join_on - - -def _add_column_with_prefix(column: Column, prefix: str | None) -> Label: - """Creates a SQLAlchemy column label with an optional prefix. - - Parameters - ---------- - column : Column - The SQLAlchemy Column object to be labeled. - prefix : Optional[str] - An optional prefix to prepend to the column's name. - - Returns - ------- - Label - A labeled SQLAlchemy Column object. - """ - column_label = f"{prefix}{column.name}" if prefix else column.name - return column.label(column_label)