From 0967d04b6708e9f060e2a67235e4710f2f96a96a Mon Sep 17 00:00:00 2001 From: Phil Starkey Date: Sat, 21 Dec 2024 21:01:46 +1100 Subject: [PATCH] Support multi-level nested create/update with model `full_clean()` (#659) * Add support for nested creation/update in mutations. This also has the benefit of consistently calling `full_clean()` before creating related instances. This does remove the `get_or_create()` calls and instead uses `create` only. The expectation here is that `key_attr` could and should be used to indicate what field should be used as the unique identifier, and not something hard coded that could have unintended side effects when creating related instances that don't have unique constraints and expect new instances to always be created. * Formatting * First test (heavily based on one from an existing PR) * Update new test with m2m creation/use * Add test for nested creation when creating a new resource * Add test for full_clean being called when performing nested creation or resources * Remove unecessary `@transaction.atomic()` call * Add support for nested creation of ForeignKeys --- strawberry_django/mutations/resolvers.py | 139 ++++- tests/projects/schema.py | 13 + tests/projects/snapshots/schema.gql | 10 + .../snapshots/schema_with_inheritance.gql | 18 + tests/test_input_mutations.py | 535 +++++++++++++++++- 5 files changed, 673 insertions(+), 42 deletions(-) diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 50c76677..53a58cbe 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -15,6 +15,7 @@ import strawberry from django.db import models, transaction from django.db.models.base import Model +from django.db.models.fields import Field from django.db.models.fields.related import ManyToManyField from django.db.models.fields.reverse_related import ( ForeignObjectRel, @@ -44,7 +45,11 @@ ) if TYPE_CHECKING: - from django.db.models.manager import ManyToManyRelatedManager, RelatedManager + from django.db.models.manager import ( + BaseManager, + ManyToManyRelatedManager, + RelatedManager, + ) from strawberry.types.info import Info @@ -88,6 +93,7 @@ def _parse_data( value: Any, *, key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, ): obj, data = _parse_pk(value, model, key_attr=key_attr) parsed_data = {} @@ -97,10 +103,21 @@ def _parse_data( continue if isinstance(v, ParsedObject): - if v.pk is None: - v = create(info, model, v.data or {}) # noqa: PLW2901 + if v.pk in {None, UNSET}: + related_field = cast("Field", get_model_fields(model).get(k)) + related_model = related_field.related_model + v = create( # noqa: PLW2901 + info, + cast("type[Model]", related_model), + v.data or {}, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=[related_field.name], + ) elif isinstance(v.pk, models.Model) and v.data: - v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901 + v = update( # noqa: PLW2901 + info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean + ) else: v = v.pk # noqa: PLW2901 @@ -222,6 +239,7 @@ def prepare_create_update( data: dict[str, Any], key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, + exclude_m2m: list[str] | None = None, ) -> tuple[ Model, dict[str, object], @@ -237,6 +255,7 @@ def prepare_create_update( fields = get_model_fields(model) m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = [] direct_field_values: dict[str, object] = {} + exclude_m2m = exclude_m2m or [] if dataclasses.is_dataclass(data): data = vars(data) @@ -256,6 +275,8 @@ def prepare_create_update( # (but only if the instance is already saved and we are updating it) value = False # noqa: PLW2901 elif isinstance(field, (ManyToManyField, ForeignObjectRel)): + if name in exclude_m2m: + continue # m2m will be processed later m2m.append((field, value)) direct_field_value = False @@ -269,14 +290,19 @@ def prepare_create_update( cast("type[Model]", field.related_model), value, key_attr=key_attr, + full_clean=full_clean, ) if value is None and not value_data: value = None # noqa: PLW2901 # If foreign object is not found, then create it - elif value is None: - value = field.related_model._default_manager.create( # noqa: PLW2901 - **value_data, + elif value in {None, UNSET}: + value = create( # noqa: PLW2901 + info, + field.related_model, + value_data, + key_attr=key_attr, + full_clean=full_clean, ) # If foreign object does not need updating, then skip it @@ -309,6 +335,7 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -321,10 +348,10 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... -@transaction.atomic def create( info: Info, model: type[_M], @@ -333,12 +360,43 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, +) -> list[_M] | _M: + return _create( + info, + model._default_manager, + data, + key_attr=key_attr, + full_clean=full_clean, + pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, + ) + + +@transaction.atomic +def _create( + info: Info, + manager: BaseManager, + data: dict[str, Any] | list[dict[str, Any]], + *, + key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, + pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M] | _M: + model = manager.model # Before creating your instance, verify this is not a bulk create # if so, add them one by one. Otherwise, get to work. if isinstance(data, list): return [ - create(info, model, d, key_attr=key_attr, full_clean=full_clean) + create( + info, + model, + d, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) for d in data ] @@ -365,6 +423,7 @@ def create( data=data, full_clean=full_clean, key_attr=key_attr, + exclude_m2m=exclude_m2m, ) # Creating the instance directly via create() without full-clean will @@ -376,7 +435,7 @@ def create( # Create the instance using the manager create method to respect # manager create overrides. This also ensures support for proxy-models. - instance = model._default_manager.create(**create_kwargs) + instance = manager.create(**create_kwargs) for field, value in m2m: update_m2m(info, instance, field, value, key_attr) @@ -393,6 +452,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -405,6 +465,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... @@ -417,6 +478,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M | list[_M]: # Unwrap lazy objects since they have a proxy __iter__ method that will make # them iterables even if the wrapped object isn't @@ -433,6 +495,7 @@ def update( key_attr=key_attr, full_clean=full_clean, pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, ) for instance in instances ] @@ -443,6 +506,7 @@ def update( data=data, key_attr=key_attr, full_clean=full_clean, + exclude_m2m=exclude_m2m, ) if pre_save_hook is not None: @@ -554,15 +618,22 @@ def update_m2m( use_remove = True if isinstance(field, ManyToManyField): manager = cast("RelatedManager", getattr(instance, field.attname)) + reverse_field_name = field.remote_field.related_name # type: ignore else: assert isinstance(field, (ManyToManyRel, ManyToOneRel)) accessor_name = field.get_accessor_name() + reverse_field_name = field.field.name assert accessor_name manager = cast("RelatedManager", getattr(instance, accessor_name)) if field.one_to_many: # remove if field is nullable, otherwise delete use_remove = field.remote_field.null is True + # Create a data dict containing the reference to the instance and exclude it from + # nested m2m creation (to break circular references) + ref_instance_data = {reverse_field_name: instance} + exclude_m2m = [reverse_field_name] + to_add = [] to_remove = [] to_delete = [] @@ -581,7 +652,11 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(values) for v in values: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) if obj: data.pop(key_attr, None) @@ -621,14 +696,17 @@ def update_m2m( existing.discard(obj) else: - if key_attr not in data: # we have a Input Type - obj, _ = manager.get_or_create(**data) - else: - data.pop(key_attr) - obj = manager.create(**data) - - if full_clean: - obj.full_clean(**full_clean_options) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + obj = _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) existing.discard(obj) for remaining in existing: @@ -645,6 +723,7 @@ def update_m2m( cast("type[Model]", manager.model), v, key_attr=key_attr, + full_clean=full_clean, ) if obj and data: data.pop(key_attr, None) @@ -656,18 +735,28 @@ def update_m2m( data.pop(key_attr, None) to_add.append(obj) elif data: - if key_attr not in data: - manager.get_or_create(**data) - else: - data.pop(key_attr) - manager.create(**data) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) else: raise AssertionError need_remove_cache = need_remove_cache or bool(value.remove) for v in value.remove or []: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) data.pop(key_attr, None) assert not data diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 9a29c5bd..9c82e89a 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -337,6 +337,12 @@ class MilestoneIssueInput: name: strawberry.auto +@strawberry_django.partial(Issue) +class MilestoneIssueInputPartial: + name: strawberry.auto + tags: Optional[list[TagInputPartial]] + + @strawberry_django.partial(Project) class ProjectInputPartial(NodeInputPartial): name: strawberry.auto @@ -353,6 +359,8 @@ class MilestoneInput: @strawberry_django.partial(Milestone) class MilestoneInputPartial(NodeInputPartial): name: strawberry.auto + issues: Optional[list[MilestoneIssueInputPartial]] + project: Optional[ProjectInputPartial] @strawberry.type @@ -521,6 +529,11 @@ class Mutation: argument_name="input", key_attr="name", ) + create_project_with_milestones: ProjectType = mutations.create( + ProjectInputPartial, + handle_django_errors=True, + argument_name="input", + ) update_project: ProjectType = mutations.update( ProjectInputPartial, handle_django_errors=True, diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index eb6f1ad0..3b74f901 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -95,6 +95,8 @@ input CreateProjectInput { union CreateProjectPayload = ProjectType | OperationInfo +union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo + input CreateQuizInput { title: String! fullCleanOptions: Boolean! = false @@ -365,12 +367,19 @@ input MilestoneInput { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial } input MilestoneIssueInput { name: String! } +input MilestoneIssueInputPartial { + name: String + tags: [TagInputPartial!] +} + input MilestoneOrder { name: Ordering project: ProjectOrder @@ -433,6 +442,7 @@ type Mutation { updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload! deleteIssue(input: NodeInput!): DeleteIssuePayload! deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload! + createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload! updateProject(input: ProjectInputPartial!): UpdateProjectPayload! createMilestone(input: MilestoneInput!): CreateMilestonePayload! createProject( diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index b98209fb..ae16f683 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -159,6 +159,13 @@ input MilestoneFilter { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial +} + +input MilestoneIssueInputPartial { + name: String + tags: [TagInputPartial!] } input MilestoneOrder { @@ -308,6 +315,12 @@ type PageInfo { endCursor: String } +input ProjectInputPartial { + id: GlobalID + name: String + milestones: [MilestoneInputPartial!] +} + input ProjectOrder { id: Ordering name: Ordering @@ -405,6 +418,11 @@ input StrFilterLookup { iRegex: String } +input TagInputPartial { + id: GlobalID + name: String +} + type TagType implements Node & Named { """The Globally Unique ID of this object""" id: GlobalID! diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index 6f2a388b..4cdaaf38 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + import pytest +from django.core.exceptions import ValidationError from strawberry.relay import from_base64, to_base64 from tests.utils import GraphQLTestClient, assert_num_queries @@ -10,7 +13,7 @@ TagFactory, UserFactory, ) -from .projects.models import Issue, Milestone, Project +from .projects.models import Issue, Milestone, Project, Tag @pytest.mark.django_db(transaction=True) @@ -245,8 +248,8 @@ def test_input_create_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateMilestone ($input: MilestoneInput!) { - createMilestone (input: $input) { + mutation CreateIssue ($input: IssueInput!) { + createIssue (input: $input) { __typename ... on OperationInfo { messages { @@ -255,45 +258,70 @@ def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient message } } - ... on MilestoneType { + ... on IssueType { id name - project { + milestone { id name + project { + id + name + } } } } } """ assert not Project.objects.filter(name="New Project").exists() + assert not Milestone.objects.filter(name="New Milestone").exists() + assert not Issue.objects.filter(name="New Issue").exists() res = gql_client.query( query, { "input": { - "name": "Some Milestone", - "project": { - "name": "New Project", + "name": "New Issue", + "milestone": { + "name": "New Milestone", + "project": { + "name": "New Project", + }, }, }, }, ) + assert res.data - assert isinstance(res.data["createMilestone"], dict) + assert isinstance(res.data["createIssue"], dict) - typename, _pk = from_base64(res.data["createMilestone"].pop("id")) - assert typename == "MilestoneType" + typename, pk = from_base64(res.data["createIssue"].get("id")) + + assert typename == "IssueType" + issue = Issue.objects.get(pk=pk) + assert issue.name == "New Issue" + + milestone = Milestone.objects.get(name="New Milestone") + assert milestone.name == "New Milestone" project = Project.objects.get(name="New Project") + assert project.name == "New Project" + + assert milestone.project_id == project.pk + assert issue.milestone_id == milestone.pk assert res.data == { - "createMilestone": { - "__typename": "MilestoneType", - "name": "Some Milestone", - "project": { - "id": to_base64("ProjectType", project.pk), - "name": project.name, + "createIssue": { + "__typename": "IssueType", + "id": to_base64("IssueType", issue.pk), + "name": "New Issue", + "milestone": { + "id": to_base64("MilestoneType", milestone.pk), + "name": "New Milestone", + "project": { + "id": to_base64("ProjectType", project.pk), + "name": "New Project", + }, }, }, } @@ -377,6 +405,479 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient): } +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + + projects = Project.objects.all() + project_typename, project_pk = from_base64( + res.data["createProjectWithMilestones"].pop("id") + ) + assert project_typename == "ProjectType" + assert projects[0] == Project.objects.get(pk=project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["createProjectWithMilestones"]["milestones"]) == 2 + + some_milestone = res.data["createProjectWithMilestones"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["createProjectWithMilestones"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + + assert res.data == { + "createProjectWithMilestones": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], + }, + ], + }, + ], + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_input_update_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation UpdateProject ($input: ProjectInputPartial!) { + updateProject (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + project = ProjectFactory.create(name="Some Project") + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + res = gql_client.query( + query, + { + "input": { + "id": to_base64("ProjectType", project.pk), + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["updateProject"], dict) + + project_typename, project_pk = from_base64(res.data["updateProject"].pop("id")) + assert project_typename == "ProjectType" + assert project.pk == int(project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["updateProject"]["milestones"]) == 2 + + some_milestone = res.data["updateProject"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["updateProject"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + + assert res.data == { + "updateProject": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], + }, + ], + }, + ], + }, + } + + +@pytest.mark.parametrize("mock_model", ["Milestone", "Issue", "Tag"]) +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_nested_calls_nested_full_clean( + db, gql_client: GraphQLTestClient, mock_model: str +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + with patch( + f"tests.projects.models.{mock_model}.clean", + side_effect=ValidationError({"name": ValidationError("Invalid name")}), + ) as mocked_full_clean: + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + assert res.data["createProjectWithMilestones"]["__typename"] == "OperationInfo" + assert mocked_full_clean.call_count == 1 + assert res.data["createProjectWithMilestones"]["messages"] == [ + {"field": "name", "kind": "VALIDATION", "message": "Invalid name"} + ] + + @pytest.mark.django_db(transaction=True) def test_input_update_mutation(db, gql_client: GraphQLTestClient): query = """