1
1
import inspect
2
- import importlib
3
2
from types import GenericAlias
4
- from typing import cast , Dict , List , Tuple , Type , TypeVar , Union , Generic , Any , Callable , Mapping , overload , ForwardRef
3
+ from typing import cast , Dict , List , Tuple , Type , TypeVar , Union , Generic , Any , Callable , Mapping , overload
5
4
from typing_extensions import Optional , Self , TypeAlias , dataclass_transform
6
5
7
6
from .coder import Proto , proto_decode , proto_encode
8
7
9
- _ProtoTypes = Union [str , list , dict , bytes , int , float , bool , "ProtoStruct" ]
8
+ _ProtoBasicTypes = Union [str , list , dict , bytes , int , float , bool ]
9
+ _ProtoTypes = Union [_ProtoBasicTypes , "ProtoStruct" ]
10
10
11
11
T = TypeVar ("T" , str , list , dict , bytes , int , float , bool , "ProtoStruct" )
12
12
V = TypeVar ("V" )
13
13
NT : TypeAlias = Dict [int , Union [_ProtoTypes , "NT" ]]
14
- AMT : TypeAlias = Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]]
15
- DAMT : TypeAlias = Dict [str , "DelayAnnoType" ]
16
- DelayAnnoType = Union [str , type (List [str ])]
17
14
NoneType = type (None )
18
15
19
16
20
17
class ProtoField (Generic [T ]):
21
18
def __init__ (self , tag : int , default : T ):
22
19
if tag <= 0 :
23
20
raise ValueError ("Tag must be a positive integer" )
24
- self ._tag = tag
25
- self ._default = default
21
+ self ._tag : int = tag
22
+ self ._default : T = default
26
23
27
24
@property
28
25
def tag (self ) -> int :
@@ -86,16 +83,14 @@ def proto_field(
86
83
@dataclass_transform (kw_only_default = True , field_specifiers = (proto_field ,))
87
84
class ProtoStruct :
88
85
_anno_map : Dict [str , Tuple [Type [_ProtoTypes ], ProtoField [Any ]]]
89
- _delay_anno_map : Dict [str , DelayAnnoType ]
90
86
_proto_debug : bool
91
87
92
88
def __init__ (self , * args , ** kwargs ):
93
89
undefined_params : List [str ] = []
94
- args = list (args )
95
- self ._resolve_annotations (self )
90
+ arg_list = list (args )
96
91
for name , (typ , field ) in self ._anno_map .items ():
97
92
if args :
98
- self ._set_attr (name , typ , args .pop (0 ))
93
+ self ._set_attr (name , typ , arg_list .pop (0 ))
99
94
elif name in kwargs :
100
95
self ._set_attr (name , typ , kwargs .pop (name ))
101
96
else :
@@ -104,13 +99,11 @@ def __init__(self, *args, **kwargs):
104
99
else :
105
100
undefined_params .append (name )
106
101
if undefined_params :
107
- raise AttributeError (
108
- "Undefined parameters in '{}': {}" .format (self , undefined_params )
109
- )
102
+ raise AttributeError (f"Undefined parameters in '{ self } ': { undefined_params } " )
110
103
111
104
def __init_subclass__ (cls , ** kwargs ):
105
+ cls ._anno_map = cls ._get_annotations ()
112
106
cls ._proto_debug = kwargs .pop ("debug" ) if "debug" in kwargs else False
113
- cls ._anno_map , cls ._delay_anno_map = cls ._get_annotations ()
114
107
super ().__init_subclass__ (** kwargs )
115
108
116
109
def __repr__ (self ) -> str :
@@ -125,17 +118,14 @@ def _set_attr(self, name: str, data_typ: Type[V], value: V) -> None:
125
118
if isinstance (data_typ , GenericAlias ): # force ignore
126
119
pass
127
120
elif not isinstance (value , data_typ ) and value is not None :
128
- raise TypeError (
129
- "'{}' is not a instance of type '{}'" .format (value , data_typ )
130
- )
121
+ raise TypeError (f"{ value } is not a instance of type { data_typ } " )
131
122
setattr (self , name , value )
132
123
133
124
@classmethod
134
125
def _get_annotations (
135
126
cls ,
136
- ) -> Tuple [AMT , DAMT ]: # Name: (ReturnType, ProtoField)
137
- annotations : AMT = {}
138
- delay_annotations : DAMT = {}
127
+ ) -> Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]]: # Name: (ReturnType, ProtoField)
128
+ annotations : Dict [str , Tuple [Type [_ProtoTypes ], "ProtoField" ]] = {}
139
129
for obj in reversed (inspect .getmro (cls )):
140
130
if obj in (ProtoStruct , object ): # base object, ignore
141
131
continue
@@ -149,34 +139,15 @@ def _get_annotations(
149
139
if not isinstance (field , ProtoField ):
150
140
raise TypeError ("attribute '{name}' is not a ProtoField object" )
151
141
152
- _typ = typ
153
- annotations [name ] = (_typ , field )
154
- if isinstance (typ , str ):
155
- delay_annotations [name ] = typ
156
142
if hasattr (typ , "__origin__" ):
157
- typ = cast (GenericAlias , typ )
158
- _inner = typ .__args__ [0 ]
159
- _typ = typ .__origin__ [typ .__args__ [0 ]]
160
- annotations [name ] = (_typ , field )
161
-
162
- if isinstance (_inner , type ):
163
- continue
164
- if isinstance (_inner , GenericAlias ) and isinstance (_inner .__args__ [0 ], type ):
165
- continue
166
- if isinstance (_inner , str ):
167
- delay_annotations [name ] = _typ .__origin__ [_inner ]
168
- if isinstance (_inner , ForwardRef ):
169
- delay_annotations [name ] = _inner .__forward_arg__
170
- if isinstance (_inner , GenericAlias ):
171
- delay_annotations [name ] = _typ
172
-
173
- return annotations , delay_annotations
143
+ typ = typ .__origin__ [typ .__args__ [0 ]]
144
+ annotations [name ] = (typ , field )
145
+
146
+ return annotations
174
147
175
148
@classmethod
176
149
def _get_field_mapping (cls ) -> Dict [int , Tuple [str , Type [_ProtoTypes ]]]: # Tag, (Name, Type)
177
150
field_mapping : Dict [int , Tuple [str , Type [_ProtoTypes ]]] = {}
178
- if cls ._delay_anno_map :
179
- cls ._resolve_annotations (cls )
180
151
for name , (typ , field ) in cls ._anno_map .items ():
181
152
field_mapping [field .tag ] = (name , typ )
182
153
return field_mapping
@@ -187,17 +158,7 @@ def _get_stored_mapping(self) -> Dict[str, NT]:
187
158
stored_mapping [name ] = getattr (self , name )
188
159
return stored_mapping
189
160
190
- @staticmethod
191
- def _resolve_annotations (arg : Union [Type ["ProtoStruct" ], "ProtoStruct" ]) -> None :
192
- for k , v in arg ._delay_anno_map .copy ().items ():
193
- module = importlib .import_module (arg .__module__ )
194
- if hasattr (v , "__origin__" ): # resolve GenericAlias, such as list[str]
195
- arg ._anno_map [k ] = (v .__origin__ [module .__getattribute__ (v .__args__ [0 ])], arg ._anno_map [k ][1 ])
196
- else :
197
- arg ._anno_map [k ] = (module .__getattribute__ (v ), arg ._anno_map [k ][1 ])
198
- arg ._delay_anno_map .pop (k )
199
-
200
- def _encode (self , v : _ProtoTypes ) -> NT :
161
+ def _encode (self , v : _ProtoTypes ) -> _ProtoBasicTypes :
201
162
if isinstance (v , ProtoStruct ):
202
163
v = v .encode ()
203
164
return v
0 commit comments