Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

fix-schema-generation-for-list-dict-fields #79

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,5 @@ ENV/
# IDE settings
.vscode/
.idea/

tests/.hypothesis
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [UNRELEASED]
### Fixed
- use a pattern to correctly describe the `attr` value in validation errors in all cases. Specifically, this fixes
the incorrect description of the `attr` value for list serializers and list/dict fields. Previously, the `attr`
value was described with an enum having a single value like `INDEX.field`. Now, it shows up as a string with
the pattern `\d+\.field`.

## [0.14.0] - 2024-06-19
### Added
Expand Down
12 changes: 7 additions & 5 deletions docs/settings.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,18 @@ DRF_STANDARDIZED_ERRORS = {
"ERROR_SCHEMAS": None,

# When there is a validation error in list serializers, the "attr" returned
# will be sth like "0.email", "1.email", "2.email", ... So, to describe
# the error codes linked to the same field in a list serializer, the field
# will appear in the schema with the name "INDEX.email"
# will be sth like "0.email", "1.email", "2.email", ... So, this setting is
# used to represent the error codes linked to the same field during API
# schema generation and its value will be part of the name of the
# corresponding error component.
"LIST_INDEX_IN_API_SCHEMA": "INDEX",

# When there is a validation error in a DictField with the name "extra_data",
# the "attr" returned will be sth like "extra_data.<key1>", "extra_data.<key2>",
# "extra_data.<key3>", ... Since the keys of a DictField are not predetermined,
# this setting is used as a common name to be used in the API schema. So, the
# corresponding "attr" value for the previous example will be "extra_data.KEY"
# this setting is used to represent the error codes linked to the same field
# during API schema generation and its value will be part of the name of the
# corresponding error component.
"DICT_KEY_IN_API_SCHEMA": "KEY",

# should be unique to error components since it is used to identify error
Expand Down
66 changes: 53 additions & 13 deletions drf_standardized_errors/openapi_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import re
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Dict, List, Optional, Set, Type, Union
from typing import Dict, List, Optional, Set, Type, Union

from django import forms
from django.core.validators import (
DecimalValidator,
RegexValidator,
validate_image_file_extension,
validate_integer,
validate_ipv4_address,
Expand Down Expand Up @@ -421,21 +423,16 @@ class Meta:
def get_error_serializer(
operation_id: str, attr: Optional[str], error_codes: Set[str]
) -> Type[serializers.Serializer]:
attr_kwargs: Dict[str, Any] = {"choices": [(attr, attr)]}
if not attr:
attr_kwargs["allow_null"] = True
if attr is not None:
attr_regex = _get_attr_regex(attr)
attr_field = serializers.CharField(validators=[RegexValidator(attr_regex)])
else:
attr_field = serializers.CharField(allow_null=True)
error_code_choices = sorted(zip(error_codes, error_codes))

camelcase_operation_id = camelize(operation_id)
attr_with_underscores = (attr or "").replace(
package_settings.NESTED_FIELD_SEPARATOR, "_"
)
camelcase_attr = camelize(attr_with_underscores)
suffix = package_settings.ERROR_COMPONENT_NAME_SUFFIX
component_name = f"{camelcase_operation_id}{camelcase_attr}{suffix}"
component_name = _get_error_component_name(operation_id, attr)

class ErrorSerializer(serializers.Serializer):
attr = serializers.ChoiceField(**attr_kwargs)
attr = attr_field
code = serializers.ChoiceField(choices=error_code_choices)
detail = serializers.CharField()

Expand All @@ -445,6 +442,49 @@ class Meta:
return ErrorSerializer


def _get_attr_regex(attr: str) -> str:
r"""
- For ListSerializers:
- input attr: "INDEX.field1", "INDEX.field2", ...
- regex generated: "\d+\.field1", "\d+\.field2", ...
- actual field name: "0.field1", "1.field2", ...
- For ListFields:
- input attr: "field.INDEX"
- regex generated: "field\.\d+"
- actual field name: "0.field1", "1.field2", ...
- For DictFields:
- input attr: "field.KEY"
- regex generated: "field\..+"
- actual field name: "field.key1", "field.key2", ...
- For other cases
- input attr: "field.nested_field"
- regex generated: "field\.nested_field"
- actual field name: "field.nested_field"
"""
parts = attr.split(package_settings.NESTED_FIELD_SEPARATOR)
regex_parts = []
for part in parts:
if part == package_settings.LIST_INDEX_IN_API_SCHEMA:
regex_parts.append(r"\d+")
elif part == package_settings.DICT_KEY_IN_API_SCHEMA:
regex_parts.append(".+")
else:
regex_parts.append(re.escape(part))

escaped_separator = re.escape(package_settings.NESTED_FIELD_SEPARATOR)
return escaped_separator.join(regex_parts)


def _get_error_component_name(operation_id: str, attr: Optional[str]) -> str:
camelcase_operation_id = camelize(operation_id)
attr_with_underscores = (attr or "").replace(
package_settings.NESTED_FIELD_SEPARATOR, "_"
)
camelcase_attr = camelize(attr_with_underscores)
suffix = package_settings.ERROR_COMPONENT_NAME_SUFFIX
return f"{camelcase_operation_id}{camelcase_attr}{suffix}"


@dataclass
class InputDataField:
name: str
Expand Down
25 changes: 25 additions & 0 deletions tests/fuzzing_urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from django.urls import path
from drf_spectacular.views import SpectacularAPIView

from .views import (
DictFieldFuzzingSerializer,
FuzzingView,
ListFieldFuzzingSerializer,
ListSerializerFuzzingSerializer,
)

urlpatterns = [
path(
"fuzzing/list_field/",
FuzzingView.as_view(serializer_class=ListFieldFuzzingSerializer),
),
path(
"fuzzing/list_serializer/",
FuzzingView.as_view(serializer_class=ListSerializerFuzzingSerializer),
),
path(
"fuzzing/dict_field/",
FuzzingView.as_view(serializer_class=DictFieldFuzzingSerializer),
),
path("schema/", SpectacularAPIView.as_view(), name="api-schema"),
]
3 changes: 3 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,6 @@
"drf_standardized_errors.openapi_hooks.postprocess_schema_enums"
],
}

