Skip to content

Commit

Permalink
Merge pull request #779 from tfranzel/extend_extension_interface
Browse files Browse the repository at this point in the history
Extend OpenApiSerializerExtension interface. #392 #705
  • Loading branch information
tfranzel authored Aug 5, 2022
2 parents 5408d37 + 94190bb commit 8ec2a8d
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 13 deletions.
2 changes: 1 addition & 1 deletion drf_spectacular/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class OpenApiSerializerExtension(OpenApiGeneratorExtension['OpenApiSerializerExt
"""
_registry: List['OpenApiSerializerExtension'] = []

def get_name(self) -> Optional[str]:
def get_name(self, auto_schema: 'AutoSchema', direction: Direction) -> Optional[str]:
""" return str for overriding default name extraction """
return None

Expand Down
28 changes: 17 additions & 11 deletions drf_spectacular/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
ComponentRegistry, ResolvedComponent, UnableToProceedError, append_meta,
assert_basic_serializer, build_array_type, build_basic_type, build_choice_field,
build_examples_list, build_generic_type, build_listed_example_value, build_media_type_object,
build_mocked_view, build_object_type, build_parameter_type, error, follow_field_source,
follow_model_field_lookup, force_instance, get_doc, get_list_serializer, get_type_hints,
get_view_model, is_basic_serializer, is_basic_type, is_field, is_list_serializer,
is_list_serializer_customized, is_patched_serializer, is_serializer,
build_mocked_view, build_object_type, build_parameter_type, error, filter_supported_arguments,
follow_field_source, follow_model_field_lookup, force_instance, get_doc, get_list_serializer,
get_type_hints, get_view_model, is_basic_serializer, is_basic_type, is_field,
is_list_serializer, is_list_serializer_customized, is_patched_serializer, is_serializer,
is_trivial_string_variation, modify_media_types_for_versioning, resolve_django_path_parameter,
resolve_regex_path_parameter, resolve_type_hint, safe_ref, sanitize_specification_extensions,
warn, whitelisted,
Expand Down Expand Up @@ -1422,11 +1422,17 @@ def _get_response_headers_for_code(self, status_code, direction='response') -> d

return result

def _get_serializer_name(self, serializer, direction):
def _get_serializer_name(self, serializer, direction, bypass_extensions=False):
serializer_extension = OpenApiSerializerExtension.get_match(serializer)
if serializer_extension and serializer_extension.get_name():
# library override mechanisms
name = serializer_extension.get_name()
if serializer_extension and not bypass_extensions:
custom_name = serializer_extension.get_name(**filter_supported_arguments(
serializer_extension.get_name, auto_schema=self, direction=direction
))
else:
custom_name = None

if custom_name:
name = custom_name
elif has_override(serializer, 'component_name'):
name = get_override(serializer, 'component_name')
elif getattr(getattr(serializer, 'Meta', None), 'ref_name', None) is not None:
Expand Down Expand Up @@ -1457,21 +1463,21 @@ def _get_serializer_name(self, serializer, direction):

return name

def resolve_serializer(self, serializer, direction) -> ResolvedComponent:
def resolve_serializer(self, serializer, direction, bypass_extensions=False) -> ResolvedComponent:
assert_basic_serializer(serializer)
serializer = force_instance(serializer)

with add_trace_message(serializer.__class__.__name__):
component = ResolvedComponent(
name=self._get_serializer_name(serializer, direction),
name=self._get_serializer_name(serializer, direction, bypass_extensions),
type=ResolvedComponent.SCHEMA,
object=serializer,
)
if component in self.registry:
return self.registry[component] # return component with schema

self.registry.register(component)
component.schema = self._map_serializer(serializer, direction)
component.schema = self._map_serializer(serializer, direction, bypass_extensions)

discard_component = (
# components with empty schemas serve no purpose
Expand Down
7 changes: 7 additions & 0 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1275,3 +1275,10 @@ def build_listed_example_value(value: Any, paginator, direction):
f"provide example values themselves. Using the plain example value as fallback."
)
return value


def filter_supported_arguments(func, **kwargs):
sig = inspect.signature(func)
return {
arg: val for arg, val in kwargs.items() if arg in sig.parameters
}
55 changes: 54 additions & 1 deletion tests/test_extensions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import TYPE_CHECKING
from unittest import mock

from rest_framework import fields, mixins, permissions, serializers, viewsets
from rest_framework.authentication import BaseAuthentication
from rest_framework.decorators import api_view
Expand All @@ -12,10 +15,13 @@
ResolvedComponent, build_array_type, build_basic_type, build_object_type,
)
from drf_spectacular.types import OpenApiTypes
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import Direction, extend_schema
from tests import generate_schema, get_response_schema
from tests.models import SimpleModel, SimpleSerializer

if TYPE_CHECKING:
from drf_spectacular.openapi import AutoSchema


class Base64Field(fields.Field):
pass # pragma: no cover
Expand Down Expand Up @@ -193,3 +199,50 @@ class XViewset(mixins.ListModelMixin, viewsets.GenericViewSet):
'data': {'type': 'array', 'items': {'$ref': '#/components/schemas/X'}}
}
}


@mock.patch('drf_spectacular.settings.spectacular_settings.COMPONENT_SPLIT_REQUEST', True)
def test_serializer_envelope_through_extension(no_warnings):
class EnvelopeMixin:
pass

# actual enveloping not implemented. This could be done internally with
# to_representation or externally with a custom Renderer
class XSerializer(EnvelopeMixin, serializers.ModelSerializer):
name = serializers.CharField()

class Meta:
model = SimpleModel
fields = '__all__'
envelope = 'foo' # some arbitrary addition to Meta for example

class EnvelopeFix(OpenApiSerializerExtension):
target_class = EnvelopeMixin
match_subclasses = True

def get_name(self, auto_schema: 'AutoSchema', direction: Direction):
if direction == 'request':
return None
else:
return f"Enveloped{self.target.__class__.__name__}"

def map_serializer(self, auto_schema: 'AutoSchema', direction: Direction):
if direction == 'request':
return auto_schema._map_serializer(self.target, direction, bypass_extensions=True)
else:
component = auto_schema.resolve_serializer(self.target, direction, bypass_extensions=True)
if not component:
return {}
return build_object_type(
properties={self.target.Meta.envelope: component.ref}
)

class XViewset(viewsets.ModelViewSet):
serializer_class = XSerializer
queryset = SimpleModel.objects.none()

schema = generate_schema('/x', XViewset)
assert 'X' in schema['components']['schemas']
assert 'EnvelopedX' in schema['components']['schemas']
assert 'XRequest' in schema['components']['schemas']
assert 'PatchedXRequest' in schema['components']['schemas']

0 comments on commit 8ec2a8d

Please # to comment.