Skip to content

Commit

Permalink
Fix polymorphic application for callback protocols (#16514)
Browse files Browse the repository at this point in the history
Fixes #16512

The problems were caused if same callback protocol appeared multiple
times in a signature. Previous logic confused this with a recursive
callback protocol.
  • Loading branch information
ilevkivskyi authored and JukkaL committed Nov 22, 2023
1 parent 661adb7 commit e6399d1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
16 changes: 11 additions & 5 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -6209,11 +6209,16 @@ class PolyTranslator(TypeTranslator):
See docstring for apply_poly() for details.
"""

def __init__(self, poly_tvars: Sequence[TypeVarLikeType]) -> None:
def __init__(
self,
poly_tvars: Iterable[TypeVarLikeType],
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
seen_aliases: frozenset[TypeInfo] = frozenset(),
) -> None:
self.poly_tvars = set(poly_tvars)
# This is a simplified version of TypeVarScope used during semantic analysis.
self.bound_tvars: set[TypeVarLikeType] = set()
self.seen_aliases: set[TypeInfo] = set()
self.bound_tvars = bound_tvars
self.seen_aliases = seen_aliases

def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
found_vars = []
Expand Down Expand Up @@ -6289,10 +6294,11 @@ def visit_instance(self, t: Instance) -> Type:
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
if t.type in self.seen_aliases:
raise PolyTranslationError()
self.seen_aliases.add(t.type)
call = find_member("__call__", t, t, is_operator=True)
assert call is not None
return call.accept(self)
return call.accept(
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
)
return super().visit_instance(t)


Expand Down
25 changes: 25 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3788,3 +3788,28 @@ def func2(arg: T) -> List[Union[T, str]]:
reveal_type(func2) # N: Revealed type is "def [S] (S`4) -> Union[S`4, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/list.pyi]

[case testInferenceAgainstGenericCallbackProtoMultiple]
from typing import Callable, Protocol, TypeVar
from typing_extensions import Concatenate, ParamSpec

V_co = TypeVar("V_co", covariant=True)
class Metric(Protocol[V_co]):
def __call__(self) -> V_co: ...

T = TypeVar("T")
P = ParamSpec("P")
def simple_metric(func: Callable[Concatenate[int, P], T]) -> Callable[P, T]: ...

@simple_metric
def Negate(count: int, /, metric: Metric[float]) -> float: ...
@simple_metric
def Combine(count: int, m1: Metric[T], m2: Metric[T], /, *more: Metric[T]) -> T: ...

reveal_type(Negate) # N: Revealed type is "def (metric: __main__.Metric[builtins.float]) -> builtins.float"
reveal_type(Combine) # N: Revealed type is "def [T] (def () -> T`4, def () -> T`4, *more: def () -> T`4) -> T`4"

def m1() -> float: ...
def m2() -> float: ...
reveal_type(Combine(m1, m2)) # N: Revealed type is "builtins.float"
[builtins fixtures/list.pyi]

0 comments on commit e6399d1

Please # to comment.