Skip to content

Commit 9ad47be

Browse files
committed
Merge branch 'kwargs-overload'
Fixes #1150.
2 parents 9befcc7 + aafceae commit 9ad47be

File tree

5 files changed

+153
-130
lines changed

5 files changed

+153
-130
lines changed

mypy/checkexpr.py

+111-94
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Expression type checker. This file is conceptually part of TypeChecker."""
22

3-
from typing import cast, List, Tuple, Dict, Callable, Union
3+
from typing import cast, List, Tuple, Dict, Callable, Union, Optional
44

55
from mypy.types import (
66
Type, AnyType, CallableType, Overloaded, NoneTyp, Void, TypeVarDef,
@@ -28,7 +28,7 @@
2828
from mypy import messages
2929
from mypy.infer import infer_type_arguments, infer_function_type_arguments
3030
from mypy import join
31-
from mypy.expandtype import expand_type, expand_caller_var_args
31+
from mypy.expandtype import expand_type
3232
from mypy.subtypes import is_subtype, is_more_precise
3333
from mypy import applytype
3434
from mypy import erasetype
@@ -38,6 +38,16 @@
3838
from mypy.checkstrformat import StringFormatterChecker
3939

4040

41+
# Type of callback user for checking individual function arguments. See
42+
# check_args() below for details.
43+
ArgChecker = Callable[[Type, Type, Type, int, int, CallableType, Context, MessageBuilder],
44+
None]
45+
46+
47+
class Finished(Exception):
48+
"""Raised if we can terminate overload argument check early (no match)."""
49+
50+
4151
class ExpressionChecker:
4252
"""Expression type checker.
4353
@@ -204,7 +214,6 @@ def check_call(self, callee: Type, args: List[Node],
204214
arg_messages: TODO
205215
"""
206216
arg_messages = arg_messages or self.msg
207-
is_var_arg = nodes.ARG_STAR in arg_kinds
208217
if isinstance(callee, CallableType):
209218
if callee.is_type_obj() and callee.type_object().is_abstract:
210219
type = callee.type_object()
@@ -227,7 +236,7 @@ def check_call(self, callee: Type, args: List[Node],
227236
callee, args, arg_kinds, formal_to_actual)
228237

229238
self.check_argument_count(callee, arg_types, arg_kinds,
230-
arg_names, formal_to_actual, context)
239+
arg_names, formal_to_actual, context, self.msg)
231240

232241
self.check_argument_types(arg_types, arg_kinds, callee,
233242
formal_to_actual, context,
@@ -244,7 +253,7 @@ def check_call(self, callee: Type, args: List[Node],
244253
arg_types = self.infer_arg_types_in_context(None, args)
245254
self.msg.enable_errors()
246255

247-
target = self.overload_call_target(arg_types, is_var_arg,
256+
target = self.overload_call_target(arg_types, arg_kinds, arg_names,
248257
callee, context,
249258
messages=arg_messages)
250259
return self.check_call(target, args, arg_kinds, context, arg_names,
@@ -495,66 +504,86 @@ def apply_inferred_arguments(self, callee_type: CallableType,
495504
def check_argument_count(self, callee: CallableType, actual_types: List[Type],
496505
actual_kinds: List[int], actual_names: List[str],
497506
formal_to_actual: List[List[int]],
498-
context: Context) -> None:
499-
"""Check that the number of arguments to a function are valid.
507+
context: Context,
508+
messages: Optional[MessageBuilder]) -> bool:
509+
"""Check that there is a value for all required arguments to a function.
500510
501-
Also check that there are no duplicate values for arguments.
511+
Also check that there are no duplicate values for arguments. Report found errors
512+
using 'messages' if it's not None.
513+
514+
Return False if there were any errors. Otherwise return True
502515
"""
516+
# TODO(jukka): We could return as soon as we find an error if messages is None.
503517
formal_kinds = callee.arg_kinds
504518

505519
# Collect list of all actual arguments matched to formal arguments.
506520
all_actuals = [] # type: List[int]
507521
for actuals in formal_to_actual:
508522
all_actuals.extend(actuals)
509523

