Skip to content

Commit

Permalink
provide global enum naming. #70
Browse files Browse the repository at this point in the history
  • Loading branch information
tfranzel committed May 29, 2020
1 parent 7ef494e commit 2f47eba
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
19 changes: 14 additions & 5 deletions drf_spectacular/plumbing.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,9 @@ def postprocess_schema_enums(result, generator, **kwargs):
the same choices. Aids client generation to not generate a separate enum for
every occurrence. only takes effect when replacement is guaranteed to be correct.
"""
def choice_hash(choices):
return hash(json.dumps(choices, sort_keys=True))

def iter_prop_containers(schema):
if isinstance(schema, list):
for item in schema:
Expand All @@ -481,15 +484,18 @@ def iter_prop_containers(schema):

schemas = result.get('components', {}).get('schemas', {})

overrides = {
choice_hash(list(dict(choices).keys())): name
for name, choices in spectacular_settings.ENUM_NAME_OVERRIDES.items()
}

hash_mapping = defaultdict(set)
# collect all enums, their names and contents
for props in iter_prop_containers(list(schemas.values())):
for prop_name, prop_schema in props.items():
if 'enum' not in prop_schema:
continue
hash_mapping[prop_name].add(
hash(json.dumps(prop_schema['enum'], sort_keys=True))
)
hash_mapping[prop_name].add(choice_hash(prop_schema['enum']))
# safe replacement requires name to have only one set of enum values
candidate_enums = {
prop_name for prop_name, prop_hash_set in hash_mapping.items()
Expand All @@ -500,10 +506,13 @@ def iter_prop_containers(schema):
for prop_name, prop_schema in props.items():
if 'enum' not in prop_schema:
continue
if prop_name not in candidate_enums:
elif choice_hash(prop_schema['enum']) in overrides:
enum_name = overrides[choice_hash(prop_schema['enum'])]
elif prop_name in candidate_enums:
enum_name = f'{inflection.camelize(prop_name)}Enum'
else:
continue

enum_name = f'{inflection.camelize(prop_name)}Enum'
enum_schema = {k: v for k, v in prop_schema.items() if k in ['type', 'enum']}
prop_schema = {k: v for k, v in prop_schema.items() if k not in ['type', 'enum']}

Expand Down
3 changes: 3 additions & 0 deletions drf_spectacular/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
'drf_spectacular.plumbing.postprocess_schema_enums'
],

# enum name overrides. dict with keys "YourEnum" and their choice values "field.choices"
'ENUM_NAME_OVERRIDES': {},

# General schema metadata. Refer to spec for valid inputs
# https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.3.md#openapi-object
'TITLE': '',
Expand Down
22 changes: 21 additions & 1 deletion tests/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from rest_framework import serializers, viewsets, mixins
from unittest import mock

from rest_framework import serializers, viewsets, mixins, generics
from rest_framework.decorators import action

from drf_spectacular.utils import extend_schema
Expand Down Expand Up @@ -32,3 +34,21 @@ def selection(self, request):
def test_postprocessing(no_warnings):
schema = generate_schema('a', AViewset)
assert_schema(schema, 'tests/test_postprocessing.yml')


@mock.patch('drf_spectacular.settings.spectacular_settings.ENUM_NAME_OVERRIDES', {
'LanguageEnum': language_choices
})
def test_global_enum_naming_override(no_warnings):

class XSerializer(serializers.Serializer):
foo = serializers.ChoiceField(choices=language_choices)
bar = serializers.ChoiceField(choices=language_choices)

class XView(generics.RetrieveAPIView):
serializer_class = XSerializer

schema = generate_schema('/x', view=XView)
assert 'LanguageEnum' in schema['components']['schemas']['X']['properties']['foo']['$ref']
assert 'LanguageEnum' in schema['components']['schemas']['X']['properties']['bar']['$ref']
assert len(schema['components']['schemas']) == 2

0 comments on commit 2f47eba

Please # to comment.