STATIC_URL = "/static/"
MEDIA_URL = "/media/"
85 changes: 85 additions & 0 deletions tests/test_api_fuzzing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List
from urllib.parse import urljoin

import pytest
import schemathesis
from hypothesis import settings
from schemathesis import Case, DataGenerationMethod


@pytest.fixture
def run_test_server(live_server, settings):
settings.ROOT_URLCONF = "tests.fuzzing_urls"

schema_url = urljoin(live_server.url, "/schema/")
return schemathesis.from_uri(schema_url)


schema = schemathesis.from_pytest_fixture("run_test_server")


@schemathesis.hook
def before_add_examples(
context: schemathesis.hooks.HookContext,
examples: List[Case],
) -> None:
if context.operation.path == "/fuzzing/list_field/":
case = Case(
context.operation,
0.01,
body={"field1": [None]},
media_type="application/json",
)
examples.append(case)
if context.operation.path == "/fuzzing/dict_field/":
case = Case(
context.operation,
0.01,
body={"field1": {"my_int": "non_integer_value"}},
media_type="application/json",
)
examples.append(case)
if context.operation.path == "/fuzzing/list_serializer/":
case = Case(
context.operation,
0.01,
body={"field1": [{"field2": None}]},
media_type="application/json",
)
examples.append(case)


@schema.parametrize(
endpoint="fuzzing/list_field/",
data_generation_methods=[
DataGenerationMethod.negative,
DataGenerationMethod.positive,
],
)
@settings(max_examples=100)
def test_compliance_to_api_schema_for_list_field(case):
case.call_and_validate()


@schema.parametrize(
endpoint="fuzzing/list_serializer/",
data_generation_methods=[
DataGenerationMethod.negative,
DataGenerationMethod.positive,
],
)
@settings(max_examples=100)
def test_compliance_to_api_schema_for_list_serializer(case):
case.call_and_validate()