510-
is_error = False # Keep track of errors to avoid duplicate errors.
524+
is_unexpected_arg_error = False # Keep track of errors to avoid duplicate errors.
525+
ok = True # False if we've found any error.
511526
for i, kind in enumerate(actual_kinds):
512527
if i not in all_actuals and (
513528
kind != nodes.ARG_STAR or
514529
not is_empty_tuple(actual_types[i])):
515530
# Extra actual: not matched by a formal argument.
531+
ok = False
516532
if kind != nodes.ARG_NAMED:
517-
self.msg.too_many_arguments(callee, context)
533+
if messages:
534+
messages.too_many_arguments(callee, context)
518535
else:
519-
self.msg.unexpected_keyword_argument(
520-
callee, actual_names[i], context)
521-
is_error = True
536+
if messages:
537+
messages.unexpected_keyword_argument(
538+
callee, actual_names[i], context)
539+
is_unexpected_arg_error = True
522540
elif kind == nodes.ARG_STAR and (
523541
nodes.ARG_STAR not in formal_kinds):
524542
actual_type = actual_types[i]
525543
if isinstance(actual_type, TupleType):
526544
if all_actuals.count(i) < len(actual_type.items):
527545
# Too many tuple items as some did not match.
528-
self.msg.too_many_arguments(callee, context)
546+
if messages:
547+
messages.too_many_arguments(callee, context)
548+
ok = False
529549
# *args can be applied even if the function takes a fixed
530550
# number of positional arguments. This may succeed at runtime.
531551

532552
for i, kind in enumerate(formal_kinds):
533553
if kind == nodes.ARG_POS and (not formal_to_actual[i] and
534-
not is_error):
554+
not is_unexpected_arg_error):
535555
# No actual for a mandatory positional formal.
536-
self.msg.too_few_arguments(callee, context, actual_names)
556+
if messages:
557+
messages.too_few_arguments(callee, context, actual_names)
558+
ok = False
537559
elif kind in [nodes.ARG_POS, nodes.ARG_OPT,
538560
nodes.ARG_NAMED] and is_duplicate_mapping(
539561
formal_to_actual[i], actual_kinds):
540562
if (self.chk.typing_mode_full() or
541563
isinstance(actual_types[formal_to_actual[i][0]], TupleType)):
542-
self.msg.duplicate_argument_value(callee, i, context)
564+
if messages:
565+
messages.duplicate_argument_value(callee, i, context)
566+
ok = False
543567
elif (kind == nodes.ARG_NAMED and formal_to_actual[i] and
544568
actual_kinds[formal_to_actual[i][0]] != nodes.ARG_NAMED):
545569
# Positional argument when expecting a keyword argument.
546-
self.msg.too_many_positional_arguments(callee, context)
570+
if messages:
571+
messages.too_many_positional_arguments(callee, context)
572+
ok = False
573+
return ok
547574

548575
def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int],
549576
callee: CallableType,
550577
formal_to_actual: List[List[int]],
551578
context: Context,
552-
messages: MessageBuilder = None) -> None:
579+
messages: MessageBuilder = None,
580+
check_arg: ArgChecker = None) -> None:
553581
"""Check argument types against a callable type.
554582
555583
Report errors if the argument types are not compatible.
556584
"""
557585
messages = messages or self.msg
586+
check_arg = check_arg or self.check_arg
558587
# Keep track of consumed tuple *arg items.
559588
tuple_counter = [0]
560589
for i, actuals in enumerate(formal_to_actual):
@@ -567,13 +596,13 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int],
567596
if (arg_kinds[actual] == nodes.ARG_STAR2 and
568597
not self.is_valid_keyword_var_arg(arg_type)):
569598
messages.invalid_keyword_var_arg(arg_type, context)
570-
# Get the type of an inidividual actual argument (for *args
599+
# Get the type of an individual actual argument (for *args
571600
# and **args this is the item type, not the collection type).
572601
actual_type = get_actual_type(arg_type, arg_kinds[actual],
573602
tuple_counter)
574-
self.check_arg(actual_type, arg_type,
575-
callee.arg_types[i],
576-
actual + 1, i + 1, callee, context, messages)
603+
check_arg(actual_type, arg_type,
604+
callee.arg_types[i],
605+
actual + 1, i + 1, callee, context, messages)
577606

