From 9cdfb193831dade20c56e29220d78f7dc1d68718 Mon Sep 17 00:00:00 2001 From: A5rocks Date: Fri, 31 Jan 2025 08:54:04 +0900 Subject: [PATCH] Narrow literals in the negative case even with custom equality --- mypy/checker.py | 43 +++++++++++++++++------------ test-data/unit/check-narrowing.test | 26 ++++++++++++++++- 2 files changed, 51 insertions(+), 18 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index c69b80a55fd9c..de0e028b32754 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6413,33 +6413,42 @@ def equality_type_narrowing_helper( should_narrow_by_identity = True else: - def is_exactly_literal_type(t: Type) -> bool: - return isinstance(get_proper_type(t), LiteralType) - def has_no_custom_eq_checks(t: Type) -> bool: return not custom_special_method( t, "__eq__", check_all=False ) and not custom_special_method(t, "__ne__", check_all=False) - is_valid_target = is_exactly_literal_type - coerce_only_in_literal_context = True - expr_types = [operand_types[i] for i in expr_indices] should_narrow_by_identity = all( map(has_no_custom_eq_checks, expr_types) ) and not is_ambiguous_mix_of_enums(expr_types) - if_map: TypeMap = {} - else_map: TypeMap = {} - if should_narrow_by_identity: - if_map, else_map = self.refine_identity_comparison_expression( - operands, - operand_types, - expr_indices, - narrowable_operand_index_to_hash.keys(), - is_valid_target, - coerce_only_in_literal_context, - ) + def is_exactly_literal_type_possibly_except_enum(t: Type) -> bool: + p_t = get_proper_type(t) + if isinstance(p_t, LiteralType): + if should_narrow_by_identity: + return True + else: + return not p_t.fallback.type.is_enum + else: + return False + + is_valid_target = is_exactly_literal_type_possibly_except_enum + coerce_only_in_literal_context = True + + if_map, else_map = self.refine_identity_comparison_expression( + operands, + operand_types, + expr_indices, + narrowable_operand_index_to_hash.keys(), + is_valid_target, + coerce_only_in_literal_context, + ) + if not should_narrow_by_identity: + # refine_identity_comparison_expression narrows against a single literal + # -- and we know that literal will only go to the positive branch. + # This means that the negative branch narrowing is actually correct. + if_map = {} if if_map == {} and else_map == {}: if_map, else_map = self.refine_away_none_in_comparison( diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index ec647366e7437..835b416270fc0 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -845,7 +845,7 @@ x1: Union[Custom, Literal[1], Literal[2]] if x1 == 1: reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[1], Literal[2]]" else: - reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[1], Literal[2]]" + reveal_type(x1) # N: Revealed type is "Union[__main__.Custom, Literal[2]]" x2: Union[Default, Literal[1], Literal[2]] if x2 == 1: @@ -2416,3 +2416,27 @@ while x is not None and b(): x = f() [builtins fixtures/primitives.pyi] + +[case testNegativeNarrowingWithCustomEq] +from typing import Union +from typing_extensions import Literal + +class A: + def __eq__(self, other: object) -> bool: ... # necessary + +def f(v: Union[A, Literal["text"]]) -> Union[A, None]: + if v == "text": + reveal_type(v) # N: Revealed type is "Union[__main__.A, Literal['text']]" + return None + else: + reveal_type(v) # N: Revealed type is "__main__.A" + return v # no error + +def g(v: Union[A, Literal["text"]]) -> Union[A, None]: + if v != "text": + reveal_type(v) # N: Revealed type is "__main__.A" + return None + else: + reveal_type(v) # N: Revealed type is "Union[__main__.A, Literal['text']]" + return v # E: Incompatible return value type (got "Union[A, Literal['text']]", expected "Optional[A]") +[builtins fixtures/primitives.pyi]