@schema.parametrize(
endpoint="fuzzing/dict_field/",
data_generation_methods=[
DataGenerationMethod.negative,
DataGenerationMethod.positive,
],
)
@settings(max_examples=100)
def test_compliance_to_api_schema_for_dict_field(case):
case.call_and_validate()
44 changes: 43 additions & 1 deletion tests/test_openapi_validation_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from drf_spectacular.utils import extend_schema
from rest_framework import serializers
from rest_framework.decorators import action, api_view
from rest_framework.generics import DestroyAPIView, UpdateAPIView
from rest_framework.generics import DestroyAPIView, GenericAPIView, UpdateAPIView
from rest_framework.response import Response
from rest_framework.versioning import URLPathVersioning
from rest_framework.viewsets import ModelViewSet
Expand Down Expand Up @@ -512,3 +512,45 @@ def test_extra_validation_errors_for_nested_list_serializer_field(
schema = generate_view_schema(route, view)
error_codes = get_error_codes(schema, "ValidateCreateGroupsINDEXNameErrorComponent")
assert "some_error" in error_codes


def test_pattern_for_list_serializer_field(viewset_with_nested_serializer):
route = "validate/"
view = viewset_with_nested_serializer.as_view({"post": "create"})
schema = generate_view_schema(route, view)
attr = schema["components"]["schemas"][
"ValidateCreateGroupsINDEXNameErrorComponent"
]["properties"]["attr"]
assert attr["pattern"] == r"groups\.\d+\.name"


@pytest.fixture
def list_dict_fields_view():
class SomeSerializer(serializers.Serializer):
field1 = serializers.DictField(child=serializers.IntegerField())
field2 = serializers.ListField(child=serializers.IntegerField())

class SomeView(GenericAPIView):
serializer_class = SomeSerializer

def post(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
return Response(serializer.data)

return SomeView


def test_pattern_for_list_dict_fields(list_dict_fields_view):
route = "validate/"
view = list_dict_fields_view.as_view()
schema = generate_view_schema(route, view)
dict_attr = schema["components"]["schemas"][
"ValidateCreateField1KEYErrorComponent"
]["properties"]["attr"]
assert dict_attr["pattern"] == r"field1\..+"

list_attr = schema["components"]["schemas"][
"ValidateCreateField2INDEXErrorComponent"
]["properties"]["attr"]
assert list_attr["pattern"] == r"field2\.\d+"
26 changes: 26 additions & 0 deletions tests/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from rest_framework import serializers
from rest_framework.authentication import BasicAuthentication
from rest_framework.generics import GenericAPIView
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
from rest_framework.response import Response
from rest_framework.throttling import BaseThrottle
Expand Down Expand Up @@ -73,3 +74,28 @@ class RecursionView(APIView):
def get(self, request, *args, **kwargs):
errors = [{"field": ["Some Error"]} for _ in range(1, 1000)]
raise serializers.ValidationError(errors)


class ListFieldFuzzingSerializer(serializers.Serializer):
field1 = serializers.ListField(child=serializers.IntegerField())


class SomeSerializer(serializers.Serializer):
field2 = serializers.IntegerField()


class ListSerializerFuzzingSerializer(serializers.Serializer):
field1 = SomeSerializer(many=True)


class DictFieldFuzzingSerializer(serializers.Serializer):
field1 = serializers.DictField(child=serializers.IntegerField())


class FuzzingView(GenericAPIView):
parser_classes = [JSONParser]

def post(self, request):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
return Response(data=serializer.data)
4 changes: 4 additions & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ python =
[testenv]
deps =
pytest
pytest-env
pytest-django
schemathesis
drf-spectacular>=0.27.0
django-filter
dj32: Django>=3.2,<4.0
Expand Down Expand Up @@ -54,6 +56,8 @@ commands = sphinx-build -d "{toxworkdir}/docs_doctree" docs "{toxworkdir}/docs_o
DJANGO_SETTINGS_MODULE = tests.settings
testpaths = tests
pythonpath = . drf_standardized_errors
env =
DJANGO_LIVE_TEST_SERVER_ADDRESS=127.0.0.1:8000

[coverage:run]
branch = True
Expand Down
Loading