Skip to content

Commit

Permalink
feat: converting sqla models to patch-style
Browse files Browse the repository at this point in the history
  • Loading branch information
niqzart committed Oct 7, 2023
1 parent 3dad230 commit 32766b0
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 9 deletions.
12 changes: 10 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class User(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(100))
description: Mapped[str | None] = mapped_column(Text())
admin: Mapped[bool] = mapped_column() # empty `mapped_column()` is required for models

avatar_id: Mapped[int] = mapped_column(ForeignKey("avatars.id"))
avatar: Mapped[Avatar] = relationship()
Expand All @@ -35,20 +36,26 @@ class User(Base):

BaseModel = MappedModel.create(columns=[id])
CreateModel = MappedModel.create(columns=[name, description])
PatchModel = CreateModel.as_patch()
IndexModel = MappedModel.create(properties=[representation])
FullModel = BaseModel.extend(
columns=[admin],
relationships=[(avatar, Avatar.IdModel)],
includes=[CreateModel, IndexModel],
)


with sessionmaker.begin() as session:
user = User(name="alex", description="cool person", avatar=Avatar())
user = User(name="alex", description="cool person", avatar=Avatar(), admin=False)
session.add(user)
session.flush()

print(User.BaseModel.model_validate(user).model_dump())
# {"id": 0}
print(User.PatchModel.model_validate({}).model_dump(exclude_defaults=True))
# {}
print(User.PatchModel.model_validate({"description": None}).model_dump(exclude_defaults=True))
# {"description": None}
print(User.CreateModel.model_validate(user).model_dump())
# {"name": "alex", "description": "cool person"}
print(User.IndexModel.model_validate(user).model_dump())
Expand All @@ -59,7 +66,8 @@ with sessionmaker.begin() as session:
# "name": "alex",
# "description": "cool person",
# "representation": "User #0: alex",
# "avatar": {"id": 0}
# "avatar": {"id": 0},
# "admin": False
# }
```

Expand Down
2 changes: 2 additions & 0 deletions pydantic_marshals/mypy/magic.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@ class MappedModelStub:
bases: Sequence[type[BaseModel]] = (),
includes: Sequence[MarshalModel | MappedModelStub] = (),
) -> Self: ...
@classmethod
def as_patch(cls) -> Self: ...
2 changes: 1 addition & 1 deletion pydantic_marshals/mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# TODO redo and generalize
base_model_qualname: Final = "pydantic_marshals.sqlalchemy.models.MappedModel"
type_matrix = {base_model_qualname}
methods = {"extend", "create"}
methods = {"create", "extend", "as_patch"}

pydantic_base_model_qualname: Final = "pydantic_marshals.base.models.MarshalBaseModel"
stub_module_name: Final = "pydantic_marshals.mypy.magic"
Expand Down
13 changes: 12 additions & 1 deletion pydantic_marshals/sqlalchemy/models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from collections.abc import Sequence
from collections.abc import Iterator, Sequence

from pydantic import BaseModel
from typing_extensions import Self

from pydantic_marshals.base.fields.base import MarshalField, PatchMarshalField
from pydantic_marshals.base.fields.properties import PropertyField, PropertyType
from pydantic_marshals.base.models import MarshalModel
from pydantic_marshals.sqlalchemy.fields.columns import ColumnField, ColumnType
Expand Down Expand Up @@ -66,3 +67,13 @@ def extend(
bases=bases,
includes=(self, *includes),
)

def patch_fields(self) -> Iterator[MarshalField]:
for field in self.fields:
if isinstance(field, PatchMarshalField):
yield field.as_patch()
else:
yield field

def as_patch(self) -> Self:
return type(self)(*self.patch_fields(), bases=self.bases)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pydantic-marshals"
version = "0.3.3"
version = "0.3.4"
description = "Library for creating partial pydantic models (automatic converters) from different mappings"
authors = ["niqzart <niqzart@gmail.com>"]
readme = "README.md"
Expand Down
32 changes: 28 additions & 4 deletions tests/functional/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sqlalchemy import ForeignKey, MetaData
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship

from pydantic_marshals.base.fields.base import PatchDefault
from pydantic_marshals.sqlalchemy import MappedModel
from pydantic_marshals.utils import is_subtype
from tests.unit.conftest import SampleEnum
Expand Down Expand Up @@ -102,24 +103,37 @@ class T(declarative_base): # type: ignore[valid-type, misc]

Model = MappedModel.create(columns=[field])
ExtendedModel = Model.extend()
PatchModel = Model.as_patch()
ExtendedPatchModel = PatchModel.extend()

return T


@pytest.fixture(params=[False, True], ids=["full", "patch"])
def column_patch_mode(request: PytestRequest[bool]) -> bool:
return request.param


@pytest.fixture(params=[False, True], ids=["normal", "extended"])
def column_marshal_model(
column_model: Any,
column_patch_mode: bool,
request: PytestRequest[bool],
) -> type[BaseModel]:
if request.param:
if column_patch_mode:
return column_model.ExtendedPatchModel # type: ignore[no-any-return]
return column_model.ExtendedModel # type: ignore[no-any-return]
if column_patch_mode:
return column_model.PatchModel # type: ignore[no-any-return]
return column_model.Model # type: ignore[no-any-return]


def test_column_model_inspection(
column_type: Any,
column_default: Any,
column_nullable: bool,
column_patch_mode: bool,
column_use_default: bool,
column_marshal_model: type[BaseModel],
) -> None:
Expand All @@ -129,10 +143,20 @@ def test_column_model_inspection(
field = column_marshal_model.model_fields.get("field")
assert isinstance(field, FieldInfo)
assert field.annotation == column_type
assert field.is_required() != (column_nullable or column_use_default)
if column_nullable:
expected_default = column_default if column_use_default else None
assert field.default is expected_default

expected_required = not (column_nullable or column_use_default or column_patch_mode)
assert field.is_required() == expected_required

if column_patch_mode:
expected_default = PatchDefault
elif column_use_default:
expected_default = column_default
elif column_nullable:
expected_default = None
else:
expected_default = PydanticUndefined

assert field.default is expected_default


def test_column_model_usage(
Expand Down

0 comments on commit 32766b0

Please # to comment.