Skip to content

Commit 2e38965

Browse files
committed
Fix union callees with functools.partial (#17903)
Fixes #17741.
1 parent c5d3673 commit 2e38965

File tree

2 files changed

+33
-3
lines changed

2 files changed

+33
-3
lines changed

mypy/plugins/functools.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Type,
1919
TypeOfAny,
2020
UnboundType,
21+
UnionType,
2122
get_proper_type,
2223
)
2324

@@ -130,7 +131,19 @@ def partial_new_callback(ctx: mypy.plugin.FunctionContext) -> Type:
130131
if isinstance(get_proper_type(ctx.arg_types[0][0]), Overloaded):
131132
# TODO: handle overloads, just fall back to whatever the non-plugin code does
132133
return ctx.default_return_type
133-
fn_type = ctx.api.extract_callable_type(ctx.arg_types[0][0], ctx=ctx.default_return_type)
134+
return handle_partial_with_callee(ctx, callee=ctx.arg_types[0][0])
135+
136+
137+
def handle_partial_with_callee(ctx: mypy.plugin.FunctionContext, callee: Type) -> Type:
138+
if not isinstance(ctx.api, mypy.checker.TypeChecker): # use internals
139+
return ctx.default_return_type
140+
141+
if isinstance(callee_proper := get_proper_type(callee), UnionType):
142+
return UnionType.make_union(
143+
[handle_partial_with_callee(ctx, item) for item in callee_proper.items]
144+
)
145+
146+
fn_type = ctx.api.extract_callable_type(callee, ctx=ctx.default_return_type)
134147
if fn_type is None:
135148
return ctx.default_return_type
136149

test-data/unit/check-functools.test

+19-2
Original file line numberDiff line numberDiff line change
@@ -346,15 +346,32 @@ fn1: Union[Callable[[int], int], Callable[[int], int]]
346346
reveal_type(functools.partial(fn1, 2)()) # N: Revealed type is "builtins.int"
347347

348348
fn2: Union[Callable[[int], int], Callable[[int], str]]
349-
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "builtins.object"
349+
reveal_type(functools.partial(fn2, 2)()) # N: Revealed type is "Union[builtins.int, builtins.str]"
350350

351351
fn3: Union[Callable[[int], int], str]
352352
reveal_type(functools.partial(fn3, 2)()) # E: "str" not callable \
353-
# E: "Union[Callable[[int], int], str]" not callable \
354353
# N: Revealed type is "builtins.int" \
355354
# E: Argument 1 to "partial" has incompatible type "Union[Callable[[int], int], str]"; expected "Callable[..., int]"
356355
[builtins fixtures/tuple.pyi]
357356

357+
[case testFunctoolsPartialUnionOfTypeAndCallable]
358+
import functools
359+
from typing import Callable, Union, Type
360+
from typing_extensions import TypeAlias
361+
362+
class FooBar:
363+
def __init__(self, arg1: str) -> None:
364+
pass
365+
366+
def f1(t: Union[Type[FooBar], Callable[..., 'FooBar']]) -> None:
367+
val = functools.partial(t)
368+
369+
FooBarFunc: TypeAlias = Callable[..., 'FooBar']
370+
371+
def f2(t: Union[Type[FooBar], FooBarFunc]) -> None:
372+
val = functools.partial(t)
373+
[builtins fixtures/tuple.pyi]
374+
358375
[case testFunctoolsPartialExplicitType]
359376
from functools import partial
360377
from typing import Type, TypeVar, Callable

0 commit comments

Comments
 (0)