From fa0ee03e82d2b73dc00bdeb75b266aba99be8d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9r=C3=B4me=20Lafr=C3=A9choux?= Date: Sun, 4 Sep 2022 18:22:59 +0200 Subject: [PATCH] Merge EnumSymbol and EnumValue into Enum field --- src/marshmallow/fields.py | 91 +++++++++++++++-------------------- tests/base.py | 6 +-- tests/test_deserialization.py | 36 +++++++------- tests/test_serialization.py | 12 ++--- 4 files changed, 67 insertions(+), 78 deletions(-) diff --git a/src/marshmallow/fields.py b/src/marshmallow/fields.py index 1dc747143..8568aee5d 100644 --- a/src/marshmallow/fields.py +++ b/src/marshmallow/fields.py @@ -11,7 +11,7 @@ import math import typing import warnings -from enum import Enum +from enum import Enum as EnumType from collections.abc import Mapping as _Mapping from marshmallow import validate, utils, class_registry, types @@ -60,8 +60,7 @@ "IPInterface", "IPv4Interface", "IPv6Interface", - "EnumSymbol", - "EnumValue", + "Enum", "Method", "Function", "Str", @@ -1856,43 +1855,14 @@ class IPv6Interface(IPInterface): DESERIALIZATION_CLASS = ipaddress.IPv6Interface -class EnumSymbol(String): - """An Enum field (de)serializing enum members by symbol (name) as string. +class Enum(Field): + """An Enum field (de)serializing enum members by symbol (name) as string or by value. :param enum Enum: Enum class - - .. versionadded:: 3.18.0 - """ - - default_error_messages = { - "unknown": "Must be one of: {choices}.", - } - - def __init__(self, enum: type[Enum], **kwargs): - self.enum = enum - self.choices = ", ".join(enum.__members__) - super().__init__(**kwargs) - - def _serialize(self, value, attr, obj, **kwargs): - if value is None: - return None - return value.name - - def _deserialize(self, value, attr, data, **kwargs): - value = super()._deserialize(value, attr, data, **kwargs) - try: - return getattr(self.enum, value) - except AttributeError as exc: - raise self.make_error("unknown", choices=self.choices) from exc - - -class EnumValue(Field): - """An Enum field (de)serializing enum members by value. - - A Field must be provided to (de)serialize the value. - :param cls_or_instance: Field class or instance. - :param enum Enum: Enum class + + If a field is provided as ``cls_or_instance`` argument, the Enum is (de)serialized by + value using this field. Otherwise, it is (de)serialized by symbol (name) as string. .. versionadded:: 3.18.0 """ @@ -1901,30 +1871,49 @@ class EnumValue(Field): "unknown": "Must be one of: {choices}.", } - def __init__(self, cls_or_instance: Field | type, enum: type[Enum], **kwargs): + def __init__( + self, + enum: type[EnumType], + cls_or_instance: Field | type | None = None, + **kwargs, + ): super().__init__(**kwargs) - try: - self.field = resolve_field_instance(cls_or_instance) - except FieldInstanceResolutionError as error: - raise ValueError( - "The enum field must be a subclass or instance of " - "marshmallow.base.FieldABC." - ) from error self.enum = enum - self.choices = ", ".join( - [str(self.field._serialize(m.value, None, None)) for m in enum] - ) + if cls_or_instance is not None: + try: + self.field = resolve_field_instance(cls_or_instance) + except FieldInstanceResolutionError as error: + raise ValueError( + "The enum field must be a subclass or instance of " + "marshmallow.base.FieldABC." + ) from error + self.by_symbol_or_value = "value" + self.choices = ", ".join( + [str(self.field._serialize(m.value, None, None)) for m in enum] + ) + else: + self.field = String() + self.by_symbol_or_value = "symbol" + self.choices = ", ".join(enum.__members__) def _serialize(self, value, attr, obj, **kwargs): if value is None: return None - return self.field._serialize(value.value, attr, obj, **kwargs) + if self.by_symbol_or_value == "value": + return self.field._serialize(value.value, attr, obj, **kwargs) + return value.name def _deserialize(self, value, attr, data, **kwargs): + if self.by_symbol_or_value == "value": + value = self.field._deserialize(value, attr, data, **kwargs) + try: + return self.enum(value) + except ValueError as exc: + raise self.make_error("unknown", choices=self.choices) from exc value = self.field._deserialize(value, attr, data, **kwargs) try: - return self.enum(value) - except ValueError as exc: + return getattr(self.enum, value) + except AttributeError as exc: raise self.make_error("unknown", choices=self.choices) from exc diff --git a/tests/base.py b/tests/base.py index 4dbbcd5fc..36014c35e 100644 --- a/tests/base.py +++ b/tests/base.py @@ -54,9 +54,9 @@ class DateEnum(Enum): fields.IPInterface, fields.IPv4Interface, fields.IPv6Interface, - functools.partial(fields.EnumSymbol, GenderEnum), - functools.partial(fields.EnumValue, fields.String, HairColorEnum), - functools.partial(fields.EnumValue, fields.Integer, GenderEnum), + functools.partial(fields.Enum, GenderEnum), + functools.partial(fields.Enum, HairColorEnum, fields.String), + functools.partial(fields.Enum, GenderEnum, fields.Integer), ] diff --git a/tests/test_deserialization.py b/tests/test_deserialization.py index b922d6a56..48fb96260 100644 --- a/tests/test_deserialization.py +++ b/tests/test_deserialization.py @@ -1097,54 +1097,54 @@ def test_invalid_ipv6interface_deserialization(self, in_value): assert excinfo.value.args[0] == "Not a valid IPv6 interface." - def test_enumsymbol_field_deserialization(self): - field = fields.EnumSymbol(GenderEnum) + def test_enum_by_symbol_field_deserialization(self): + field = fields.Enum(GenderEnum) assert field.deserialize("male") == GenderEnum.male - def test_enumsymbol_field_invalid_value(self): - field = fields.EnumSymbol(GenderEnum) + def test_enum_by_symbol_field_invalid_value(self): + field = fields.Enum(GenderEnum) with pytest.raises( ValidationError, match="Must be one of: male, female, non_binary." ): field.deserialize("dummy") - def test_enumsymbol_field_not_string(self): - field = fields.EnumSymbol(GenderEnum) + def test_enum_by_symbol_field_not_string(self): + field = fields.Enum(GenderEnum) with pytest.raises(ValidationError, match="Not a valid string."): field.deserialize(12) - def test_enumvalue_field_deserialization(self): - field = fields.EnumValue(fields.String, HairColorEnum) + def test_enum_by_value_field_deserialization(self): + field = fields.Enum(HairColorEnum, fields.String) assert field.deserialize("black hair") == HairColorEnum.black - field = fields.EnumValue(fields.Integer, GenderEnum) + field = fields.Enum(GenderEnum, fields.Integer) assert field.deserialize(1) == GenderEnum.male - field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y")) assert field.deserialize("29/02/2004") == DateEnum.date_1 - def test_enumvalue_field_invalid_value(self): - field = fields.EnumValue(fields.String, HairColorEnum) + def test_enum_by_value_field_invalid_value(self): + field = fields.Enum(HairColorEnum, fields.String) with pytest.raises( ValidationError, match="Must be one of: black hair, brown hair, blond hair, red hair.", ): field.deserialize("dummy") - field = fields.EnumValue(fields.Integer, GenderEnum) + field = fields.Enum(GenderEnum, fields.Integer) with pytest.raises(ValidationError, match="Must be one of: 1, 2, 3."): field.deserialize(12) - field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y")) with pytest.raises( ValidationError, match="Must be one of: 29/02/2004, 29/02/2008, 29/02/2012." ): field.deserialize("28/02/2004") - def test_enumvalue_field_wrong_type(self): - field = fields.EnumValue(fields.String, HairColorEnum) + def test_enum_by_value_field_wrong_type(self): + field = fields.Enum(HairColorEnum, fields.String) with pytest.raises(ValidationError, match="Not a valid string."): field.deserialize(12) - field = fields.EnumValue(fields.Integer, GenderEnum) + field = fields.Enum(GenderEnum, fields.Integer) with pytest.raises(ValidationError, match="Not a valid integer."): field.deserialize("dummy") - field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y")) with pytest.raises(ValidationError, match="Not a valid date."): field.deserialize("30/02/2004") diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8e751e5ce..216bddc2a 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -255,20 +255,20 @@ def test_ipv6_interface_field(self, user): == ipv6interface_exploded_string ) - def test_enumsymbol_field_serialization(self, user): + def test_enum_by_symbol_field_serialization(self, user): user.sex = GenderEnum.male - field = fields.EnumSymbol(GenderEnum) + field = fields.Enum(GenderEnum) assert field.serialize("sex", user) == "male" - def test_enumvalue_field_serialization(self, user): + def test_enum_by_value_field_serialization(self, user): user.hair_color = HairColorEnum.black - field = fields.EnumValue(fields.String, HairColorEnum) + field = fields.Enum(HairColorEnum, fields.String) assert field.serialize("hair_color", user) == "black hair" user.sex = GenderEnum.male - field = fields.EnumValue(fields.Integer, GenderEnum) + field = fields.Enum(GenderEnum, fields.Integer) assert field.serialize("sex", user) == 1 user.some_date = DateEnum.date_1 - field = fields.EnumValue(fields.Date(format="%d/%m/%Y"), DateEnum) + field = fields.Enum(DateEnum, fields.Date(format="%d/%m/%Y")) assert field.serialize("some_date", user) == "29/02/2004" def test_decimal_field(self, user):