@@ -958,38 +958,44 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:
958
958
959
959
This is useful for dict subclasses like SymbolTable.
960
960
"""
961
- target_type = get_proper_type (self .types [expr ])
961
+ return self .get_dict_base_type_from_type (self .types [expr ])
962
+
963
+ def get_dict_base_type_from_type (self , target_type : Type ) -> list [Instance ]:
964
+ target_type = get_proper_type (target_type )
962
965
if isinstance (target_type , UnionType ):
963
- types = [get_proper_type (item ) for item in target_type .items ]
966
+ return [
967
+ inner
968
+ for item in target_type .items
969
+ for inner in self .get_dict_base_type_from_type (item )
970
+ ]
971
+ if isinstance (target_type , TypeVarLikeType ):
972
+ # Match behaviour of self.node_type
973
+ # We can only reach this point if `target_type` was a TypeVar(bound=dict[...])
974
+ # or a ParamSpec.
975
+ return self .get_dict_base_type_from_type (target_type .upper_bound )
976
+
977
+ if isinstance (target_type , TypedDictType ):
978
+ target_type = target_type .fallback
979
+ dict_base = next (
980
+ base for base in target_type .type .mro if base .fullname == "typing.Mapping"
981
+ )
982
+ elif isinstance (target_type , Instance ):
983
+ dict_base = next (
984
+ base for base in target_type .type .mro if base .fullname == "builtins.dict"
985
+ )
964
986
else :
965
- types = [target_type ]
966
-
967
- dict_types = []
968
- for t in types :
969
- if isinstance (t , TypedDictType ):
970
- t = t .fallback
971
- dict_base = next (base for base in t .type .mro if base .fullname == "typing.Mapping" )
972
- else :
973
- assert isinstance (t , Instance ), t
974
- dict_base = next (base for base in t .type .mro if base .fullname == "builtins.dict" )
975
- dict_types .append (map_instance_to_supertype (t , dict_base ))
976
- return dict_types
987
+ assert False , f"Failed to extract dict base from { target_type } "
988
+ return [map_instance_to_supertype (target_type , dict_base )]
977
989
978
990
def get_dict_key_type (self , expr : Expression ) -> RType :
979
991
dict_base_types = self .get_dict_base_type (expr )
980
- if len (dict_base_types ) == 1 :
981
- return self .type_to_rtype (dict_base_types [0 ].args [0 ])
982
- else :
983
- rtypes = [self .type_to_rtype (t .args [0 ]) for t in dict_base_types ]
984
- return RUnion .make_simplified_union (rtypes )
992
+ rtypes = [self .type_to_rtype (t .args [0 ]) for t in dict_base_types ]
993
+ return RUnion .make_simplified_union (rtypes )
985
994
986
995
def get_dict_value_type (self , expr : Expression ) -> RType :
987
996
dict_base_types = self .get_dict_base_type (expr )
988
- if len (dict_base_types ) == 1 :
989
- return self .type_to_rtype (dict_base_types [0 ].args [1 ])
990
- else :
991
- rtypes = [self .type_to_rtype (t .args [1 ]) for t in dict_base_types ]
992
- return RUnion .make_simplified_union (rtypes )
997
+ rtypes = [self .type_to_rtype (t .args [1 ]) for t in dict_base_types ]
998
+ return RUnion .make_simplified_union (rtypes )
993
999
994
1000
def get_dict_item_type (self , expr : Expression ) -> RType :
995
1001
key_type = self .get_dict_key_type (expr )
0 commit comments