From 94190bb07f3dc70eeee9d47379f684d5817f35c5 Mon Sep 17 00:00:00 2001 From: "T. Franzel" Date: Thu, 28 Jul 2022 21:46:54 +0200 Subject: [PATCH] Extend OpenApiSerializerExtension interface. #392 #705 1. Allows get_name with full parameters without breaking legacy code. 2. enable calling ``auto_schema.resolve_serializer`` with ``bypass_extensions=True`` to allow offloading component creation --- drf_spectacular/extensions.py | 2 +- drf_spectacular/openapi.py | 28 +++++++++++------- drf_spectacular/plumbing.py | 7 +++++ tests/test_extensions.py | 55 ++++++++++++++++++++++++++++++++++- 4 files changed, 79 insertions(+), 13 deletions(-) diff --git a/drf_spectacular/extensions.py b/drf_spectacular/extensions.py index 8e4ed150..d3dc656a 100644 --- a/drf_spectacular/extensions.py +++ b/drf_spectacular/extensions.py @@ -55,7 +55,7 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt """ _registry: List['OpenApiSerializerExtension'] = [] - def get_name(self) -> Optional[str]: + def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]: """ return str for overriding default name extraction """ return None diff --git a/drf_spectacular/openapi.py b/drf_spectacular/openapi.py index d397fa76..f0183161 100644 --- a/drf_spectacular/openapi.py +++ b/drf_spectacular/openapi.py @@ -28,10 +28,10 @@ ComponentRegistry, ResolvedComponent, UnableToProceedError, append_meta, assert_basic_serializer, build_array_type, build_basic_type, build_choice_field, build_examples_list, build_generic_type, build_listed_example_value, build_media_type_object, - build_mocked_view, build_object_type, build_parameter_type, error, follow_field_source, - follow_model_field_lookup, force_instance, get_doc, get_list_serializer, get_type_hints, - get_view_model, is_basic_serializer, is_basic_type, is_field, is_list_serializer, - is_list_serializer_customized, is_patched_serializer, is_serializer, + build_mocked_view, build_object_type, build_parameter_type, error, filter_supported_arguments, + follow_field_source, follow_model_field_lookup, force_instance, get_doc, get_list_serializer, + get_type_hints, get_view_model, is_basic_serializer, is_basic_type, is_field, + is_list_serializer, is_list_serializer_customized, is_patched_serializer, is_serializer, is_trivial_string_variation, modify_media_types_for_versioning, resolve_django_path_parameter, resolve_regex_path_parameter, resolve_type_hint, safe_ref, sanitize_specification_extensions, warn, whitelisted, @@ -1422,11 +1422,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> d return result - def _get_serializer_name(self, serializer, direction): + def _get_serializer_name(self, serializer, direction, bypass_extensions=False): serializer_extension = OpenApiSerializerExtension.get_match(serializer) - if serializer_extension and serializer_extension.get_name(): - # library override mechanisms - name = serializer_extension.get_name() + if serializer_extension and not bypass_extensions: + custom_name = serializer_extension.get_name(**filter_supported_arguments( + serializer_extension.get_name, auto_schema=self, direction=direction + )) + else: + custom_name = None + + if custom_name: + name = custom_name elif has_override(serializer, 'component_name'): name = get_override(serializer, 'component_name') elif getattr(getattr(serializer, 'Meta', None), 'ref_name', None) is not None: @@ -1457,13 +1463,13 @@ def _get_serializer_name(self, serializer, direction): return name - def resolve_serializer(self, serializer, direction) -> ResolvedComponent: + def resolve_serializer(self, serializer, direction, bypass_extensions=False) -> ResolvedComponent: assert_basic_serializer(serializer) serializer = force_instance(serializer) with add_trace_message(serializer.__class__.__name__): component = ResolvedComponent( - name=self._get_serializer_name(serializer, direction), + name=self._get_serializer_name(serializer, direction, bypass_extensions), type=ResolvedComponent.SCHEMA, object=serializer, ) @@ -1471,7 +1477,7 @@ def resolve_serializer(self, serializer, direction) -> ResolvedComponent: return self.registry[component] # return component with schema self.registry.register(component) - component.schema = self._map_serializer(serializer, direction) + component.schema = self._map_serializer(serializer, direction, bypass_extensions) discard_component = ( # components with empty schemas serve no purpose diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index b320f9a7..1f20c13a 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -1275,3 +1275,10 @@ def build_listed_example_value(value: Any, paginator, direction): f"provide example values themselves. Using the plain example value as fallback." ) return value + + +def filter_supported_arguments(func, **kwargs): + sig = inspect.signature(func) + return { + arg: val for arg, val in kwargs.items() if arg in sig.parameters + } diff --git a/tests/test_extensions.py b/tests/test_extensions.py index a673e1f2..7ef2a93d 100644 --- a/tests/test_extensions.py +++ b/tests/test_extensions.py @@ -1,3 +1,6 @@ +from typing import TYPE_CHECKING +from unittest import mock + from rest_framework import fields, mixins, permissions, serializers, viewsets from rest_framework.authentication import BaseAuthentication from rest_framework.decorators import api_view @@ -12,10 +15,13 @@ ResolvedComponent, build_array_type, build_basic_type, build_object_type, ) from drf_spectacular.types import OpenApiTypes -from drf_spectacular.utils import extend_schema +from drf_spectacular.utils import Direction, extend_schema from tests import generate_schema, get_response_schema from tests.models import SimpleModel, SimpleSerializer +if TYPE_CHECKING: + from drf_spectacular.openapi import AutoSchema + class Base64Field(fields.Field): pass # pragma: no cover @@ -193,3 +199,50 @@ class XViewset(mixins.ListModelMixin, viewsets.GenericViewSet): 'data': {'type': 'array', 'items': {'$ref': '#/components/schemas/X'}} } } + + +@mock.patch('drf_spectacular.settings.spectacular_settings.COMPONENT_SPLIT_REQUEST', True) +def test_serializer_envelope_through_extension(no_warnings): + class EnvelopeMixin: + pass + + # actual enveloping not implemented. This could be done internally with + # to_representation or externally with a custom Renderer + class XSerializer(EnvelopeMixin, serializers.ModelSerializer): + name = serializers.CharField() + + class Meta: + model = SimpleModel + fields = '__all__' + envelope = 'foo' # some arbitrary addition to Meta for example + + class EnvelopeFix(OpenApiSerializerExtension): + target_class = EnvelopeMixin + match_subclasses = True + + def get_name(self, auto_schema: 'AutoSchema', direction: Direction): + if direction == 'request': + return None + else: + return f"Enveloped{self.target.__class__.__name__}" + + def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction): + if direction == 'request': + return auto_schema._map_serializer(self.target, direction, bypass_extensions=True) + else: + component = auto_schema.resolve_serializer(self.target, direction, bypass_extensions=True) + if not component: + return {} + return build_object_type( + properties={self.target.Meta.envelope: component.ref} + ) + + class XViewset(viewsets.ModelViewSet): + serializer_class = XSerializer + queryset = SimpleModel.objects.none() + + schema = generate_schema('/x', XViewset) + assert 'X' in schema['components']['schemas'] + assert 'EnvelopedX' in schema['components']['schemas'] + assert 'XRequest' in schema['components']['schemas'] + assert 'PatchedXRequest' in schema['components']['schemas']