Skip to content

Commit

Permalink
Support multi-level nested create/update with model full_clean() (#659
Browse files Browse the repository at this point in the history
)

* Add support for nested creation/update in mutations. This also has the benefit of consistently calling `full_clean()` before creating related instances.

This does remove the `get_or_create()` calls and instead uses `create` only. The expectation here is that `key_attr` could and should be used to indicate what field should be used as the unique identifier, and not something hard coded that could have unintended side effects when creating related instances that don't have unique constraints and expect new instances to always be created.

* Formatting

* First test (heavily based on one from an existing PR)

* Update new test with m2m creation/use

* Add test for nested creation when creating a new resource

* Add test for full_clean being called when performing nested creation or resources

* Remove unecessary `@transaction.atomic()` call

* Add support for nested creation of ForeignKeys
  • Loading branch information
philipstarkey authored Dec 21, 2024
1 parent 9e6d2bb commit 0967d04
Show file tree
Hide file tree
Showing 5 changed files with 673 additions and 42 deletions.
139 changes: 114 additions & 25 deletions strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import strawberry
from django.db import models, transaction
from django.db.models.base import Model
from django.db.models.fields import Field
from django.db.models.fields.related import ManyToManyField
from django.db.models.fields.reverse_related import (
ForeignObjectRel,
Expand Down Expand Up @@ -44,7 +45,11 @@
)

if TYPE_CHECKING:
from django.db.models.manager import ManyToManyRelatedManager, RelatedManager
from django.db.models.manager import (
BaseManager,
ManyToManyRelatedManager,
RelatedManager,
)
from strawberry.types.info import Info


