diff --git a/strawberry_django/mutations/resolvers.py b/strawberry_django/mutations/resolvers.py index 50c76677..53a58cbe 100644 --- a/strawberry_django/mutations/resolvers.py +++ b/strawberry_django/mutations/resolvers.py @@ -15,6 +15,7 @@ import strawberry from django.db import models, transaction from django.db.models.base import Model +from django.db.models.fields import Field from django.db.models.fields.related import ManyToManyField from django.db.models.fields.reverse_related import ( ForeignObjectRel, @@ -44,7 +45,11 @@ ) if TYPE_CHECKING: - from django.db.models.manager import ManyToManyRelatedManager, RelatedManager + from django.db.models.manager import ( + BaseManager, + ManyToManyRelatedManager, + RelatedManager, + ) from strawberry.types.info import Info @@ -88,6 +93,7 @@ def _parse_data( value: Any, *, key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, ): obj, data = _parse_pk(value, model, key_attr=key_attr) parsed_data = {} @@ -97,10 +103,21 @@ def _parse_data( continue if isinstance(v, ParsedObject): - if v.pk is None: - v = create(info, model, v.data or {}) # noqa: PLW2901 + if v.pk in {None, UNSET}: + related_field = cast("Field", get_model_fields(model).get(k)) + related_model = related_field.related_model + v = create( # noqa: PLW2901 + info, + cast("type[Model]", related_model), + v.data or {}, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=[related_field.name], + ) elif isinstance(v.pk, models.Model) and v.data: - v = update(info, v.pk, v.data, key_attr=key_attr) # noqa: PLW2901 + v = update( # noqa: PLW2901 + info, v.pk, v.data, key_attr=key_attr, full_clean=full_clean + ) else: v = v.pk # noqa: PLW2901 @@ -222,6 +239,7 @@ def prepare_create_update( data: dict[str, Any], key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, + exclude_m2m: list[str] | None = None, ) -> tuple[ Model, dict[str, object], @@ -237,6 +255,7 @@ def prepare_create_update( fields = get_model_fields(model) m2m: list[tuple[ManyToManyField | ForeignObjectRel, Any]] = [] direct_field_values: dict[str, object] = {} + exclude_m2m = exclude_m2m or [] if dataclasses.is_dataclass(data): data = vars(data) @@ -256,6 +275,8 @@ def prepare_create_update( # (but only if the instance is already saved and we are updating it) value = False # noqa: PLW2901 elif isinstance(field, (ManyToManyField, ForeignObjectRel)): + if name in exclude_m2m: + continue # m2m will be processed later m2m.append((field, value)) direct_field_value = False @@ -269,14 +290,19 @@ def prepare_create_update( cast("type[Model]", field.related_model), value, key_attr=key_attr, + full_clean=full_clean, ) if value is None and not value_data: value = None # noqa: PLW2901 # If foreign object is not found, then create it - elif value is None: - value = field.related_model._default_manager.create( # noqa: PLW2901 - **value_data, + elif value in {None, UNSET}: + value = create( # noqa: PLW2901 + info, + field.related_model, + value_data, + key_attr=key_attr, + full_clean=full_clean, ) # If foreign object does not need updating, then skip it @@ -309,6 +335,7 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -321,10 +348,10 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... -@transaction.atomic def create( info: Info, model: type[_M], @@ -333,12 +360,43 @@ def create( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, +) -> list[_M] | _M: + return _create( + info, + model._default_manager, + data, + key_attr=key_attr, + full_clean=full_clean, + pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, + ) + + +@transaction.atomic +def _create( + info: Info, + manager: BaseManager, + data: dict[str, Any] | list[dict[str, Any]], + *, + key_attr: str | None = None, + full_clean: bool | FullCleanOptions = True, + pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M] | _M: + model = manager.model # Before creating your instance, verify this is not a bulk create # if so, add them one by one. Otherwise, get to work. if isinstance(data, list): return [ - create(info, model, d, key_attr=key_attr, full_clean=full_clean) + create( + info, + model, + d, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) for d in data ] @@ -365,6 +423,7 @@ def create( data=data, full_clean=full_clean, key_attr=key_attr, + exclude_m2m=exclude_m2m, ) # Creating the instance directly via create() without full-clean will @@ -376,7 +435,7 @@ def create( # Create the instance using the manager create method to respect # manager create overrides. This also ensures support for proxy-models. - instance = model._default_manager.create(**create_kwargs) + instance = manager.create(**create_kwargs) for field, value in m2m: update_m2m(info, instance, field, value, key_attr) @@ -393,6 +452,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M: ... @@ -405,6 +465,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> list[_M]: ... @@ -417,6 +478,7 @@ def update( key_attr: str | None = None, full_clean: bool | FullCleanOptions = True, pre_save_hook: Callable[[_M], None] | None = None, + exclude_m2m: list[str] | None = None, ) -> _M | list[_M]: # Unwrap lazy objects since they have a proxy __iter__ method that will make # them iterables even if the wrapped object isn't @@ -433,6 +495,7 @@ def update( key_attr=key_attr, full_clean=full_clean, pre_save_hook=pre_save_hook, + exclude_m2m=exclude_m2m, ) for instance in instances ] @@ -443,6 +506,7 @@ def update( data=data, key_attr=key_attr, full_clean=full_clean, + exclude_m2m=exclude_m2m, ) if pre_save_hook is not None: @@ -554,15 +618,22 @@ def update_m2m( use_remove = True if isinstance(field, ManyToManyField): manager = cast("RelatedManager", getattr(instance, field.attname)) + reverse_field_name = field.remote_field.related_name # type: ignore else: assert isinstance(field, (ManyToManyRel, ManyToOneRel)) accessor_name = field.get_accessor_name() + reverse_field_name = field.field.name assert accessor_name manager = cast("RelatedManager", getattr(instance, accessor_name)) if field.one_to_many: # remove if field is nullable, otherwise delete use_remove = field.remote_field.null is True + # Create a data dict containing the reference to the instance and exclude it from + # nested m2m creation (to break circular references) + ref_instance_data = {reverse_field_name: instance} + exclude_m2m = [reverse_field_name] + to_add = [] to_remove = [] to_delete = [] @@ -581,7 +652,11 @@ def update_m2m( need_remove_cache = need_remove_cache or bool(values) for v in values: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) if obj: data.pop(key_attr, None) @@ -621,14 +696,17 @@ def update_m2m( existing.discard(obj) else: - if key_attr not in data: # we have a Input Type - obj, _ = manager.get_or_create(**data) - else: - data.pop(key_attr) - obj = manager.create(**data) - - if full_clean: - obj.full_clean(**full_clean_options) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + obj = _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) existing.discard(obj) for remaining in existing: @@ -645,6 +723,7 @@ def update_m2m( cast("type[Model]", manager.model), v, key_attr=key_attr, + full_clean=full_clean, ) if obj and data: data.pop(key_attr, None) @@ -656,18 +735,28 @@ def update_m2m( data.pop(key_attr, None) to_add.append(obj) elif data: - if key_attr not in data: - manager.get_or_create(**data) - else: - data.pop(key_attr) - manager.create(**data) + # If we've reached here, the key_attr should be UNSET or missing. So + # let's remove it if it is there. + data.pop(key_attr, None) + _create( + info, + manager, + data | ref_instance_data, + key_attr=key_attr, + full_clean=full_clean, + exclude_m2m=exclude_m2m, + ) else: raise AssertionError need_remove_cache = need_remove_cache or bool(value.remove) for v in value.remove or []: obj, data = _parse_data( - info, cast("type[Model]", manager.model), v, key_attr=key_attr + info, + cast("type[Model]", manager.model), + v, + key_attr=key_attr, + full_clean=full_clean, ) data.pop(key_attr, None) assert not data diff --git a/tests/projects/schema.py b/tests/projects/schema.py index 9a29c5bd..9c82e89a 100644 --- a/tests/projects/schema.py +++ b/tests/projects/schema.py @@ -337,6 +337,12 @@ class MilestoneIssueInput: name: strawberry.auto +@strawberry_django.partial(Issue) +class MilestoneIssueInputPartial: + name: strawberry.auto + tags: Optional[list[TagInputPartial]] + + @strawberry_django.partial(Project) class ProjectInputPartial(NodeInputPartial): name: strawberry.auto @@ -353,6 +359,8 @@ class MilestoneInput: @strawberry_django.partial(Milestone) class MilestoneInputPartial(NodeInputPartial): name: strawberry.auto + issues: Optional[list[MilestoneIssueInputPartial]] + project: Optional[ProjectInputPartial] @strawberry.type @@ -521,6 +529,11 @@ class Mutation: argument_name="input", key_attr="name", ) + create_project_with_milestones: ProjectType = mutations.create( + ProjectInputPartial, + handle_django_errors=True, + argument_name="input", + ) update_project: ProjectType = mutations.update( ProjectInputPartial, handle_django_errors=True, diff --git a/tests/projects/snapshots/schema.gql b/tests/projects/snapshots/schema.gql index eb6f1ad0..3b74f901 100644 --- a/tests/projects/snapshots/schema.gql +++ b/tests/projects/snapshots/schema.gql @@ -95,6 +95,8 @@ input CreateProjectInput { union CreateProjectPayload = ProjectType | OperationInfo +union CreateProjectWithMilestonesPayload = ProjectType | OperationInfo + input CreateQuizInput { title: String! fullCleanOptions: Boolean! = false @@ -365,12 +367,19 @@ input MilestoneInput { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial } input MilestoneIssueInput { name: String! } +input MilestoneIssueInputPartial { + name: String + tags: [TagInputPartial!] +} + input MilestoneOrder { name: Ordering project: ProjectOrder @@ -433,6 +442,7 @@ type Mutation { updateIssueWithKeyAttr(input: IssueInputPartialWithoutId!): UpdateIssueWithKeyAttrPayload! deleteIssue(input: NodeInput!): DeleteIssuePayload! deleteIssueWithKeyAttr(input: MilestoneIssueInput!): DeleteIssueWithKeyAttrPayload! + createProjectWithMilestones(input: ProjectInputPartial!): CreateProjectWithMilestonesPayload! updateProject(input: ProjectInputPartial!): UpdateProjectPayload! createMilestone(input: MilestoneInput!): CreateMilestonePayload! createProject( diff --git a/tests/projects/snapshots/schema_with_inheritance.gql b/tests/projects/snapshots/schema_with_inheritance.gql index b98209fb..ae16f683 100644 --- a/tests/projects/snapshots/schema_with_inheritance.gql +++ b/tests/projects/snapshots/schema_with_inheritance.gql @@ -159,6 +159,13 @@ input MilestoneFilter { input MilestoneInputPartial { id: GlobalID name: String + issues: [MilestoneIssueInputPartial!] + project: ProjectInputPartial +} + +input MilestoneIssueInputPartial { + name: String + tags: [TagInputPartial!] } input MilestoneOrder { @@ -308,6 +315,12 @@ type PageInfo { endCursor: String } +input ProjectInputPartial { + id: GlobalID + name: String + milestones: [MilestoneInputPartial!] +} + input ProjectOrder { id: Ordering name: Ordering @@ -405,6 +418,11 @@ input StrFilterLookup { iRegex: String } +input TagInputPartial { + id: GlobalID + name: String +} + type TagType implements Node & Named { """The Globally Unique ID of this object""" id: GlobalID! diff --git a/tests/test_input_mutations.py b/tests/test_input_mutations.py index 6f2a388b..4cdaaf38 100644 --- a/tests/test_input_mutations.py +++ b/tests/test_input_mutations.py @@ -1,4 +1,7 @@ +from unittest.mock import patch + import pytest +from django.core.exceptions import ValidationError from strawberry.relay import from_base64, to_base64 from tests.utils import GraphQLTestClient, assert_num_queries @@ -10,7 +13,7 @@ TagFactory, UserFactory, ) -from .projects.models import Issue, Milestone, Project +from .projects.models import Issue, Milestone, Project, Tag @pytest.mark.django_db(transaction=True) @@ -245,8 +248,8 @@ def test_input_create_mutation(db, gql_client: GraphQLTestClient): @pytest.mark.django_db(transaction=True) def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient): query = """ - mutation CreateMilestone ($input: MilestoneInput!) { - createMilestone (input: $input) { + mutation CreateIssue ($input: IssueInput!) { + createIssue (input: $input) { __typename ... on OperationInfo { messages { @@ -255,45 +258,70 @@ def test_input_create_mutation_nested_creation(db, gql_client: GraphQLTestClient message } } - ... on MilestoneType { + ... on IssueType { id name - project { + milestone { id name + project { + id + name + } } } } } """ assert not Project.objects.filter(name="New Project").exists() + assert not Milestone.objects.filter(name="New Milestone").exists() + assert not Issue.objects.filter(name="New Issue").exists() res = gql_client.query( query, { "input": { - "name": "Some Milestone", - "project": { - "name": "New Project", + "name": "New Issue", + "milestone": { + "name": "New Milestone", + "project": { + "name": "New Project", + }, }, }, }, ) + assert res.data - assert isinstance(res.data["createMilestone"], dict) + assert isinstance(res.data["createIssue"], dict) - typename, _pk = from_base64(res.data["createMilestone"].pop("id")) - assert typename == "MilestoneType" + typename, pk = from_base64(res.data["createIssue"].get("id")) + + assert typename == "IssueType" + issue = Issue.objects.get(pk=pk) + assert issue.name == "New Issue" + + milestone = Milestone.objects.get(name="New Milestone") + assert milestone.name == "New Milestone" project = Project.objects.get(name="New Project") + assert project.name == "New Project" + + assert milestone.project_id == project.pk + assert issue.milestone_id == milestone.pk assert res.data == { - "createMilestone": { - "__typename": "MilestoneType", - "name": "Some Milestone", - "project": { - "id": to_base64("ProjectType", project.pk), - "name": project.name, + "createIssue": { + "__typename": "IssueType", + "id": to_base64("IssueType", issue.pk), + "name": "New Issue", + "milestone": { + "id": to_base64("MilestoneType", milestone.pk), + "name": "New Milestone", + "project": { + "id": to_base64("ProjectType", project.pk), + "name": "New Project", + }, }, }, } @@ -377,6 +405,479 @@ def test_input_create_with_m2m_mutation(db, gql_client: GraphQLTestClient): } +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + + projects = Project.objects.all() + project_typename, project_pk = from_base64( + res.data["createProjectWithMilestones"].pop("id") + ) + assert project_typename == "ProjectType" + assert projects[0] == Project.objects.get(pk=project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["createProjectWithMilestones"]["milestones"]) == 2 + + some_milestone = res.data["createProjectWithMilestones"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["createProjectWithMilestones"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + + assert res.data == { + "createProjectWithMilestones": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], + }, + ], + }, + ], + }, + } + + +@pytest.mark.django_db(transaction=True) +def test_input_update_mutation_with_multiple_level_nested_creation( + db, gql_client: GraphQLTestClient +): + query = """ + mutation UpdateProject ($input: ProjectInputPartial!) { + updateProject (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + project = ProjectFactory.create(name="Some Project") + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + res = gql_client.query( + query, + { + "input": { + "id": to_base64("ProjectType", project.pk), + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["updateProject"], dict) + + project_typename, project_pk = from_base64(res.data["updateProject"].pop("id")) + assert project_typename == "ProjectType" + assert project.pk == int(project_pk) + + milestones = Milestone.objects.all() + assert len(milestones) == 2 + assert len(res.data["updateProject"]["milestones"]) == 2 + + some_milestone = res.data["updateProject"]["milestones"][0] + milestone_typename, milestone_pk = from_base64(some_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[0] == Milestone.objects.get(pk=milestone_pk) + + another_milestone = res.data["updateProject"]["milestones"][1] + milestone_typename, milestone_pk = from_base64(another_milestone.pop("id")) + assert milestone_typename == "MilestoneType" + assert milestones[1] == Milestone.objects.get(pk=milestone_pk) + + issues = Issue.objects.all() + assert len(issues) == 4 + assert len(some_milestone["issues"]) == 1 + assert len(another_milestone["issues"]) == 3 + + # Issues for first milestone + fetched_issue = some_milestone["issues"][0] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[0] == Issue.objects.get(pk=issue_pk) + # Issues for second milestone + for i in range(3): + fetched_issue = another_milestone["issues"][i] + issue_typename, issue_pk = from_base64(fetched_issue.pop("id")) + assert issue_typename == "IssueType" + assert issues[i + 1] == Issue.objects.get(pk=issue_pk) + + tags = Tag.objects.all() + assert len(tags) == 7 + assert len(issues[0].tags.all()) == 4 # 3 new tags + shared tag + assert len(issues[1].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[2].tags.all()) == 2 # 1 new tag + shared tag + assert len(issues[3].tags.all()) == 2 # 1 new tag + shared tag + + assert res.data == { + "updateProject": { + "__typename": "ProjectType", + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 4"}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 5"}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Shared Tag"}, + {"name": "Tag 6"}, + ], + }, + ], + }, + ], + }, + } + + +@pytest.mark.parametrize("mock_model", ["Milestone", "Issue", "Tag"]) +@pytest.mark.django_db(transaction=True) +def test_input_create_mutation_with_nested_calls_nested_full_clean( + db, gql_client: GraphQLTestClient, mock_model: str +): + query = """ + mutation createProjectWithMilestones ($input: ProjectInputPartial!) { + createProjectWithMilestones (input: $input) { + __typename + ... on OperationInfo { + messages { + kind + field + message + } + } + ... on ProjectType { + id + name + milestones { + id + name + issues { + id + name + tags { + name + } + } + } + } + } + } + """ + + shared_tag = TagFactory.create(name="Shared Tag") + shared_tag_id = to_base64("TagType", shared_tag.pk) + + with patch( + f"tests.projects.models.{mock_model}.clean", + side_effect=ValidationError({"name": ValidationError("Invalid name")}), + ) as mocked_full_clean: + res = gql_client.query( + query, + { + "input": { + "name": "Some Project", + "milestones": [ + { + "name": "Some Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 1"}, + {"name": "Tag 2"}, + {"name": "Tag 3"}, + {"id": shared_tag_id}, + ], + } + ], + }, + { + "name": "Another Milestone", + "issues": [ + { + "name": "Some Issue", + "tags": [ + {"name": "Tag 4"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Another Issue", + "tags": [ + {"name": "Tag 5"}, + {"id": shared_tag_id}, + ], + }, + { + "name": "Third issue", + "tags": [ + {"name": "Tag 6"}, + {"id": shared_tag_id}, + ], + }, + ], + }, + ], + }, + }, + ) + + assert res.data + assert isinstance(res.data["createProjectWithMilestones"], dict) + assert res.data["createProjectWithMilestones"]["__typename"] == "OperationInfo" + assert mocked_full_clean.call_count == 1 + assert res.data["createProjectWithMilestones"]["messages"] == [ + {"field": "name", "kind": "VALIDATION", "message": "Invalid name"} + ] + + @pytest.mark.django_db(transaction=True) def test_input_update_mutation(db, gql_client: GraphQLTestClient): query = """