-
-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Treat bool as equivalent to Literal[True, False] to allow better type inference #6113
New issue
Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? # to your account
Comments
I think this can be done using literal types: essentially we can treat cc @Michael0x2a |
This appeared again so raising priority to normal. |
Would like to take a stab at this if no one's picked it up already! |
Another use case: from typing import Literal, overload
@overload
def foo(x: Literal[False]) -> int: ...
@overload
def foo(x: Literal[True]) -> str: ...
def foo(x): pass
b: bool = True
foo(b) Gives us:
It would be great if this could infer |
I was puzzled by this for a while, with a case similar to srittau's example above. The third overload they mention is here in the docs: @overload
def fetch_data(raw: Literal[True]) -> bytes: ...
@overload
def fetch_data(raw: Literal[False]) -> str: ...
# The last overload is a fallback in case the caller
# provides a regular bool:
@overload
def fetch_data(raw: bool) -> Union[bytes, str]: ...
def fetch_data(raw: bool) -> Union[bytes, str]:
# Implementation is omitted
... Presumably when this issue is fixed, the third overload can be ditched |
This issue is marked as fixed but the last two comments have not been resolved - the fallback |
Fixes #91654. Currently, the `hook` parameters of `nn.Module.register_forward_pre_hook` and `nn.Module.register_forward_hook` are typed as `Callable[..., None]`, which 1) does not enable the validation of the signature of `hook` and 2) incorrectly restricts the return type of `hook`, which the docstrings of these methods themselves state can be non-`None`. The typing of the first parameter of `hook` as `TypeVar("T", bound="Module")` allows the binding of `Callable` whose first parameter is a subclass of `Module`. --- Here are some examples of: 1. forward hooks and pre-hook hooks being accepted by mypy according to the new type hints 2. mypy throwing errors d.t. incorrect `hook` signatures 3. false negatives of pre-hooks being accepted as forward hooks 4. false negatives of hooks with kwargs being accepted irrespective of the value provided for `with_kwargs` ```python from typing import Any, Dict, Tuple import torch from torch import nn def forward_pre_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> None: ... def forward_pre_hook_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], ) -> Tuple[torch.Tensor, ...]: ... def forward_pre_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> None: ... def forward_pre_hook_with_kwargs_return_input( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], ) -> Tuple[Tuple[torch.Tensor, ...], Dict[str, Any]]: ... def forward_hook( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> None: ... def forward_hook_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], output: torch.Tensor, ) -> torch.Tensor: ... def forward_hook_with_kwargs( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> None: ... def forward_hook_with_kwargs_return_output( module: nn.Linear, args: Tuple[torch.Tensor, ...], kwargs: Dict[str, Any], output: torch.Tensor, ) -> torch.Tensor: ... model = nn.Module() # OK model.register_forward_pre_hook(forward_pre_hook) model.register_forward_pre_hook(forward_pre_hook_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=True) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=True) model.register_forward_hook(forward_hook) model.register_forward_hook(forward_hook_return_output) model.register_forward_hook(forward_hook_with_kwargs, with_kwargs=True) model.register_forward_hook(forward_hook_with_kwargs_return_output, with_kwargs=True) # mypy(error): [arg-type] model.register_forward_pre_hook(forward_hook) model.register_forward_pre_hook(forward_hook_return_output) model.register_forward_pre_hook(forward_hook_with_kwargs) model.register_forward_pre_hook(forward_hook_with_kwargs_return_output) model.register_forward_hook(forward_pre_hook) model.register_forward_hook(forward_pre_hook_return_input) # false negatives model.register_forward_hook(forward_pre_hook_with_kwargs) model.register_forward_hook(forward_pre_hook_with_kwargs_return_input) model.register_forward_pre_hook(forward_pre_hook_with_kwargs, with_kwargs=False) model.register_forward_pre_hook(forward_pre_hook_with_kwargs_return_input, with_kwargs=False) ... ``` --- Though it is not functional as of mypy 0.991, the ideal typing of these methods would use [`typing.Literal`](https://mypy.readthedocs.io/en/stable/literal_types.html#literal-types): ```python T = TypeVar("T", bound="Module") class Module: @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[False] = ..., ) -> RemovableHandle: ... @overload def register_forward_hook( self, hook: Callable[[T, Tuple[Any, ...], Dict[str, Any], Any], Optional[Any]], *, prepend: bool = ..., with_kwargs: Literal[True] = ..., ) -> RemovableHandle: ... def register_forward_hook(...): ... ``` which would: 1. validate the signature of `hook` according to the corresponding literal value provided for `with_kwargs` (and fix the false negative examples above) 2. implicitly define the [fallback `bool` signature](python/mypy#6113 (comment)) e.g. to handle if a non-literal is provided for `with_kwargs` Pull Request resolved: #92061 Approved by: https://github.com/albanD
Given that this issue is marked as closed, why does this example still not work: from typing import overload, Literal
@overload
def something(return_string: Literal[True]) -> str: ...
@overload
def something(return_string: Literal[False]) -> int: ...
def something(return_string: bool) -> int | str:
if return_string:
return "ten"
else:
return 10
something_type = True
value_int = something(something_type)
something_type = False
value_str = something(something_type)
print(value_str.upper()) This is the output I get from mypy 1.13.0:
|
When you assign If mypy would infer the type as The exception here is if you use |
See #2008 , we'll eventually change behaviour in that example |
Similar to this #6027, but more basic use case.
Mypy is giving the error
error: Item "bool" of "Union[bool, Collection[int]]" has no attribute "__iter__" (not iterable)
But it should be able to infer that it can't be a
bool
?using if / elif /else also gives the error.
The text was updated successfully, but these errors were encountered: