Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Narrow literals in the negative case even with custom equality #18574

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 26 additions & 17 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6412,33 +6412,42 @@ def equality_type_narrowing_helper(
should_narrow_by_identity = True
else:

def is_exactly_literal_type(t: Type) -> bool:
return isinstance(get_proper_type(t), LiteralType)

def has_no_custom_eq_checks(t: Type) -> bool:
return not custom_special_method(
t, "__eq__", check_all=False
) and not custom_special_method(t, "__ne__", check_all=False)

is_valid_target = is_exactly_literal_type
coerce_only_in_literal_context = True

expr_types = [operand_types[i] for i in expr_indices]
should_narrow_by_identity = all(
map(has_no_custom_eq_checks, expr_types)
) and not is_ambiguous_mix_of_enums(expr_types)

if_map: TypeMap = {}
else_map: TypeMap = {}
if should_narrow_by_identity:
if_map, else_map = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)
def is_exactly_literal_type_possibly_except_enum(t: Type) -> bool:
p_t = get_proper_type(t)
if isinstance(p_t, LiteralType):
if should_narrow_by_identity:
return True
else:
return not p_t.fallback.type.is_enum
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternative: I could try has_no_custom_eq_checks(p_t.fallback) maybe?

else:
return False

is_valid_target = is_exactly_literal_type_possibly_except_enum
coerce_only_in_literal_context = True

if_map, else_map = self.refine_identity_comparison_expression(
operands,
operand_types,
expr_indices,
narrowable_operand_index_to_hash.keys(),
is_valid_target,
coerce_only_in_literal_context,
)
if not should_narrow_by_identity:
# refine_identity_comparison_expression narrows against a single literal
# -- and we know that literal will only go to the positive branch.
# This means that the negative branch narrowing is actually correct.
if_map = {}

if if_map == {} and else_map == {}:
if_map, else_map = self.refine_away_none_in_comparison(
Expand Down
27 changes: 26 additions & 1 deletion test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ x1: Union[Custom, Literal[1], Literal[2]]
if x1 == 1:
reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[1], Literal[2]]"
else:
reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[1], Literal[2]]"
reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[2]]"

x2: Union[Default, Literal[1], Literal[2]]
if x2 == 1:
Expand Down Expand Up @@ -2417,6 +2417,31 @@ while x is not None and b():

[builtins fixtures/primitives.pyi]


[case testNegativeNarrowingWithCustomEq]
from typing import Union
from typing_extensions import Literal

class A:
def __eq__(self, other: object) -> bool: ... # necessary

def f(v: Union[A, Literal["text"]]) -> Union[A, None]:
if v == "text":
reveal_type(v) # N: Revealed type is "Union[__main__.A, Literal['text']]"
return None
else:
reveal_type(v) # N: Revealed type is "__main__.A"
return v # no error

def g(v: Union[A, Literal["text"]]) -> Union[A, None]:
if v != "text":
reveal_type(v) # N: Revealed type is "__main__.A"
return None
else:
reveal_type(v) # N: Revealed type is "Union[__main__.A, Literal['text']]"
return v # E: Incompatible return value type (got "Union[A, Literal['text']]", expected "Optional[A]")
[builtins fixtures/primitives.pyi]

[case testNarrowingTypeVarMultiple]
from typing import TypeVar

Expand Down