diff --git a/README.md b/README.md
index edfeb42..9117e71 100644
--- a/README.md
+++ b/README.md
@@ -55,8 +55,8 @@
- 🏬 Easy redis caching
- 👜 Easy client-side caching
- 🚦 ARQ integration for task queue
-- ⚙️ Efficient querying (only queries what's needed) with support for joins
-- ⎘ Out of the box pagination support
+- ⚙️ Efficient and robust queries with fastcrud
+- ⎘ Out of the box offset and cursor pagination support with fastcrud
- 🛑 Rate Limiter dependency
- 👮 FastAPI docs behind authentication and hidden based on the environment
- 🦾 Easily extendable
@@ -749,14 +749,15 @@ poetry run alembic upgrade head
### 5.6 CRUD
-Inside `app/crud`, create a new `crud_entities.py` inheriting from `CRUDBase` for each new entity:
+Inside `app/crud`, create a new `crud_entities.py` inheriting from `FastCRUD` for each new entity:
```python
-from app.crud.crud_base import CRUDBase
+from fastcrud import FastCRUD
+
from app.models.entity import Entity
from app.schemas.entity import EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete
-CRUDEntity = CRUDBase[Entity, EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete]
+CRUDEntity = FastCRUD[Entity, EntityCreateInternal, EntityUpdate, EntityUpdateInternal, EntityDelete]
crud_entity = CRUDEntity(Entity)
```
@@ -767,7 +768,7 @@ So, for users:
from app.model.user import User
from app.schemas.user import UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete
-CRUDUser = CRUDBase[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete]
+CRUDUser = FastCRUD[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete]
crud_users = CRUDUser(User)
```
diff --git a/pyproject.toml b/pyproject.toml
index 1463706..35dcad8 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -30,6 +30,7 @@ redis = "^5.0.1"
arq = "^0.25.0"
gunicorn = "^21.2.0"
bcrypt = "^4.1.1"
+fastcrud = "^0.1.5"
[build-system]
diff --git a/src/app/core/db/crud_token_blacklist.py b/src/app/core/db/crud_token_blacklist.py
index bd15058..a67aec6 100644
--- a/src/app/core/db/crud_token_blacklist.py
+++ b/src/app/core/db/crud_token_blacklist.py
@@ -1,6 +1,7 @@
-from ...crud.crud_base import CRUDBase
+from fastcrud import FastCRUD
+
from ..db.token_blacklist import TokenBlacklist
from ..schemas import TokenBlacklistCreate, TokenBlacklistUpdate
-CRUDTokenBlacklist = CRUDBase[TokenBlacklist, TokenBlacklistCreate, TokenBlacklistUpdate, TokenBlacklistUpdate, None]
+CRUDTokenBlacklist = FastCRUD[TokenBlacklist, TokenBlacklistCreate, TokenBlacklistUpdate, TokenBlacklistUpdate, None]
crud_token_blacklist = CRUDTokenBlacklist(TokenBlacklist)
diff --git a/src/app/core/exceptions/http_exceptions.py b/src/app/core/exceptions/http_exceptions.py
index 167934a..e69de29 100644
--- a/src/app/core/exceptions/http_exceptions.py
+++ b/src/app/core/exceptions/http_exceptions.py
@@ -1,45 +0,0 @@
-from http import HTTPStatus
-
-from fastapi import HTTPException, status
-
-
-class CustomException(HTTPException):
- def __init__(self, status_code: int = status.HTTP_500_INTERNAL_SERVER_ERROR, detail: str | None = None):
- if not detail:
- detail = HTTPStatus(status_code).description
- super().__init__(status_code=status_code, detail=detail)
-
-
-class BadRequestException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_400_BAD_REQUEST, detail=detail)
-
-
-class NotFoundException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=detail)
-
-
-class ForbiddenException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
-
-
-class UnauthorizedException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_401_UNAUTHORIZED, detail=detail)
-
-
-class UnprocessableEntityException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail)
-
-
-class DuplicateValueException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=detail)
-
-
-class RateLimitException(CustomException):
- def __init__(self, detail: str | None = None):
- super().__init__(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=detail)
diff --git a/src/app/core/worker/functions.py b/src/app/core/worker/functions.py
index 88dcf3b..296d894 100644
--- a/src/app/core/worker/functions.py
+++ b/src/app/core/worker/functions.py
@@ -1,5 +1,6 @@
import asyncio
import logging
+
import uvloop
from arq.worker import Worker
diff --git a/src/app/core/worker/settings.py b/src/app/core/worker/settings.py
index 89556c4..bb21894 100644
--- a/src/app/core/worker/settings.py
+++ b/src/app/core/worker/settings.py
@@ -1,6 +1,7 @@
from arq.connections import RedisSettings
+
from ...core.config import settings
-from .functions import sample_background_task, startup, shutdown
+from .functions import sample_background_task, shutdown, startup
REDIS_QUEUE_HOST = settings.REDIS_QUEUE_HOST
REDIS_QUEUE_PORT = settings.REDIS_QUEUE_PORT
diff --git a/src/app/crud/crud_base.py b/src/app/crud/crud_base.py
deleted file mode 100644
index dfbf627..0000000
--- a/src/app/crud/crud_base.py
+++ /dev/null
@@ -1,467 +0,0 @@
-from datetime import UTC, datetime
-from typing import Any, Generic, TypeVar
-
-from pydantic import BaseModel
-from sqlalchemy import and_, delete, func, inspect, select, update
-from sqlalchemy.engine.row import Row
-from sqlalchemy.ext.asyncio import AsyncSession
-from sqlalchemy.sql import Join
-
-from ..core.db.database import Base
-from .helper import (
- _add_column_with_prefix,
- _auto_detect_join_condition,
- _extract_matching_columns_from_kwargs,
- _extract_matching_columns_from_schema,
-)
-
-ModelType = TypeVar("ModelType", bound=Base)
-CreateSchemaType = TypeVar("CreateSchemaType", bound=BaseModel)
-UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel)
-UpdateSchemaInternalType = TypeVar("UpdateSchemaInternalType", bound=BaseModel)
-DeleteSchemaType = TypeVar("DeleteSchemaType", bound=BaseModel)
-
-
-class CRUDBase(Generic[ModelType, CreateSchemaType, UpdateSchemaType, UpdateSchemaInternalType, DeleteSchemaType]):
- """Base class for CRUD operations on a model.
-
- Parameters
- ----------
- model : Type[ModelType]
- The SQLAlchemy model type.
- """
-
- def __init__(self, model: type[ModelType]) -> None:
- self._model = model
-
- async def create(self, db: AsyncSession, object: CreateSchemaType) -> ModelType:
- """Create a new record in the database.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- object : CreateSchemaType
- The Pydantic schema containing the data to be saved.
-
- Returns
- -------
- ModelType
- The created database object.
- """
- object_dict = object.model_dump()
- db_object: ModelType = self._model(**object_dict)
- db.add(db_object)
- await db.commit()
- return db_object
-
- async def get(
- self, db: AsyncSession, schema_to_select: type[BaseModel] | list | None = None, **kwargs: Any
- ) -> dict | None:
- """Fetch a single record based on filters.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- schema_to_select : Union[Type[BaseModel], list, None], optional
- Pydantic schema for selecting specific columns. Default is None to select all columns.
- kwargs : dict
- Filters to apply to the query.
-
- Returns
- -------
- dict | None
- The fetched database row or None if not found.
- """
- to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
- stmt = select(*to_select).filter_by(**kwargs)
-
- db_row = await db.execute(stmt)
- result: Row = db_row.first()
- if result is not None:
- out: dict = dict(result._mapping)
- return out
-
- return None
-
- async def exists(self, db: AsyncSession, **kwargs: Any) -> bool:
- """Check if a record exists based on filters.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- kwargs : dict
- Filters to apply to the query.
-
- Returns
- -------
- bool
- True if a record exists, False otherwise.
- """
- to_select = _extract_matching_columns_from_kwargs(model=self._model, kwargs=kwargs)
- stmt = select(*to_select).filter_by(**kwargs).limit(1)
-
- result = await db.execute(stmt)
- return result.first() is not None
-
- async def count(self, db: AsyncSession, **kwargs: Any) -> int:
- """Count the records based on filters.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- kwargs : dict
- Filters to apply to the query.
-
- Returns
- -------
- int
- Total count of records that match the applied filters.
-
- Note
- ----
- This method provides a quick way to get the count of records without retrieving the actual data.
- """
- if kwargs:
- conditions = [getattr(self._model, key) == value for key, value in kwargs.items()]
- combined_conditions = and_(*conditions)
- count_query = select(func.count()).select_from(self._model).filter(combined_conditions)
- else:
- count_query = select(func.count()).select_from(self._model)
-
- total_count: int = await db.scalar(count_query)
-
- return total_count
-
- async def get_multi(
- self,
- db: AsyncSession,
- offset: int = 0,
- limit: int = 100,
- schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None,
- **kwargs: Any,
- ) -> dict[str, Any]:
- """Fetch multiple records based on filters.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- offset : int, optional
- Number of rows to skip before fetching. Default is 0.
- limit : int, optional
- Maximum number of rows to fetch. Default is 100.
- schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional
- Pydantic schema for selecting specific columns. Default is None to select all columns.
- kwargs : dict
- Filters to apply to the query.
-
- Returns
- -------
- dict[str, Any]
- Dictionary containing the fetched rows under 'data' key and total count under 'total_count'.
- """
- to_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
- stmt = select(*to_select).filter_by(**kwargs).offset(offset).limit(limit)
-
- result = await db.execute(stmt)
- data = [dict(row) for row in result.mappings()]
-
- total_count = await self.count(db=db, **kwargs)
-
- return {"data": data, "total_count": total_count}
-
- async def get_joined(
- self,
- db: AsyncSession,
- join_model: type[ModelType],
- join_prefix: str | None = None,
- join_on: Join | None = None,
- schema_to_select: type[BaseModel] | list | None = None,
- join_schema_to_select: type[BaseModel] | list | None = None,
- join_type: str = "left",
- **kwargs: Any,
- ) -> dict | None:
- """Fetches a single record with a join on another model. If 'join_on' is not provided, the method attempts
- to automatically detect the join condition using foreign key relationships.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- join_model : Type[ModelType]
- The model to join with.
- join_prefix : Optional[str]
- Optional prefix to be added to all columns of the joined model. If None, no prefix is added.
- join_on : Join, optional
- SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is
- auto-detected based on foreign keys.
- schema_to_select : Union[Type[BaseModel], list, None], optional
- Pydantic schema for selecting specific columns from the primary model.
- join_schema_to_select : Union[Type[BaseModel], list, None], optional
- Pydantic schema for selecting specific columns from the joined model.
- join_type : str, default "left"
- Specifies the type of join operation to perform. Can be "left" for a left outer join
- or "inner" for an inner join.
- kwargs : dict
- Filters to apply to the query.
-
- Returns
- -------
- dict | None
- The fetched database row or None if not found.
-
- Examples
- --------
- Simple example: Joining User and Tier models without explicitly providing join_on
- ```python
- result = await crud_user.get_joined(
- db=session, join_model=Tier, schema_to_select=UserSchema, join_schema_to_select=TierSchema
- )
- ```
-
- Complex example: Joining with a custom join condition, additional filter parameters, and a prefix
- ```python
- from sqlalchemy import and_
-
- result = await crud_user.get_joined(
- db=session,
- join_model=Tier,
- join_prefix="tier_",
- join_on=and_(User.tier_id == Tier.id, User.is_superuser == True),
- schema_to_select=UserSchema,
- join_schema_to_select=TierSchema,
- username="john_doe",
- )
- ```
-
- Return example: prefix added, no schema_to_select or join_schema_to_select
- ```python
- {
- "id": 1,
- "name": "John Doe",
- "username": "john_doe",
- "email": "johndoe@example.com",
- "hashed_password": "hashed_password_example",
- "profile_image_url": "https://profileimageurl.com/default.jpg",
- "uuid": "123e4567-e89b-12d3-a456-426614174000",
- "created_at": "2023-01-01T12:00:00",
- "updated_at": "2023-01-02T12:00:00",
- "deleted_at": null,
- "is_deleted": false,
- "is_superuser": false,
- "tier_id": 2,
- "tier_name": "Premium",
- "tier_created_at": "2022-12-01T10:00:00",
- "tier_updated_at": "2023-01-01T11:00:00",
- }
- ```
- """
- if join_on is None:
- join_on = _auto_detect_join_condition(self._model, join_model)
-
- primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
- join_select = []
-
- if join_schema_to_select:
- columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select)
- else:
- columns = inspect(join_model).c
-
- for column in columns:
- labeled_column = _add_column_with_prefix(column, join_prefix)
- if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]:
- join_select.append(labeled_column)
-
- if join_type == "left":
- stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on)
- elif join_type == "inner":
- stmt = select(*primary_select, *join_select).join(join_model, join_on)
- else:
- raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.")
-
- for key, value in kwargs.items():
- if hasattr(self._model, key):
- stmt = stmt.where(getattr(self._model, key) == value)
-
- db_row = await db.execute(stmt)
- result: Row = db_row.first()
- if result:
- out: dict = dict(result._mapping)
- return out
-
- return None
-
- async def get_multi_joined(
- self,
- db: AsyncSession,
- join_model: type[ModelType],
- join_prefix: str | None = None,
- join_on: Join | None = None,
- schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None,
- join_schema_to_select: type[BaseModel] | list[type[BaseModel]] | None = None,
- join_type: str = "left",
- offset: int = 0,
- limit: int = 100,
- **kwargs: Any,
- ) -> dict[str, Any]:
- """Fetch multiple records with a join on another model, allowing for pagination.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- join_model : Type[ModelType]
- The model to join with.
- join_prefix : Optional[str]
- Optional prefix to be added to all columns of the joined model. If None, no prefix is added.
- join_on : Join, optional
- SQLAlchemy Join object for specifying the ON clause of the join. If None, the join condition is
- auto-detected based on foreign keys.
- schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional
- Pydantic schema for selecting specific columns from the primary model.
- join_schema_to_select : Union[Type[BaseModel], list[Type[BaseModel]], None], optional
- Pydantic schema for selecting specific columns from the joined model.
- join_type : str, default "left"
- Specifies the type of join operation to perform. Can be "left" for a left outer join
- or "inner" for an inner join.
- offset : int, default 0
- The offset (number of records to skip) for pagination.
- limit : int, default 100
- The limit (maximum number of records to return) for pagination.
- kwargs : dict
- Filters to apply to the primary query.
-
- Returns
- -------
- dict[str, Any]
- A dictionary containing the fetched rows under 'data' key and total count under 'total_count'.
-
- Examples
- --------
- # Fetching multiple User records joined with Tier records, using left join
- users = await crud_user.get_multi_joined(
- db=session,
- join_model=Tier,
- join_prefix="tier_",
- schema_to_select=UserSchema,
- join_schema_to_select=TierSchema,
- offset=0,
- limit=10
- )
- """
- if join_on is None:
- join_on = _auto_detect_join_condition(self._model, join_model)
-
- primary_select = _extract_matching_columns_from_schema(model=self._model, schema=schema_to_select)
- join_select = []
-
- if join_schema_to_select:
- columns = _extract_matching_columns_from_schema(model=join_model, schema=join_schema_to_select)
- else:
- columns = inspect(join_model).c
-
- for column in columns:
- labeled_column = _add_column_with_prefix(column, join_prefix)
- if f"{join_prefix}{column.name}" not in [col.name for col in primary_select]:
- join_select.append(labeled_column)
-
- if join_type == "left":
- stmt = select(*primary_select, *join_select).outerjoin(join_model, join_on)
- elif join_type == "inner":
- stmt = select(*primary_select, *join_select).join(join_model, join_on)
- else:
- raise ValueError(f"Invalid join type: {join_type}. Only 'left' or 'inner' are valid.")
-
- for key, value in kwargs.items():
- if hasattr(self._model, key):
- stmt = stmt.where(getattr(self._model, key) == value)
-
- stmt = stmt.offset(offset).limit(limit)
-
- db_rows = await db.execute(stmt)
- data = [dict(row._mapping) for row in db_rows]
-
- total_count = await self.count(db=db, **kwargs)
-
- return {"data": data, "total_count": total_count}
-
- async def update(self, db: AsyncSession, object: UpdateSchemaType | dict[str, Any], **kwargs: Any) -> None:
- """Update an existing record in the database.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- object : Union[UpdateSchemaType, dict[str, Any]]
- The Pydantic schema or dictionary containing the data to be updated.
- kwargs : dict
- Filters for the update.
-
- Returns
- -------
- None
- """
- if isinstance(object, dict):
- update_data = object
- else:
- update_data = object.model_dump(exclude_unset=True)
-
- if "updated_at" in update_data.keys():
- update_data["updated_at"] = datetime.now(UTC)
-
- stmt = update(self._model).filter_by(**kwargs).values(update_data)
-
- await db.execute(stmt)
- await db.commit()
-
- async def db_delete(self, db: AsyncSession, **kwargs: Any) -> None:
- """Delete a record in the database.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- kwargs : dict
- Filters for the delete.
-
- Returns
- -------
- None
- """
- stmt = delete(self._model).filter_by(**kwargs)
- await db.execute(stmt)
- await db.commit()
-
- async def delete(self, db: AsyncSession, db_row: Row | None = None, **kwargs: Any) -> None:
- """Soft delete a record if it has "is_deleted" attribute, otherwise perform a hard delete.
-
- Parameters
- ----------
- db : AsyncSession
- The SQLAlchemy async session.
- db_row : Row | None, optional
- Existing database row to delete. If None, it will be fetched based on `kwargs`. Default is None.
- kwargs : dict
- Filters for fetching the database row if not provided.
-
- Returns
- -------
- None
- """
- db_row = db_row or await self.exists(db=db, **kwargs)
- if db_row:
- if "is_deleted" in self._model.__table__.columns:
- object_dict = {"is_deleted": True, "deleted_at": datetime.now(UTC)}
- stmt = update(self._model).filter_by(**kwargs).values(object_dict)
-
- await db.execute(stmt)
- await db.commit()
-
- else:
- stmt = delete(self._model).filter_by(**kwargs)
- await db.execute(stmt)
- await db.commit()
diff --git a/src/app/crud/crud_posts.py b/src/app/crud/crud_posts.py
index 5e7fe0c..c823625 100644
--- a/src/app/crud/crud_posts.py
+++ b/src/app/crud/crud_posts.py
@@ -1,6 +1,7 @@
+from fastcrud import FastCRUD
+
from ..models.post import Post
from ..schemas.post import PostCreateInternal, PostDelete, PostUpdate, PostUpdateInternal
-from .crud_base import CRUDBase
-CRUDPost = CRUDBase[Post, PostCreateInternal, PostUpdate, PostUpdateInternal, PostDelete]
+CRUDPost = FastCRUD[Post, PostCreateInternal, PostUpdate, PostUpdateInternal, PostDelete]
crud_posts = CRUDPost(Post)
diff --git a/src/app/crud/crud_rate_limit.py b/src/app/crud/crud_rate_limit.py
index f018344..9824856 100644
--- a/src/app/crud/crud_rate_limit.py
+++ b/src/app/crud/crud_rate_limit.py
@@ -1,6 +1,7 @@
+from fastcrud import FastCRUD
+
from ..models.rate_limit import RateLimit
from ..schemas.rate_limit import RateLimitCreateInternal, RateLimitDelete, RateLimitUpdate, RateLimitUpdateInternal
-from .crud_base import CRUDBase
-CRUDRateLimit = CRUDBase[RateLimit, RateLimitCreateInternal, RateLimitUpdate, RateLimitUpdateInternal, RateLimitDelete]
+CRUDRateLimit = FastCRUD[RateLimit, RateLimitCreateInternal, RateLimitUpdate, RateLimitUpdateInternal, RateLimitDelete]
crud_rate_limits = CRUDRateLimit(RateLimit)
diff --git a/src/app/crud/crud_tier.py b/src/app/crud/crud_tier.py
index efeeb7e..861eb36 100644
--- a/src/app/crud/crud_tier.py
+++ b/src/app/crud/crud_tier.py
@@ -1,6 +1,7 @@
+from fastcrud import FastCRUD
+
from ..models.tier import Tier
from ..schemas.tier import TierCreateInternal, TierDelete, TierUpdate, TierUpdateInternal
-from .crud_base import CRUDBase
-CRUDTier = CRUDBase[Tier, TierCreateInternal, TierUpdate, TierUpdateInternal, TierDelete]
+CRUDTier = FastCRUD[Tier, TierCreateInternal, TierUpdate, TierUpdateInternal, TierDelete]
crud_tiers = CRUDTier(Tier)
diff --git a/src/app/crud/crud_users.py b/src/app/crud/crud_users.py
index f3b1419..3130bca 100644
--- a/src/app/crud/crud_users.py
+++ b/src/app/crud/crud_users.py
@@ -1,6 +1,7 @@
+from fastcrud import FastCRUD
+
from ..models.user import User
from ..schemas.user import UserCreateInternal, UserDelete, UserUpdate, UserUpdateInternal
-from .crud_base import CRUDBase
-CRUDUser = CRUDBase[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete]
+CRUDUser = FastCRUD[User, UserCreateInternal, UserUpdate, UserUpdateInternal, UserDelete]
crud_users = CRUDUser(User)
diff --git a/src/app/crud/helper.py b/src/app/crud/helper.py
deleted file mode 100644
index 258f97d..0000000
--- a/src/app/crud/helper.py
+++ /dev/null
@@ -1,124 +0,0 @@
-from typing import Any
-
-from pydantic import BaseModel
-from sqlalchemy import inspect
-from sqlalchemy.orm import DeclarativeMeta
-from sqlalchemy.sql import ColumnElement
-from sqlalchemy.sql.elements import Label
-from sqlalchemy.sql.schema import Column
-
-from ..core.db.database import Base
-
-
-def _extract_matching_columns_from_schema(model: type[Base], schema: type[BaseModel] | list | None) -> list[Any]:
- """Retrieves a list of ORM column objects from a SQLAlchemy model that match the field names in a given
- Pydantic schema.
-
- Parameters
- ----------
- model: Type[Base]
- The SQLAlchemy ORM model containing columns to be matched with the schema fields.
- schema: Type[BaseModel]
- The Pydantic schema containing field names to be matched with the model's columns.
-
- Returns
- -------
- list[Any]
- A list of ORM column objects from the model that correspond to the field names defined in the schema.
- """
- column_list = list(model.__table__.columns)
- if schema is not None:
- if isinstance(schema, list):
- schema_fields = schema
- else:
- schema_fields = schema.model_fields.keys()
-
- column_list = []
- for column_name in schema_fields:
- if hasattr(model, column_name):
- column_list.append(getattr(model, column_name))
-
- return column_list
-
-
-def _extract_matching_columns_from_kwargs(model: type[Base], kwargs: dict) -> list[Any]:
- if kwargs is not None:
- kwargs_fields = kwargs.keys()
- column_list = []
- for column_name in kwargs_fields:
- if hasattr(model, column_name):
- column_list.append(getattr(model, column_name))
-
- return column_list
-
-
-def _extract_matching_columns_from_column_names(model: type[Base], column_names: list) -> list[Any]:
- column_list = []
- for column_name in column_names:
- if hasattr(model, column_name):
- column_list.append(getattr(model, column_name))
-
- return column_list
-
-
-def _auto_detect_join_condition(
- base_model: type[DeclarativeMeta], join_model: type[DeclarativeMeta]
-) -> ColumnElement | None:
- """Automatically detects the join condition for SQLAlchemy models based on foreign key relationships. This
- function scans the foreign keys in the base model and tries to match them with columns in the join model.
-
- Parameters
- ----------
- base_model : Type[DeclarativeMeta]
- The base SQLAlchemy model from which to join.
- join_model : Type[DeclarativeMeta]
- The SQLAlchemy model to join with the base model.
-
- Returns
- -------
- Optional[ColumnElement]
- A SQLAlchemy ColumnElement representing the join condition, if successfully detected.
-
- Raises
- ------
- ValueError
- If the join condition cannot be automatically determined, a ValueError is raised.
-
- Example
- -------
- # Assuming User has a foreign key reference to Tier:
- join_condition = auto_detect_join_condition(User, Tier)
- """
- fk_columns = [col for col in inspect(base_model).c if col.foreign_keys]
- join_on = next(
- (
- base_model.__table__.c[col.name] == join_model.__table__.c[list(col.foreign_keys)[0].column.name]
- for col in fk_columns
- if list(col.foreign_keys)[0].column.table == join_model.__table__
- ),
- None,
- )
-
- if join_on is None:
- raise ValueError("Could not automatically determine join condition. Please provide join_on.")
-
- return join_on
-
-
-def _add_column_with_prefix(column: Column, prefix: str | None) -> Label:
- """Creates a SQLAlchemy column label with an optional prefix.
-
- Parameters
- ----------
- column : Column
- The SQLAlchemy Column object to be labeled.
- prefix : Optional[str]
- An optional prefix to prepend to the column's name.
-
- Returns
- -------
- Label
- A labeled SQLAlchemy Column object.
- """
- column_label = f"{prefix}{column.name}" if prefix else column.name
- return column.label(column_label)