Skip to content

Commit b7523f4

Browse files
authored
Support UniqueConstraint (#7438)
1 parent 9882207 commit b7523f4

File tree

3 files changed

+170
-41
lines changed

3 files changed

+170
-41
lines changed

rest_framework/serializers.py

+46-36
Original file line numberDiff line numberDiff line change
@@ -1398,6 +1398,23 @@ def get_extra_kwargs(self):
13981398

13991399
return extra_kwargs
14001400

1401+
def get_unique_together_constraints(self, model):
1402+
"""
1403+
Returns iterator of (fields, queryset), each entry describes an unique together
1404+
constraint on `fields` in `queryset`.
1405+
"""
1406+
for parent_class in [model] + list(model._meta.parents):
1407+
for unique_together in parent_class._meta.unique_together:
1408+
yield unique_together, model._default_manager
1409+
for constraint in parent_class._meta.constraints:
1410+
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
1411+
yield (
1412+
constraint.fields,
1413+
model._default_manager
1414+
if constraint.condition is None
1415+
else model._default_manager.filter(constraint.condition)
1416+
)
1417+
14011418
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
14021419
"""
14031420
Return any additional field options that need to be included as a
@@ -1426,12 +1443,11 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14261443

14271444
unique_constraint_names -= {None}
14281445

1429-
# Include each of the `unique_together` field names,
1446+
# Include each of the `unique_together` and `UniqueConstraint` field names,
14301447
# so long as all the field names are included on the serializer.
1431-
for parent_class in [model] + list(model._meta.parents):
1432-
for unique_together_list in parent_class._meta.unique_together:
1433-
if set(field_names).issuperset(unique_together_list):
1434-
unique_constraint_names |= set(unique_together_list)
1448+
for unique_together_list, queryset in self.get_unique_together_constraints(model):
1449+
if set(field_names).issuperset(unique_together_list):
1450+
unique_constraint_names |= set(unique_together_list)
14351451

14361452
# Now we have all the field names that have uniqueness constraints
14371453
# applied, we can add the extra 'required=...' or 'default=...'
@@ -1526,11 +1542,6 @@ def get_unique_together_validators(self):
15261542
"""
15271543
Determine a default set of validators for any unique_together constraints.
15281544
"""
1529-
model_class_inheritance_tree = (
1530-
[self.Meta.model] +
1531-
list(self.Meta.model._meta.parents)
1532-
)
1533-
15341545
# The field names we're passing though here only include fields
15351546
# which may map onto a model field. Any dotted field name lookups
15361547
# cannot map to a field, and must be a traversal, so we're not
@@ -1556,34 +1567,33 @@ def get_unique_together_validators(self):
15561567
# Note that we make sure to check `unique_together` both on the
15571568
# base model class, but also on any parent classes.
15581569
validators = []
1559-
for parent_class in model_class_inheritance_tree:
1560-
for unique_together in parent_class._meta.unique_together:
1561-
# Skip if serializer does not map to all unique together sources
1562-
if not set(source_map).issuperset(unique_together):
1563-
continue
1564-
1565-
for source in unique_together:
1566-
assert len(source_map[source]) == 1, (
1567-
"Unable to create `UniqueTogetherValidator` for "
1568-
"`{model}.{field}` as `{serializer}` has multiple "
1569-
"fields ({fields}) that map to this model field. "
1570-
"Either remove the extra fields, or override "
1571-
"`Meta.validators` with a `UniqueTogetherValidator` "
1572-
"using the desired field names."
1573-
.format(
1574-
model=self.Meta.model.__name__,
1575-
serializer=self.__class__.__name__,
1576-
field=source,
1577-
fields=', '.join(source_map[source]),
1578-
)
1579-
)
1570+
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
1571+
# Skip if serializer does not map to all unique together sources
1572+
if not set(source_map).issuperset(unique_together):
1573+
continue
15801574

1581-
field_names = tuple(source_map[f][0] for f in unique_together)
1582-
validator = UniqueTogetherValidator(
1583-
queryset=parent_class._default_manager,
1584-
fields=field_names
1575+
for source in unique_together:
1576+
assert len(source_map[source]) == 1, (
1577+
"Unable to create `UniqueTogetherValidator` for "
1578+
"`{model}.{field}` as `{serializer}` has multiple "
1579+
"fields ({fields}) that map to this model field. "
1580+
"Either remove the extra fields, or override "
1581+
"`Meta.validators` with a `UniqueTogetherValidator` "
1582+
"using the desired field names."
1583+
.format(
1584+
model=self.Meta.model.__name__,
1585+
serializer=self.__class__.__name__,
1586+
field=source,
1587+
fields=', '.join(source_map[source]),
1588+
)
15851589
)
1586-
validators.append(validator)
1590+
1591+
field_names = tuple(source_map[f][0] for f in unique_together)
1592+
validator = UniqueTogetherValidator(
1593+
queryset=queryset,
1594+
fields=field_names
1595+
)
1596+
validators.append(validator)
15871597
return validators
15881598

15891599
def get_unique_for_date_validators(self):

rest_framework/utils/field_mapping.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,29 @@ def get_detail_view_name(model):
6262
}
6363

6464

65+
def get_unique_validators(field_name, model_field):
66+
"""
67+
Returns a list of UniqueValidators that should be applied to the field.
68+
"""
69+
field_set = set([field_name])
70+
conditions = {
71+
c.condition
72+
for c in model_field.model._meta.constraints
73+
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
74+
}
75+
if getattr(model_field, 'unique', False):
76+
conditions.add(None)
77+
if not conditions:
78+
return
79+
unique_error_message = get_unique_error_message(model_field)
80+
queryset = model_field.model._default_manager
81+
for condition in conditions:
82+
yield UniqueValidator(
83+
queryset=queryset if condition is None else queryset.filter(condition),
84+
message=unique_error_message
85+
)
86+
87+
6588
def get_field_kwargs(field_name, model_field):
6689
"""
6790
Creates a default instance of a basic non-relational field.
@@ -216,11 +239,7 @@ def get_field_kwargs(field_name, model_field):
216239
if not isinstance(validator, validators.MinLengthValidator)
217240
]
218241

