diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index 84645516..e20449fb 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -808,12 +808,18 @@ def _get_response_for_code(self, serializer): def _get_serializer_name(self, serializer, direction): serializer_extension = OpenApiSerializerExtension.get_match(serializer) if serializer_extension and serializer_extension.get_name(): - return serializer_extension.get_name() - - name = serializer.__class__.__name__ + # library override mechanisms + name = serializer_extension.get_name() + elif getattr(getattr(serializer, 'Meta', None), 'ref_name', None) is not None: + # local override mechanisms. for compatibility with drf-yasg we support meta ref_name, + # though we do not support the serializer inlining feature. + # https://drf-yasg.readthedocs.io/en/stable/custom_spec.html#serializer-meta-nested-class + name = serializer.Meta.ref_name + else: + name = serializer.__class__.__name__ + if name.endswith('Serializer'): + name = name[:-10] - if name.endswith('Serializer'): - name = name[:-10] if self.method == 'PATCH' and not serializer.read_only and direction == 'request': name = 'Patched' + name diff --git a/tests/test_regressions.py b/tests/test_regressions.py index 5b001659..eb21d1c5 100644 --- a/tests/test_regressions.py +++ b/tests/test_regressions.py @@ -356,7 +356,7 @@ def pi(request, foo): assert parameter['schema']['type'] == 'integer' -def test_serializer_naming_collision_resolution(no_warnings): +def test_lib_serializer_naming_collision_resolution(no_warnings): """ parity test in tests.test_warnings.test_serializer_name_reuse """ def x_lib1(): class XSerializer(serializers.Serializer): @@ -390,3 +390,35 @@ def get_name(self): assert request_component == '#/components/schemas/X' response_component = operation['responses']['200']['content']['application/json']['schema']['$ref'] assert response_component == '#/components/schemas/RenamedLib2X' + + +def test_owned_serializer_naming_override_with_ref_name(no_warnings): + def x_owned1(): + class XSerializer(serializers.Serializer): + x = serializers.UUIDField() + + return XSerializer + + def x_owned2(): + class XSerializer(serializers.Serializer): + x = serializers.IntegerField() + + class Meta: + ref_name = 'Y' + + return XSerializer + + x_owned1, x_owned2 = x_owned1(), x_owned2() + + class XAPIView(APIView): + @extend_schema(request=x_owned1, responses=x_owned2) + def post(self, request): + pass # pragma: no cover + + schema = generate_schema('/x', view=XAPIView) + + operation = schema['paths']['/x']['post'] + request_component = operation['requestBody']['content']['application/json']['schema']['$ref'] + assert request_component == '#/components/schemas/X' + response_component = operation['responses']['200']['content']['application/json']['schema']['$ref'] + assert response_component == '#/components/schemas/Y' diff --git a/tests/test_warnings.py b/tests/test_warnings.py index 5b6a86f2..8e21c7e1 100644 --- a/tests/test_warnings.py +++ b/tests/test_warnings.py @@ -2,6 +2,7 @@ from rest_framework import serializers, mixins, viewsets, views from rest_framework.authentication import BaseAuthentication from rest_framework.decorators import action +from rest_framework.views import APIView from drf_spectacular.utils import extend_schema from drf_spectacular.validation import validate_schema @@ -39,6 +40,24 @@ class X2Viewset(mixins.ListModelMixin, viewsets.GenericViewSet): generator.get_schema(request=None, public=True) +def test_owned_serializer_naming_override_with_ref_name_collision(warnings): + class XSerializer(serializers.Serializer): + x = serializers.UUIDField() + + class YSerializer(serializers.Serializer): + x = serializers.IntegerField() + + class Meta: + ref_name = 'X' # already used above + + class XAPIView(APIView): + @extend_schema(request=XSerializer, responses=YSerializer) + def post(self, request): + pass # pragma: no cover + + generate_schema('/x', view=XAPIView) + + def test_no_queryset_warn(capsys): class X1Serializer(serializers.Serializer): uuid = serializers.UUIDField()