Skip to content

Commit 228f864

Browse files
committed
feat(protobuf): enhanced type annotations (mypy)
1 parent 74d3ccf commit 228f864

File tree

1 file changed

+17
-56
lines changed

1 file changed

+17
-56
lines changed

lagrange/utils/binary/protobuf/models.py

+17-56
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,25 @@
11
import inspect
2-
import importlib
32
from types import GenericAlias
4-
from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload, ForwardRef
3+
from typing import cast, Dict, List, Tuple, Type, TypeVar, Union, Generic, Any, Callable, Mapping, overload
54
from typing_extensions import Optional, Self, TypeAlias, dataclass_transform
65

76
from .coder import Proto, proto_decode, proto_encode
87

9-
_ProtoTypes = Union[str, list, dict, bytes, int, float, bool, "ProtoStruct"]
8+
_ProtoBasicTypes = Union[str, list, dict, bytes, int, float, bool]
9+
_ProtoTypes = Union[_ProtoBasicTypes, "ProtoStruct"]
1010

1111
T = TypeVar("T", str, list, dict, bytes, int, float, bool, "ProtoStruct")
1212
V = TypeVar("V")
1313
NT: TypeAlias = Dict[int, Union[_ProtoTypes, "NT"]]
14-
AMT: TypeAlias = Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]
15-
DAMT: TypeAlias = Dict[str, "DelayAnnoType"]
16-
DelayAnnoType = Union[str, type(List[str])]
1714
NoneType = type(None)
1815

1916

2017
class ProtoField(Generic[T]):
2118
def __init__(self, tag: int, default: T):
2219
if tag <= 0:
2320
raise ValueError("Tag must be a positive integer")
24-
self._tag = tag
25-
self._default = default
21+
self._tag: int = tag
22+
self._default: T = default
2623

2724
@property
2825
def tag(self) -> int:
@@ -86,16 +83,14 @@ def proto_field(
8683
@dataclass_transform(kw_only_default=True, field_specifiers=(proto_field,))
8784
class ProtoStruct:
8885
_anno_map: Dict[str, Tuple[Type[_ProtoTypes], ProtoField[Any]]]
89-
_delay_anno_map: Dict[str, DelayAnnoType]
9086
_proto_debug: bool
9187

9288
def __init__(self, *args, **kwargs):
9389
undefined_params: List[str] = []
94-
args = list(args)
95-
self._resolve_annotations(self)
90+
arg_list = list(args)
9691
for name, (typ, field) in self._anno_map.items():
9792
if args:
98-
self._set_attr(name, typ, args.pop(0))
93+
self._set_attr(name, typ, arg_list.pop(0))
9994
elif name in kwargs:
10095
self._set_attr(name, typ, kwargs.pop(name))
10196
else:
@@ -104,13 +99,11 @@ def __init__(self, *args, **kwargs):
10499
else:
105100
undefined_params.append(name)
106101
if undefined_params:
107-
raise AttributeError(
108-
"Undefined parameters in '{}': {}".format(self, undefined_params)
109-
)
102+
raise AttributeError(f"Undefined parameters in '{self}': {undefined_params}")
110103

111104
def __init_subclass__(cls, **kwargs):
105+
cls._anno_map = cls._get_annotations()
112106
cls._proto_debug = kwargs.pop("debug") if "debug" in kwargs else False
113-
cls._anno_map, cls._delay_anno_map = cls._get_annotations()
114107
super().__init_subclass__(**kwargs)
115108

116109
def __repr__(self) -> str:
@@ -125,17 +118,14 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None:
125118
if isinstance(data_typ, GenericAlias): # force ignore
126119
pass
127120
elif not isinstance(value, data_typ) and value is not None:
128-
raise TypeError(
129-
"'{}' is not a instance of type '{}'".format(value, data_typ)
130-
)
121+
raise TypeError(f"{value} is not a instance of type {data_typ}")
131122
setattr(self, name, value)
132123

133124
@classmethod
134125
def _get_annotations(
135126
cls,
136-
) -> Tuple[AMT, DAMT]: # Name: (ReturnType, ProtoField)
137-
annotations: AMT = {}
138-
delay_annotations: DAMT = {}
127+
) -> Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]]: # Name: (ReturnType, ProtoField)
128+
annotations: Dict[str, Tuple[Type[_ProtoTypes], "ProtoField"]] = {}
139129
for obj in reversed(inspect.getmro(cls)):
140130
if obj in (ProtoStruct, object): # base object, ignore
141131
continue
@@ -149,34 +139,15 @@ def _get_annotations(
149139
if not isinstance(field, ProtoField):
150140
raise TypeError("attribute '{name}' is not a ProtoField object")
151141

152-
_typ = typ
153-
annotations[name] = (_typ, field)
154-
if isinstance(typ, str):
155-
delay_annotations[name] = typ
156142
if hasattr(typ, "__origin__"):
157-
typ = cast(GenericAlias, typ)
158-
_inner = typ.__args__[0]
159-
_typ = typ.__origin__[typ.__args__[0]]
160-
annotations[name] = (_typ, field)
161-
162-
if isinstance(_inner, type):
163-
continue
164-
if isinstance(_inner, GenericAlias) and isinstance(_inner.__args__[0], type):
165-
continue
166-
if isinstance(_inner, str):
167-
delay_annotations[name] = _typ.__origin__[_inner]
168-
if isinstance(_inner, ForwardRef):
169-
delay_annotations[name] = _inner.__forward_arg__
170-
if isinstance(_inner, GenericAlias):
171-
delay_annotations[name] = _typ
172-
173-
return annotations, delay_annotations
143+
typ = typ.__origin__[typ.__args__[0]]
144+
annotations[name] = (typ, field)
145+
146+
return annotations
174147

175148
@classmethod
176149
def _get_field_mapping(cls) -> Dict[int, Tuple[str, Type[_ProtoTypes]]]: # Tag, (Name, Type)
177150
field_mapping: Dict[int, Tuple[str, Type[_ProtoTypes]]] = {}
178-
if cls._delay_anno_map:
179-
cls._resolve_annotations(cls)
180151
for name, (typ, field) in cls._anno_map.items():
181152
field_mapping[field.tag] = (name, typ)
182153
return field_mapping
@@ -187,17 +158,7 @@ def _get_stored_mapping(self) -> Dict[str, NT]:
187158
stored_mapping[name] = getattr(self, name)
188159
return stored_mapping
189160

190-
@staticmethod
191-
def _resolve_annotations(arg: Union[Type["ProtoStruct"], "ProtoStruct"]) -> None:
192-
for k, v in arg._delay_anno_map.copy().items():
193-
module = importlib.import_module(arg.__module__)
194-
if hasattr(v, "__origin__"): # resolve GenericAlias, such as list[str]
195-
arg._anno_map[k] = (v.__origin__[module.__getattribute__(v.__args__[0])], arg._anno_map[k][1])
196-
else:
197-
arg._anno_map[k] = (module.__getattribute__(v), arg._anno_map[k][1])
198-
arg._delay_anno_map.pop(k)
199-
200-
def _encode(self, v: _ProtoTypes) -> NT:
161+
def _encode(self, v: _ProtoTypes) -> _ProtoBasicTypes:
201162
if isinstance(v, ProtoStruct):
202163
v = v.encode()
203164
return v

0 commit comments

Comments
 (0)