578607
# There may be some remaining tuple varargs items that haven't
579608
# been checked yet. Handle them.
@@ -585,9 +614,9 @@ def check_argument_types(self, arg_types: List[Type], arg_kinds: List[int],
585614
actual_type = get_actual_type(arg_type,
586615
arg_kinds[actual],
587616
tuple_counter)
588-
self.check_arg(actual_type, arg_type,
589-
callee.arg_types[i],
590-
actual + 1, i + 1, callee, context, messages)
617+
check_arg(actual_type, arg_type,
618+
callee.arg_types[i],
619+
actual + 1, i + 1, callee, context, messages)
591620

592621
def check_arg(self, caller_type: Type, original_caller_type: Type,
593622
callee_type: Type, n: int, m: int, callee: CallableType,
@@ -601,14 +630,14 @@ def check_arg(self, caller_type: Type, original_caller_type: Type,
601630
messages.incompatible_argument(n, m, callee, original_caller_type,
602631
context)
603632

604-
def overload_call_target(self, arg_types: List[Type], is_var_arg: bool,
633+
def overload_call_target(self, arg_types: List[Type], arg_kinds: List[int],
634+
arg_names: List[str],
605635
overload: Overloaded, context: Context,
606636
messages: MessageBuilder = None) -> Type:
607637
"""Infer the correct overload item to call with given argument types.
608638
609639
The return value may be CallableType or AnyType (if an unique item
610-
could not be determined). If is_var_arg is True, the caller
611-
uses varargs.
640+
could not be determined).
612641
"""
613642
messages = messages or self.msg
614643
# TODO also consider argument names and kinds
@@ -617,7 +646,8 @@ def overload_call_target(self, arg_types: List[Type], is_var_arg: bool,
617646
match = [] # type: List[CallableType]
618647
best_match = 0
619648
for typ in overload.items():
620-
similarity = self.erased_signature_similarity(arg_types, is_var_arg, typ)
649+
similarity = self.erased_signature_similarity(arg_types, arg_kinds, arg_names,
650+
typ)
621651
if similarity > 0 and similarity >= best_match:
622652
if (match and not is_same_type(match[-1].ret_type,
623653
typ.ret_type) and
@@ -647,12 +677,12 @@ def overload_call_target(self, arg_types: List[Type], is_var_arg: bool,
647677
# matching signature, or default to the first one if none
648678
# match.
649679
for m in match:
650-
if self.match_signature_types(arg_types, is_var_arg, m):
680+
if self.match_signature_types(arg_types, arg_kinds, arg_names, m):
651681
return m
652682
return match[0]
653683

654-
def erased_signature_similarity(self, arg_types: List[Type], is_var_arg: bool,
655-
callee: CallableType) -> int:
684+
def erased_signature_similarity(self, arg_types: List[Type], arg_kinds: List[int],
685+
arg_names: List[str], callee: CallableType) -> int:
656686
"""Determine whether arguments could match the signature at runtime.
657687
658688
If is_var_arg is True, the caller uses varargs. This is used for
@@ -661,59 +691,63 @@ def erased_signature_similarity(self, arg_types: List[Type], is_var_arg: bool,
661691
Return similarity level (0 = no match, 1 = can match, 2 = non-promotion match). See
662692
overload_arg_similarity for a discussion of similarity levels.
663693
"""
664-
if not is_valid_argc(len(arg_types), is_var_arg, callee):
665-
return False
666-
667-
if is_var_arg:
668-
if not self.is_valid_var_arg(arg_types[-1]):
669-
return False
670-
arg_types, rest = expand_caller_var_args(arg_types,
671-
callee.max_fixed_args())
694+
formal_to_actual = map_actuals_to_formals(arg_kinds,
695+
arg_names,
696+
callee.arg_kinds,
697+
callee.arg_names,
698+
lambda i: arg_types[i])
699+
700+
if not self.check_argument_count(callee, arg_types, arg_kinds, arg_names,
701+
formal_to_actual, None, None):
702+
# Too few or many arguments -> no match.
703+
return 0
672704

673-
# Fixed function arguments.
674-
func_fixed = callee.max_fixed_args()
675705
similarity = 2
676-
for i in range(min(len(arg_types), func_fixed)):
677-
# Instead of just is_subtype, we use a relaxed overlapping check to determine
678-
# which overload variant could apply.
706+
707+
def check_arg(caller_type: Type, original_caller_type: Type,
708+
callee_type: Type, n: int, m: int, callee: CallableType,
709+
context: Context, messages: MessageBuilder) -> None:
710+
nonlocal similarity
679711
similarity = min(similarity,
680-
overload_arg_similarity(arg_types[i], callee.arg_types[i]))
712+
overload_arg_similarity(caller_type, callee_type))
681713
if similarity == 0:
682-
return 0
683-
# Function varargs.
684-
if callee.is_var_arg:
685-
for i in range(func_fixed, len(arg_types)):
686-
# See above for why we use is_compatible_overload_arg.
687-
similarity = min(similarity,
688-
overload_arg_similarity(arg_types[i],
689-
callee.arg_types[func_fixed]))
690-
if similarity == 0:
691-
return 0
714+
# No match -- exit early since none of the remaining work can change
715+
# the result.
716+
raise Finished
717+
718+
try:
719+
self.check_argument_types(arg_types, arg_kinds, callee, formal_to_actual,
720+
None, check_arg=check_arg)
721+
except Finished:
722+
pass
723+
692724
return similarity
693725

694-
def match_signature_types(self, arg_types: List[Type], is_var_arg: bool,
695-
callee: CallableType) -> bool:
726+
def match_signature_types(self, arg_types: List[Type], arg_kinds: List[int],
727+
arg_names: List[str], callee: CallableType) -> bool:
696728
"""Determine whether arguments types match the signature.
697729
698-
If is_var_arg is True, the caller uses varargs. Assume that argument
699-
counts are compatible.
730+
Assume that argument counts are compatible.
731+
732+
Return True if arguments match.
700733
"""
701-
if is_var_arg:
702-
arg_types, rest = expand_caller_var_args(arg_types,
703-
callee.max_fixed_args())
704-
705-
# Fixed function arguments.
706-
func_fixed = callee.max_fixed_args()
707-
for i in range(min(len(arg_types), func_fixed)):
708-
if not is_subtype(arg_types[i], callee.arg_types[i]):
709-
return False
710-
# Function varargs.
711-
if callee.is_var_arg:
712-
for i in range(func_fixed, len(arg_types)):
713-
if not is_subtype(arg_types[i],
714-
callee.arg_types[func_fixed]):
715-
return False
716-
return True
734+
formal_to_actual = map_actuals_to_formals(arg_kinds,
735+
arg_names,
736+
callee.arg_kinds,
737+
callee.arg_names,
738+
lambda i: arg_types[i])
739+
ok = True
740+
741+
def check_arg(caller_type: Type, original_caller_type: Type,
742+
callee_type: Type, n: int, m: int, callee: CallableType,
743+
context: Context, messages: MessageBuilder) -> None:
744+
nonlocal ok
745+
if not is_subtype(caller_type, callee_type):
746+
ok = False
747+
748+
self.check_argument_types(arg_types, arg_kinds, callee, formal_to_actual,
749+
None, check_arg=check_arg)
750+
return ok
717751

718752
def apply_generic_arguments(self, callable: CallableType, types: List[Type],
719753
context: Context) -> Type:
@@ -1437,23 +1471,6 @@ def has_member(self, typ: Type, member: str) -> bool:
14371471
return False
14381472

14391473

1440-
def is_valid_argc(nargs: int, is_var_arg: bool, callable: CallableType) -> bool:
1441-
"""Return a boolean indicating whether a call expression has a
1442-
(potentially) compatible number of arguments for calling a function.
1443-
Varargs at caller are not checked.
1444-
"""
1445-
if is_var_arg:
1446-
if callable.is_var_arg:
1447-
return True
1448-
else:
1449-
return nargs - 1 <= callable.max_fixed_args()
1450-
elif callable.is_var_arg:
1451-
return nargs >= callable.min_args
1452-
else:
1453-
# Neither has varargs.
1454-
return nargs <= len(callable.arg_types) and nargs >= callable.min_args
1455-
1456-
14571474
def map_actuals_to_formals(caller_kinds: List[int],
14581475
caller_names: List[str],
14591476
callee_kinds: List[int],

mypy/constraints.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
Instance, TupleType, UnionType, Overloaded, ErasedType, PartialType, DeletedType,
88
is_named_instance
99
)
10-
from mypy.expandtype import expand_caller_var_args
1110
from mypy.maptype import map_instance_to_supertype
1211
from mypy import nodes
1312
import mypy.subtypes

0 commit comments

Comments
 (0)