219-
if getattr(model_field, 'unique', False):
220-
validator = UniqueValidator(
221-
queryset=model_field.model._default_manager,
222-
message=get_unique_error_message(model_field))
223-
validator_kwarg.append(validator)
242+
validator_kwarg += get_unique_validators(field_name, model_field)
224243

225244
if validator_kwarg:
226245
kwargs['validators'] = validator_kwarg

tests/test_validators.py

+100
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,106 @@ def filter(self, **kwargs):
464464
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
465465

466466

467+
class UniqueConstraintModel(models.Model):
468+
race_name = models.CharField(max_length=100)
469+
position = models.IntegerField()
470+
global_id = models.IntegerField()
471+
fancy_conditions = models.IntegerField(null=True)
472+
473+
class Meta:
474+
constraints = [
475+
models.UniqueConstraint(
476+
name="unique_constraint_model_global_id_uniq",
477+
fields=('global_id',),
478+
),
479+
models.UniqueConstraint(
480+
name="unique_constraint_model_fancy_1_uniq",
481+
fields=('fancy_conditions',),
482+
condition=models.Q(global_id__lte=1)
483+
),
484+
models.UniqueConstraint(
485+
name="unique_constraint_model_fancy_3_uniq",
486+
fields=('fancy_conditions',),
487+
condition=models.Q(global_id__gte=3)
488+
),
489+
models.UniqueConstraint(
490+
name="unique_constraint_model_together_uniq",
491+
fields=('race_name', 'position'),
492+
condition=models.Q(race_name='example'),
493+
)
494+
]
495+
496+
497+
class UniqueConstraintSerializer(serializers.ModelSerializer):
498+
class Meta:
499+
model = UniqueConstraintModel
500+
fields = '__all__'
501+
502+
503+
class TestUniqueConstraintValidation(TestCase):
504+
def setUp(self):
505+
self.instance = UniqueConstraintModel.objects.create(
506+
race_name='example',
507+
position=1,
508+
global_id=1
509+
)
510+
UniqueConstraintModel.objects.create(
511+
race_name='example',
512+
position=2,
513+
global_id=2
514+
)
515+
UniqueConstraintModel.objects.create(
516+
race_name='other',
517+
position=1,
518+
global_id=3
519+
)
520+
521+
def test_repr(self):
522+
serializer = UniqueConstraintSerializer()
523+
# the order of validators isn't deterministic so delete
524+
# fancy_conditions field that has two of them
525+
del serializer.fields['fancy_conditions']
526+
expected = dedent("""
527+
UniqueConstraintSerializer():
528+
id = IntegerField(label='ID', read_only=True)
529+
race_name = CharField(max_length=100, required=True)
530+
position = IntegerField(required=True)
531+
global_id = IntegerField(validators=[<UniqueValidator(queryset=UniqueConstraintModel.objects.all())>])
532+
class Meta:
533+
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
534+
""")
535+
assert repr(serializer) == expected
536+
537+
def test_unique_together_field(self):
538+
"""
539+
UniqueConstraint fields and condition attributes must be passed
540+
to UniqueTogetherValidator as fields and queryset
541+
"""
542+
serializer = UniqueConstraintSerializer()
543+
assert len(serializer.validators) == 1
544+
validator = serializer.validators[0]
545+
assert validator.fields == ('race_name', 'position')
546+
assert set(validator.queryset.values_list(flat=True)) == set(
547+
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
548+
)
549+
550+
def test_single_field_uniq_validators(self):
551+
"""
552+
UniqueConstraint with single field must be transformed into
553+
field's UniqueValidator
554+
"""
555+
serializer = UniqueConstraintSerializer()
556+
assert len(serializer.validators) == 1
557+
validators = serializer.fields['global_id'].validators
558+
assert len(validators) == 1
559+
assert validators[0].queryset == UniqueConstraintModel.objects
560+
561+
validators = serializer.fields['fancy_conditions'].validators
562+
assert len(validators) == 2
563+
ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators}
564+
assert ids_in_qs == {frozenset([1]), frozenset([3])}
565+
566+
467567
# Tests for `UniqueForDateValidator`
468568
# ----------------------------------
469569

0 commit comments

Comments
 (0)