diff --git a/drf_spectacular/plumbing.py b/drf_spectacular/plumbing.py index 6b5df736..39d439f9 100644 --- a/drf_spectacular/plumbing.py +++ b/drf_spectacular/plumbing.py @@ -969,7 +969,7 @@ def resolve_type_hint(hint): if all(type(args[0]) is type(choice) for choice in args): schema.update(build_basic_type(type(args[0]))) return schema - elif inspect.isclass(hint) and issubclass(hint, Choices): + elif inspect.isclass(hint) and issubclass(hint, Enum): return { 'enum': [item.value for item in hint], **build_basic_type([t for t in hint.__mro__ if is_basic_type(t)][0]) diff --git a/tests/test_plumbing.py b/tests/test_plumbing.py index 69272a4e..d70e9ba7 100644 --- a/tests/test_plumbing.py +++ b/tests/test_plumbing.py @@ -4,6 +4,7 @@ import sys import typing from datetime import datetime +from enum import Enum import pytest from django import __version__ as DJANGO_VERSION @@ -90,6 +91,11 @@ class NamedTupleB(typing.NamedTuple): b: str +class LanguageEnum(str, Enum): + EN = 'en' + DE = 'de' + + TYPE_HINT_TEST_PARAMS = [ ( typing.Optional[int], @@ -136,6 +142,11 @@ class NamedTupleB(typing.NamedTuple): ) ] +TYPE_HINT_TEST_PARAMS.append(( + LanguageEnum, + {'enum': ['en', 'de'], 'type': 'string'} +)) + if DJANGO_VERSION > '3': from django.db.models.enums import TextChoices # only available in Django>3 diff --git a/tests/test_postprocessing.py b/tests/test_postprocessing.py index 5a7429d9..4dfcbcf3 100644 --- a/tests/test_postprocessing.py +++ b/tests/test_postprocessing.py @@ -37,6 +37,10 @@ class LanguageEnum(Enum): EN = 'en' +class LanguageStrEnum(str, Enum): + EN = 'en' + + class LanguageChoices(TextChoices): EN = 'en' @@ -174,7 +178,7 @@ def partial_update(self, request): def test_enum_override_variations(no_warnings): - enum_override_variations = ['language_list', 'LanguageEnum'] + enum_override_variations = ['language_list', 'LanguageEnum', 'LanguageStrEnum'] if DJANGO_VERSION > '3': enum_override_variations += ['LanguageChoices', 'LanguageChoices.choices']