Expand Down Expand Up @@ -88,6 +93,7 @@ def _parse_data(
value: Any,
*,
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
):
obj, data = _parse_pk(value, model, key_attr=key_attr)
parsed_data = {}
Expand All @@ -97,10 +103,21 @@ def _parse_data(
continue

if isinstance(v, ParsedObject):
if v.pk is None:
v = create(info, model, v.data or {}) # noqa: PLW2901
if v.pk in {None, UNSET}:
related_field = cast("Field", get_model_fields(model).get(k))
related_model = related_field.related_model
v = create( # noqa: PLW2901
info,
cast("type[Model]", related_model),
v.data or {},
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=[related_field.name],
)
elif isinstance(v.pk, models.Model) and v.data:
v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901
v = update( # noqa: PLW2901
info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean
)
else:
v = v.pk # noqa: PLW2901

Expand Down Expand Up @@ -222,6 +239,7 @@ def prepare_create_update(
data: dict[str, Any],
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
exclude_m2m: list[str] | None = None,
) -> tuple[
Model,
dict[str, object],
Expand All @@ -237,6 +255,7 @@ def prepare_create_update(
fields = get_model_fields(model)
m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = []
direct_field_values: dict[str, object] = {}
exclude_m2m = exclude_m2m or []

if dataclasses.is_dataclass(data):
data = vars(data)
Expand All @@ -256,6 +275,8 @@ def prepare_create_update(
# (but only if the instance is already saved and we are updating it)
value = False # noqa: PLW2901
elif isinstance(field, (ManyToManyField, ForeignObjectRel)):
if name in exclude_m2m:
continue
# m2m will be processed later
m2m.append((field, value))
direct_field_value = False
Expand All @@ -269,14 +290,19 @@ def prepare_create_update(
cast("type[Model]", field.related_model),
value,
key_attr=key_attr,
full_clean=full_clean,
)
if value is None and not value_data:
value = None # noqa: PLW2901

# If foreign object is not found, then create it
elif value is None:
value = field.related_model._default_manager.create( # noqa: PLW2901
**value_data,
elif value in {None, UNSET}:
value = create( # noqa: PLW2901
info,
field.related_model,
value_data,
key_attr=key_attr,
full_clean=full_clean,
)

# If foreign object does not need updating, then skip it
Expand Down Expand Up @@ -309,6 +335,7 @@ def create(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -321,10 +348,10 @@ def create(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


@transaction.atomic
def create(
info: Info,
model: type[_M],
Expand All @@ -333,12 +360,43 @@ def create(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
return _create(
info,
model._default_manager,
data,
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)


@transaction.atomic
def _create(
info: Info,
manager: BaseManager,
data: dict[str, Any] | list[dict[str, Any]],
*,
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M] | _M:
model = manager.model
# Before creating your instance, verify this is not a bulk create
# if so, add them one by one. Otherwise, get to work.
if isinstance(data, list):
return [
create(info, model, d, key_attr=key_attr, full_clean=full_clean)
create(
info,
model,
d,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
for d in data
]

Expand All @@ -365,6 +423,7 @@ def create(
data=data,
full_clean=full_clean,
key_attr=key_attr,
exclude_m2m=exclude_m2m,
)

# Creating the instance directly via create() without full-clean will
Expand All @@ -376,7 +435,7 @@ def create(

# Create the instance using the manager create method to respect
# manager create overrides. This also ensures support for proxy-models.
instance = model._default_manager.create(**create_kwargs)
instance = manager.create(**create_kwargs)

for field, value in m2m:
update_m2m(info, instance, field, value, key_attr)
Expand All @@ -393,6 +452,7 @@ def update(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M: ...


Expand All @@ -405,6 +465,7 @@ def update(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> list[_M]: ...


Expand All @@ -417,6 +478,7 @@ def update(
key_attr: str | None = None,
full_clean: bool | FullCleanOptions = True,
pre_save_hook: Callable[[_M], None] | None = None,
exclude_m2m: list[str] | None = None,
) -> _M | list[_M]:
# Unwrap lazy objects since they have a proxy __iter__ method that will make
# them iterables even if the wrapped object isn't
Expand All @@ -433,6 +495,7 @@ def update(
key_attr=key_attr,
full_clean=full_clean,
pre_save_hook=pre_save_hook,
exclude_m2m=exclude_m2m,
)
for instance in instances
]
Expand All @@ -443,6 +506,7 @@ def update(
data=data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)

if pre_save_hook is not None:
Expand Down Expand Up @@ -554,15 +618,22 @@ def update_m2m(
use_remove = True
if isinstance(field, ManyToManyField):
manager = cast("RelatedManager", getattr(instance, field.attname))
reverse_field_name = field.remote_field.related_name # type: ignore
else:
assert isinstance(field, (ManyToManyRel, ManyToOneRel))
accessor_name = field.get_accessor_name()
reverse_field_name = field.field.name
assert accessor_name
manager = cast("RelatedManager", getattr(instance, accessor_name))
if field.one_to_many:
# remove if field is nullable, otherwise delete
use_remove = field.remote_field.null is True

# Create a data dict containing the reference to the instance and exclude it from
# nested m2m creation (to break circular references)
ref_instance_data = {reverse_field_name: instance}
exclude_m2m = [reverse_field_name]

to_add = []
to_remove = []
to_delete = []
Expand All @@ -581,7 +652,11 @@ def update_m2m(
need_remove_cache = need_remove_cache or bool(values)
for v in values:
obj, data = _parse_data(
info, cast("type[Model]", manager.model), v, key_attr=key_attr
info,
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
if obj:
data.pop(key_attr, None)
Expand Down Expand Up @@ -621,14 +696,17 @@ def update_m2m(

existing.discard(obj)
else:
if key_attr not in data: # we have a Input Type
obj, _ = manager.get_or_create(**data)
else:
data.pop(key_attr)
obj = manager.create(**data)

if full_clean:
obj.full_clean(**full_clean_options)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
data.pop(key_attr, None)
obj = _create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
existing.discard(obj)

for remaining in existing:
Expand All @@ -645,6 +723,7 @@ def update_m2m(
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
if obj and data:
data.pop(key_attr, None)
Expand All @@ -656,18 +735,28 @@ def update_m2m(
data.pop(key_attr, None)
to_add.append(obj)
elif data:
if key_attr not in data:
manager.get_or_create(**data)
else:
data.pop(key_attr)
manager.create(**data)
# If we've reached here, the key_attr should be UNSET or missing. So
# let's remove it if it is there.
data.pop(key_attr, None)
_create(
info,
manager,
data | ref_instance_data,
key_attr=key_attr,
full_clean=full_clean,
exclude_m2m=exclude_m2m,
)
else:
raise AssertionError

need_remove_cache = need_remove_cache or bool(value.remove)
for v in value.remove or []:
obj, data = _parse_data(
info, cast("type[Model]", manager.model), v, key_attr=key_attr
info,
cast("type[Model]", manager.model),
v,
key_attr=key_attr,
full_clean=full_clean,
)
data.pop(key_attr, None)
assert not data
Expand Down
13 changes: 13 additions & 0 deletions tests/projects/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ class MilestoneIssueInput:
name: strawberry.auto


@strawberry_django.partial(Issue)
class MilestoneIssueInputPartial:
name: strawberry.auto
tags: Optional[list[TagInputPartial]]


@strawberry_django.partial(Project)
class ProjectInputPartial(NodeInputPartial):
name: strawberry.auto
Expand All @@ -353,6 +359,8 @@ class MilestoneInput:
@strawberry_django.partial(Milestone)
class MilestoneInputPartial(NodeInputPartial):
name: strawberry.auto
issues: Optional[list[MilestoneIssueInputPartial]]
project: Optional[ProjectInputPartial]


@strawberry.type
Expand Down Expand Up @@ -521,6 +529,11 @@ class Mutation:
argument_name="input",
key_attr="name",
)
create_project_with_milestones: ProjectType = mutations.create(
ProjectInputPartial,
handle_django_errors=True,
argument_name="input",
)
update_project: ProjectType = mutations.update(
ProjectInputPartial,
handle_django_errors=True,
Expand Down
10 changes: 10 additions & 0 deletions tests/projects/snapshots/schema.gql
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ input CreateProjectInput {

union CreateProjectPayload = ProjectType | OperationInfo

union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo

input CreateQuizInput {
title: String!
fullCleanOptions: Boolean! = false
Expand Down Expand Up @@ -365,12 +367,19 @@ input MilestoneInput {
input MilestoneInputPartial {
id: GlobalID
name: String
issues: [MilestoneIssueInputPartial!]
project: ProjectInputPartial
}

input MilestoneIssueInput {
name: String!
}

input MilestoneIssueInputPartial {
name: String
tags: [TagInputPartial!]
}

input MilestoneOrder {
name: Ordering
project: ProjectOrder
Expand Down Expand Up @@ -433,6 +442,7 @@ type Mutation {
updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload!
deleteIssue(input: NodeInput!): DeleteIssuePayload!
deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload!
createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload!
updateProject(input: ProjectInputPartial!): UpdateProjectPayload!
createMilestone(input: MilestoneInput!): CreateMilestonePayload!
createProject(
Expand Down
Loading

0 comments on commit 0967d04

Please # to comment.