Skip to content

Commit a7d6e68

Browse files
authored
Fix mypy crash on dataclasses.field(**unpack) (#11137)
1 parent fab534b commit a7d6e68

File tree

3 files changed

+71
-5
lines changed

3 files changed

+71
-5
lines changed

mypy/plugins/dataclasses.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def collect_attributes(self) -> Optional[List[DataclassAttribute]]:
269269
if self._is_kw_only_type(node_type):
270270
kw_only = True
271271

272-
has_field_call, field_args = _collect_field_args(stmt.rvalue)
272+
has_field_call, field_args = _collect_field_args(stmt.rvalue, ctx)
273273

274274
is_in_init_param = field_args.get('init')
275275
if is_in_init_param is None:
@@ -447,7 +447,8 @@ def dataclass_class_maker_callback(ctx: ClassDefContext) -> None:
447447
transformer.transform()
448448

449449

450-
def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
450+
def _collect_field_args(expr: Expression,
451+
ctx: ClassDefContext) -> Tuple[bool, Dict[str, Expression]]:
451452
"""Returns a tuple where the first value represents whether or not
452453
the expression is a call to dataclass.field and the second is a
453454
dictionary of the keyword arguments that field() was called with.
@@ -460,7 +461,15 @@ def _collect_field_args(expr: Expression) -> Tuple[bool, Dict[str, Expression]]:
460461
# field() only takes keyword arguments.
461462
args = {}
462463
for name, arg in zip(expr.arg_names, expr.args):
463-
assert name is not None
464+
if name is None:
465+
# This means that `field` is used with `**` unpacking,
466+
# the best we can do for now is not to fail.
467+
# TODO: we can infer what's inside `**` and try to collect it.
468+
ctx.api.fail(
469+
'Unpacking **kwargs in "field()" is not supported',
470+
expr,
471+
)
472+
return True, {}
464473
args[name] = arg
465474
return True, args
466475
return False, {}

test-data/unit/check-dataclasses.test

+36
Original file line numberDiff line numberDiff line change
@@ -1300,3 +1300,39 @@ a.x = x
13001300
a.x = x2 # E: Incompatible types in assignment (expression has type "Callable[[str], str]", variable has type "Callable[[int], int]")
13011301

13021302
[builtins fixtures/dataclasses.pyi]
1303+
1304+
1305+
[case testDataclassFieldDoesNotFailOnKwargsUnpacking]
1306+
# flags: --python-version 3.7
1307+
# https://github.com/python/mypy/issues/10879
1308+
from dataclasses import dataclass, field
1309+
1310+
@dataclass
1311+
class Foo:
1312+
bar: float = field(**{"repr": False})
1313+
[out]
1314+
main:7: error: Unpacking **kwargs in "field()" is not supported
1315+
main:7: error: No overload variant of "field" matches argument type "Dict[str, bool]"
1316+
main:7: note: Possible overload variants:
1317+
main:7: note: def [_T] field(*, default: _T, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
1318+
main:7: note: def [_T] field(*, default_factory: Callable[[], _T], init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> _T
1319+
main:7: note: def field(*, init: bool = ..., repr: bool = ..., hash: Optional[bool] = ..., compare: bool = ..., metadata: Optional[Mapping[str, Any]] = ..., kw_only: bool = ...) -> Any
1320+
[builtins fixtures/dataclasses.pyi]
1321+
1322+
1323+
[case testDataclassFieldWithTypedDictUnpacking]
1324+
# flags: --python-version 3.7
1325+
from dataclasses import dataclass, field
1326+
from typing_extensions import TypedDict
1327+
1328+
class FieldKwargs(TypedDict):
1329+
repr: bool
1330+
1331+
field_kwargs: FieldKwargs = {"repr": False}
1332+
1333+
@dataclass
1334+
class Foo:
1335+
bar: float = field(**field_kwargs) # E: Unpacking **kwargs in "field()" is not supported
1336+
1337+
reveal_type(Foo(bar=1.5)) # N: Revealed type is "__main__.Foo"
1338+
[builtins fixtures/dataclasses.pyi]

test-data/unit/fixtures/dataclasses.pyi

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1-
from typing import Generic, Sequence, TypeVar
1+
from typing import (
2+
Generic, Iterator, Iterable, Mapping, Optional, Sequence, Tuple,
3+
TypeVar, Union, overload,
4+
)
25

36
_T = TypeVar('_T')
47
_U = TypeVar('_U')
8+
KT = TypeVar('KT')
9+
VT = TypeVar('VT')
510

611
class object:
712
def __init__(self) -> None: pass
@@ -15,7 +20,23 @@ class int: pass
1520
class float: pass
1621
class str: pass
1722
class bool(int): pass
18-
class dict(Generic[_T, _U]): pass
23+
24+
class dict(Mapping[KT, VT]):
25+
@overload
26+
def __init__(self, **kwargs: VT) -> None: pass
27+
@overload
28+
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
29+
def __getitem__(self, key: KT) -> VT: pass
30+
def __setitem__(self, k: KT, v: VT) -> None: pass
31+
def __iter__(self) -> Iterator[KT]: pass
32+
def __contains__(self, item: object) -> int: pass
33+
def update(self, a: Mapping[KT, VT]) -> None: pass
34+
@overload
35+
def get(self, k: KT) -> Optional[VT]: pass
36+
@overload
37+
def get(self, k: KT, default: Union[KT, _T]) -> Union[VT, _T]: pass
38+
def __len__(self) -> int: ...
39+
1940
class list(Generic[_T], Sequence[_T]): pass
2041
class function: pass
2142
class classmethod: pass

0 commit comments

Comments
 (0)