Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Simplify cattrs._compat.is_typeddict #384

Merged
merged 9 commits into from
Jun 14, 2023
36 changes: 5 additions & 31 deletions src/cattrs/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,18 @@
from attr import fields as attrs_fields
from attr import resolve_types

__all__ = ["ExtensionsTypedDict", "is_typeddict", "TypedDict"]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like __all__ was deleted from this file in #382, but I'm not entirely sure why. I've added it back here, as otherwise ruff complains (reasonably) that is_typeddict is imported but unused. The alternative would be to do from typing_extensions import is_typeddict as is_typeddict, but I find that syntax really ugly personally :p

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

whoops, my bad. You might wanna add ExceptionGroup too which uses the x as x syntax currently.


try:
from typing_extensions import TypedDict as ExtensionsTypedDict
except ImportError:
ExtensionsTypedDict = None

try:
from typing_extensions import _TypedDictMeta as ExtensionsTypedDictMeta
from typing_extensions import is_typeddict
except ImportError:
ExtensionsTypedDictMeta = None
assert sys.version_info >= (3, 10)
from typing import is_typeddict

if sys.version_info >= (3, 8):
from typing import Final, Protocol, get_args, get_origin
Expand Down Expand Up @@ -157,7 +160,6 @@ def get_final_base(type) -> Optional[type]:
_AnnotatedAlias,
_GenericAlias,
_SpecialGenericAlias,
_TypedDictMeta,
_UnionGenericAlias,
)

Expand Down Expand Up @@ -234,20 +236,6 @@ def get_newtype_base(typ: Any) -> Optional[type]:
return supertype
return None

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down Expand Up @@ -462,20 +450,6 @@ def copy_with(type, args):
"""Replace a generic type's arguments."""
return type.copy_with(args)

def is_typeddict(cls) -> bool:
return (
cls.__class__ is _TypedDictMeta
or (is_generic(cls) and (cls.__origin__.__class__ is _TypedDictMeta))
or (
ExtensionsTypedDictMeta is not None
and cls.__class__ is ExtensionsTypedDictMeta
or (
is_generic(cls)
and (cls.__origin__.__class__ is ExtensionsTypedDictMeta)
)
)
)

def get_notrequired_base(type) -> "Union[Any, Literal[NOTHING]]":
if get_origin(type) in (NotRequired, Required):
return get_args(type)[0]
Expand Down