Skip to content

Commit

Permalink
feat: Enforce validation for updating nested relations (#405)
Browse files Browse the repository at this point in the history
* Perform full_clean for updating nested relations

* Add tests for validation on nested updates

* Perform full clean only if obj is not None

* Pass full_clean options to updatem2m as well
  • Loading branch information
tokr-bit authored Nov 6, 2023
1 parent b9f8704 commit 8cccadf
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 5 deletions.
13 changes: 12 additions & 1 deletion strawberry_django/mutations/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def update(
instance.save()

for field, value in m2m:
update_m2m(info, instance, field, value)
update_m2m(info, instance, field, value, full_clean)

retval.append(instance.__class__._default_manager.get(pk=instance.pk))

Expand Down Expand Up @@ -422,6 +422,7 @@ def update_m2m(
instance: Model,
field: ManyToManyField | ForeignObjectRel,
value: Any,
full_clean: bool | FullCleanOptions = True,
):
if value is UNSET:
return
Expand Down Expand Up @@ -457,6 +458,8 @@ def update_m2m(
to_delete = []
need_remove_cache = False

full_clean_options = full_clean if isinstance(full_clean, dict) else {}

values = value.set if isinstance(value, ParsedObjectList) else value
if isinstance(values, list):
if isinstance(value, ParsedObjectList) and getattr(value, "add", None):
Expand All @@ -474,6 +477,8 @@ def update_m2m(
if data:
for k, inner_value in data.items():
setattr(obj, k, inner_value)
if full_clean:
obj.full_clean(**full_clean_options)
obj.save()

if hasattr(manager, "through"):
Expand All @@ -496,13 +501,17 @@ def update_m2m(

for k, inner_value in through_defaults.items():
setattr(im, k, inner_value)
if full_clean:
im.full_clean(**full_clean_options)
im.save()
elif obj not in existing:
to_add.append(obj)

existing.discard(obj)
else:
obj, _ = manager.get_or_create(**data)
if full_clean:
obj.full_clean(**full_clean_options)
existing.discard(obj)

for remaining in existing:
Expand All @@ -516,6 +525,8 @@ def update_m2m(
for v in value.add or []:
obj, data = _parse_data(info, manager.model, v)
if obj and data:
if full_clean:
obj.full_clean(**full_clean_options)
manager.add(obj, **data)
elif obj:
# Do this later in a bulk
Expand Down
9 changes: 7 additions & 2 deletions tests/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, Optional

from django.core.exceptions import ImproperlyConfigured
from django.core.exceptions import ImproperlyConfigured, ValidationError
from django.db import models

from strawberry_django.descriptors import model_property
Expand All @@ -9,6 +9,11 @@
from django.db.models.manager import RelatedManager


def validate_fruit_type(value: str):
if "rotten" in value:
raise ValidationError("We do not allow rotten fruits.")


class Fruit(models.Model):
name = models.CharField(max_length=20)
color_id: Optional[int]
Expand Down Expand Up @@ -43,7 +48,7 @@ class Color(models.Model):


class FruitType(models.Model):
name = models.CharField(max_length=20)
name = models.CharField(max_length=20, validators=[validate_fruit_type])


class User(models.Model):
Expand Down
9 changes: 9 additions & 0 deletions tests/mutations/test_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,15 @@ def test_update(mutation, fruits):
]


def test_update_m2m_with_validation_error(mutation, fruit):
result = mutation(
'{ fruits: updateFruits(data: { types: [{ name: "rotten"} ] }) { id types {'
" name } }}",
)
assert result.errors
assert result.errors[0].message == "{'name': ['We do not allow rotten fruits.']}"


def test_update_lazy_object(mutation, fruit):
result = mutation(
'{ fruit: updateLazyFruit(data: { name: "orange" }) { id name } }',
Expand Down
4 changes: 2 additions & 2 deletions tests/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class GeoFieldPartialInput(GeoField):

@strawberry_django.input(models.Fruit)
class FruitInput(Fruit):
pass
types: List[FruitTypeInput] | None # noqa: UP006


@strawberry_django.input(models.Color)
Expand All @@ -70,7 +70,7 @@ class FruitTypeInput(FruitType):

@strawberry_django.input(models.Fruit, partial=True)
class FruitPartialInput(FruitInput):
pass
types: List[FruitTypePartialInput] | None # noqa: UP006


@strawberry_django.input(models.Color, partial=True)
Expand Down

0 comments on commit 8cccadf

Please # to comment.