Skip to content

Commit 03901ef

Browse files
authored
Running dataclass transform in a later pass to fix crashes (#12762)
The dataclass plugin could crash if it encountered a placeholder. Fix the issue by running the plugin after the main semantic analysis pass, when all placeholders have been resolved. Also add a new hook called get_class_decorator_hook_2 that is used by the dataclass plugin. We may want to do a similar change to the attrs plugin, but let's change one thing at a time. Fix #12685.
1 parent e1c03ab commit 03901ef

File tree

6 files changed

+217
-37
lines changed

6 files changed

+217
-37
lines changed

mypy/plugin.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,33 @@ def get_class_decorator_hook(self, fullname: str
692692
693693
The plugin can modify a TypeInfo _in place_ (for example add some generated
694694
methods to the symbol table). This hook is called after the class body was
695-
semantically analyzed.
695+
semantically analyzed, but *there may still be placeholders* (typically
696+
caused by forward references).
696697
697-
The hook is called with full names of all class decorators, for example
698+
NOTE: Usually get_class_decorator_hook_2 is the better option, since it
699+
guarantees that there are no placeholders.
700+
701+
The hook is called with full names of all class decorators.
702+
703+
The hook can be called multiple times per class, so it must be
704+
idempotent.
705+
"""
706+
return None
707+
708+
def get_class_decorator_hook_2(self, fullname: str
709+
) -> Optional[Callable[[ClassDefContext], bool]]:
710+
"""Update class definition for given class decorators.
711+
712+
Similar to get_class_decorator_hook, but this runs in a later pass when
713+
placeholders have been resolved.
714+
715+
The hook can return False if some base class hasn't been
716+
processed yet using class hooks. It causes all class hooks
717+
(that are run in this same pass) to be invoked another time for
718+
the file(s) currently being processed.
719+
720+
The hook can be called multiple times per class, so it must be
721+
idempotent.
698722
"""
699723
return None
700724

@@ -815,6 +839,10 @@ def get_class_decorator_hook(self, fullname: str
815839
) -> Optional[Callable[[ClassDefContext], None]]:
816840
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook(fullname))
817841

842+
def get_class_decorator_hook_2(self, fullname: str
843+
) -> Optional[Callable[[ClassDefContext], bool]]:
844+
return self._find_hook(lambda plugin: plugin.get_class_decorator_hook_2(fullname))
845+
818846
def get_metaclass_hook(self, fullname: str
819847
) -> Optional[Callable[[ClassDefContext], None]]:
820848
return self._find_hook(lambda plugin: plugin.get_metaclass_hook(fullname))

mypy/plugins/dataclasses.py

+35-12
Original file line numberDiff line numberDiff line change
@@ -107,10 +107,19 @@ def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
107107

108108

109109
class DataclassTransformer:
110+
"""Implement the behavior of @dataclass.
111+
112+
Note that this may be executed multiple times on the same class, so
113+
everything here must be idempotent.
114+
115+
This runs after the main semantic analysis pass, so you can assume that
116+
there are no placeholders.
117+
"""
118+
110119
def __init__(self, ctx: ClassDefContext) -> None:
111120
self._ctx = ctx
112121

113-
def transform(self) -> None:
122+
def transform(self) -> bool:
114123
"""Apply all the necessary transformations to the underlying
115124
dataclass so as to ensure it is fully type checked according
116125
to the rules in PEP 557.
@@ -119,12 +128,11 @@ def transform(self) -> None:
119128
info = self._ctx.cls.info
120129
attributes = self.collect_attributes()
121130
if attributes is None:
122-
# Some definitions are not ready, defer() should be already called.
123-
return
131+
# Some definitions are not ready. We need another pass.
132+
return False
124133
for attr in attributes:
125134
if attr.type is None:
126-
ctx.api.defer()
127-
return
135+
return False
128136
decorator_arguments = {
129137
'init': _get_decorator_bool_argument(self._ctx, 'init', True),
130138
'eq': _get_decorator_bool_argument(self._ctx, 'eq', True),
@@ -236,6 +244,8 @@ def transform(self) -> None:
236244
'frozen': decorator_arguments['frozen'],
237245
}
238246

247+
return True
248+
239249
def add_slots(self,
240250
info: TypeInfo,
241251
attributes: List[DataclassAttribute],
@@ -294,6 +304,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
294304
b: SomeOtherType = ...
295305
296306
are collected.
307+
308+
Return None if some dataclass base class hasn't been processed
309+
yet and thus we'll need to ask for another pass.
297310
"""
298311
# First, collect attributes belonging to the current class.
299312
ctx = self._ctx
@@ -315,14 +328,11 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
315328

316329
sym = cls.info.names.get(lhs.name)
317330
if sym is None:
318-
# This name is likely blocked by a star import. We don't need to defer because
319-
# defer() is already called by mark_incomplete().
331+
# There was probably a semantic analysis error.
320332
continue
321333

322334
node = sym.node
323-
if isinstance(node, PlaceholderNode):
324-
# This node is not ready yet.
325-
return None
335+
assert not isinstance(node, PlaceholderNode)
326336
assert isinstance(node, Var)
327337

328338
# x: ClassVar[int] is ignored by dataclasses.
@@ -390,6 +400,9 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
390400
# we'll have unmodified attrs laying around.
391401
all_attrs = attrs.copy()
392402
for info in cls.info.mro[1:-1]:
403+
if 'dataclass_tag' in info.metadata and 'dataclass' not in info.metadata:
404+
# We haven't processed the base class yet. Need another pass.
405+
return None
393406
if 'dataclass' not in info.metadata:
394407
continue
395408

@@ -517,11 +530,21 @@ def _add_dataclass_fields_magic_attribute(self) -> None:
517530
)
518531

519532

520-
def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
533+
def dataclass_tag_callback(ctx: ClassDefContext) -> None:
534+
"""Record that we have a dataclass in the main semantic analysis pass.
535+
536+
The later pass implemented by DataclassTransformer will use this
537+
to detect dataclasses in base classes.
538+
"""
539+
# The value is ignored, only the existence matters.
540+
ctx.cls.info.metadata['dataclass_tag'] = {}
541+
542+
543+
def dataclass_class_maker_callback(ctx: ClassDefContext) -> bool:
521544
"""Hooks into the class typechecking process to add support for dataclasses.
522545
"""
523546
transformer = DataclassTransformer(ctx)
524-
transformer.transform()
547+
return transformer.transform()
525548

526549

527550
def _collect_field_args(expr: Expression,

mypy/plugins/default.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,21 @@ def get_class_decorator_hook(self, fullname: str
117117
auto_attribs_default=None,
118118
)
119119
elif fullname in dataclasses.dataclass_makers:
120-
return dataclasses.dataclass_class_maker_callback
120+
return dataclasses.dataclass_tag_callback
121121
elif fullname in functools.functools_total_ordering_makers:
122122
return functools.functools_total_ordering_maker_callback
123123

124124
return None
125125

126+
def get_class_decorator_hook_2(self, fullname: str
127+
) -> Optional[Callable[[ClassDefContext], bool]]:
128+
from mypy.plugins import dataclasses
129+
130+
if fullname in dataclasses.dataclass_makers:
131+
return dataclasses.dataclass_class_maker_callback
132+
133+
return None
134+
126135

127136
def contextmanager_callback(ctx: FunctionContext) -> Type:
128137
"""Infer a better return type for 'contextlib.contextmanager'."""

mypy/semanal.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -1234,43 +1234,44 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool:
12341234

12351235
def apply_class_plugin_hooks(self, defn: ClassDef) -> None:
12361236
"""Apply a plugin hook that may infer a more precise definition for a class."""
1237-
def get_fullname(expr: Expression) -> Optional[str]:
1238-
if isinstance(expr, CallExpr):
1239-
return get_fullname(expr.callee)
1240-
elif isinstance(expr, IndexExpr):
1241-
return get_fullname(expr.base)
1242-
elif isinstance(expr, RefExpr):
1243-
if expr.fullname:
1244-
return expr.fullname
1245-
# If we don't have a fullname look it up. This happens because base classes are
1246-
# analyzed in a different manner (see exprtotype.py) and therefore those AST
1247-
# nodes will not have full names.
1248-
sym = self.lookup_type_node(expr)
1249-
if sym:
1250-
return sym.fullname
1251-
return None
12521237

12531238
for decorator in defn.decorators:
1254-
decorator_name = get_fullname(decorator)
1239+
decorator_name = self.get_fullname_for_hook(decorator)
12551240
if decorator_name:
12561241
hook = self.plugin.get_class_decorator_hook(decorator_name)
12571242
if hook:
12581243
hook(ClassDefContext(defn, decorator, self))
12591244

12601245
if defn.metaclass:
1261-
metaclass_name = get_fullname(defn.metaclass)
1246+
metaclass_name = self.get_fullname_for_hook(defn.metaclass)
12621247
if metaclass_name:
12631248
hook = self.plugin.get_metaclass_hook(metaclass_name)
12641249
if hook:
12651250
hook(ClassDefContext(defn, defn.metaclass, self))
12661251

12671252
for base_expr in defn.base_type_exprs:
1268-
base_name = get_fullname(base_expr)
1253+
base_name = self.get_fullname_for_hook(base_expr)
12691254
if base_name:
12701255
hook = self.plugin.get_base_class_hook(base_name)
12711256
if hook:
12721257
hook(ClassDefContext(defn, base_expr, self))
12731258

1259+
def get_fullname_for_hook(self, expr: Expression) -> Optional[str]:
1260+
if isinstance(expr, CallExpr):
1261+
return self.get_fullname_for_hook(expr.callee)
1262+
elif isinstance(expr, IndexExpr):
1263+
return self.get_fullname_for_hook(expr.base)
1264+
elif isinstance(expr, RefExpr):
1265+
if expr.fullname:
1266+
return expr.fullname
1267+
# If we don't have a fullname look it up. This happens because base classes are
1268+
# analyzed in a different manner (see exprtotype.py) and therefore those AST
1269+
# nodes will not have full names.
1270+
sym = self.lookup_type_node(expr)
1271+
if sym:
1272+
return sym.fullname
1273+
return None
1274+
12741275
def analyze_class_keywords(self, defn: ClassDef) -> None:
12751276
for value in defn.keywords.values():
12761277
value.accept(self)

mypy/semanal_main.py

+56-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
from mypy.checker import FineGrainedDeferredNode
4646
from mypy.server.aststrip import SavedAttributes
4747
from mypy.util import is_typeshed_file
48+
from mypy.options import Options
49+
from mypy.plugin import ClassDefContext
4850
import mypy.build
4951

5052
if TYPE_CHECKING:
@@ -82,6 +84,8 @@ def semantic_analysis_for_scc(graph: 'Graph', scc: List[str], errors: Errors) ->
8284
apply_semantic_analyzer_patches(patches)
8385
# This pass might need fallbacks calculated above.
8486
check_type_arguments(graph, scc, errors)
87+
# Run class decorator hooks (they requite complete MROs and no placeholders).
88+
apply_class_plugin_hooks(graph, scc, errors)
8589
calculate_class_properties(graph, scc, errors)
8690
check_blockers(graph, scc)
8791
# Clean-up builtins, so that TypeVar etc. are not accessible without importing.
@@ -132,6 +136,7 @@ def semantic_analysis_for_targets(
132136

133137
check_type_arguments_in_targets(nodes, state, state.manager.errors)
134138
calculate_class_properties(graph, [state.id], state.manager.errors)
139+
apply_class_plugin_hooks(graph, [state.id], state.manager.errors)
135140

136141

137142
def restore_saved_attrs(saved_attrs: SavedAttributes) -> None:
@@ -382,14 +387,62 @@ def check_type_arguments_in_targets(targets: List[FineGrainedDeferredNode], stat
382387
target.node.accept(analyzer)
383388

384389

390+
def apply_class_plugin_hooks(graph: 'Graph', scc: List[str], errors: Errors) -> None:
391+
"""Apply class plugin hooks within a SCC.
392+
393+
We run these after to the main semantic analysis so that the hooks
394+
don't need to deal with incomplete definitions such as placeholder
395+
types.
396+
397+
Note that some hooks incorrectly run during the main semantic
398+
analysis pass, for historical reasons.
399+
"""
400+
num_passes = 0
401+
incomplete = True
402+
# If we encounter a base class that has not been processed, we'll run another
403+
# pass. This should eventually reach a fixed point.
404+
while incomplete:
405+
assert num_passes < 10, "Internal error: too many class plugin hook passes"
406+
num_passes += 1
407+
incomplete = False
408+
for module in scc:
409+
state = graph[module]
410+
tree = state.tree
411+
assert tree
412+
for _, node, _ in tree.local_definitions():
413+
if isinstance(node.node, TypeInfo):
414+
if not apply_hooks_to_class(state.manager.semantic_analyzer,
415+
module, node.node, state.options, tree, errors):
416+
incomplete = True
417+
418+
419+
def apply_hooks_to_class(self: SemanticAnalyzer,
420+
module: str,
421+
info: TypeInfo,
422+
options: Options,
423+
file_node: MypyFile,
424+
errors: Errors) -> bool:
425+
# TODO: Move more class-related hooks here?
426+
defn = info.defn
427+
ok = True
428+
for decorator in defn.decorators:
429+
with self.file_context(file_node, options, info):
430+
decorator_name = self.get_fullname_for_hook(decorator)
431+
if decorator_name:
432+
hook = self.plugin.get_class_decorator_hook_2(decorator_name)
433+
if hook:
434+
ok = ok and hook(ClassDefContext(defn, decorator, self))
435+
return ok
436+
437+
385438
def calculate_class_properties(graph: 'Graph', scc: List[str], errors: Errors) -> None:
386439
for module in scc:
387-
tree = graph[module].tree
440+
state = graph[module]
441+
tree = state.tree
388442
assert tree
389443
for _, node, _ in tree.local_definitions():
390444
if isinstance(node.node, TypeInfo):
391-
saved = (module, node.node, None) # module, class, function
392-
with errors.scope.saved_scope(saved) if errors.scope else nullcontext():
445+
with state.manager.semantic_analyzer.file_context(tree, state.options, node.node):
393446
calculate_class_abstract_status(node.node, tree.is_stub, errors)
394447
check_protocol_status(node.node, errors)
395448
calculate_class_vars(node.node)

0 commit comments

Comments
 (0)