Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Simplify cattrs._compat.is_typeddict #384

Merged
merged 9 commits into from
Jun 14, 2023
1 change: 1 addition & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Optimize and improve unstructuring of `Optional` (unions of one type and `None`).
([#380](https://github.com/python-attrs/cattrs/issues/380) [#381](https://github.com/python-attrs/cattrs/pull/381))
- Fix `format_exception` and `transform_error` type annotations.
- Improve the implementation of `cattrs._compat.is_typeddict`. The implementation is now simpler, and relies on fewer private implementation details from `typing` and typing_extensions. ([#384](https://github.com/python-attrs/cattrs/pull/384))

## 23.1.2 (2023-06-02)

Expand Down
54 changes: 16 additions & 38 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,13 @@
from attr import fields as attrs_fields
from attr import resolve_types

__all__ = ["ExceptionGroup", "ExtensionsTypedDict", "TypedDict", "is_typeddict"]

try:
from typing_extensions import TypedDict as ExtensionsTypedDict
except ImportError:
ExtensionsTypedDict = None

try:
from typing_extensions import _TypedDictMeta as ExtensionsTypedDictMeta
except ImportError:
ExtensionsTypedDictMeta = None

if sys.version_info >= (3, 8):
from typing import Final, Protocol, get_args, get_origin

Expand All @@ -44,9 +41,20 @@ def get_origin(cl):
from typing_extensions import Final, Protocol

if sys.version_info >= (3, 11):
ExceptionGroup = ExceptionGroup
from builtins import ExceptionGroup
else:
from exceptiongroup import ExceptionGroup as ExceptionGroup # noqa: PLC0414
from exceptiongroup import ExceptionGroup

try:
from typing_extensions import is_typeddict as _is_typeddict
except ImportError:
assert sys.version_info >= (3, 10)
from typing import is_typeddict as _is_typeddict


def is_typeddict(cls):
"""Thin wrapper around typing(_extensions).is_typeddict"""
return _is_typeddict(getattr(cls, "__origin__", cls))


def has(cls):
Expand Down Expand Up @@ -157,7 +165,6 @@ def get_final_base(type) -> Optional[type]:
_AnnotatedAlias,
_GenericAlias,
_SpecialGenericAlias,
_TypedDictMeta,
_UnionGenericAlias,
)

Expand Down Expand Up @@ -234,20 +241,6 @@ def get_newtype_base(typ: Any) -> Optional[type]:
return supertype
return None

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down Expand Up @@ -364,9 +357,8 @@ def copy_with(type, args):
from typing_extensions import get_origin as te_get_origin

if sys.version_info >= (3, 8):
from typing import TypedDict, _TypedDictMeta
from typing import TypedDict
else:
_TypedDictMeta = None
TypedDict = ExtensionsTypedDict

def is_annotated(type) -> bool:
Expand Down Expand Up @@ -462,20 +454,6 @@ def copy_with(type, args):
"""Replace a generic type's arguments."""
return type.copy_with(args)

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down