From c94e36f44f442f20ee1fd57a0b53d9595ef6c75b Mon Sep 17 00:00:00 2001 From: Alex Streed Date: Wed, 29 Jan 2025 10:50:38 -0600 Subject: [PATCH] Add missing overload for `Task.__call__` --- src/prefect/tasks.py | 14 ++++++++------ tests/typesafety/test_flows.yml | 14 +++++++++++++- tests/typesafety/test_tasks.yml | 14 +++++++++++++- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 77637225730e..9b291f7a6686 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -965,7 +965,7 @@ async def create_local_run( def __call__( self: "Task[P, NoReturn]", *args: P.args, - return_state: Literal[False], + return_state: Literal[False] = False, wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, ) -> None: @@ -977,20 +977,22 @@ def __call__( def __call__( self: "Task[P, R]", *args: P.args, - return_state: Literal[True], - wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> State[R]: + ) -> R: ... + # Keyword parameters `return_state` and `wait_for` aren't allowed after the + # ParamSpec `*args` parameter, so we lose return type typing when either of + # those are provided. + # TODO: Find a way to expose this functionality without losing type information @overload def __call__( self: "Task[P, R]", *args: P.args, - return_state: Literal[False], + return_state: Literal[True] = True, wait_for: Optional[OneOrManyFutureOrResult[Any]] = None, **kwargs: P.kwargs, - ) -> R: + ) -> State[R]: ... @overload diff --git a/tests/typesafety/test_flows.yml b/tests/typesafety/test_flows.yml index 6c4f42794093..f3537f16df96 100644 --- a/tests/typesafety/test_flows.yml +++ b/tests/typesafety/test_flows.yml @@ -34,4 +34,16 @@ reveal_type(foo) out: "main:5: note: Revealed type is \"\ prefect.flows.Flow[[bar: builtins.str], builtins.int]\ - \"" \ No newline at end of file + \"" + +- case: prefect_flow_call + main: | + from prefect import flow + @flow + def foo(bar: str) -> int: + return 42 + ret = foo(bar="baz") + reveal_type(ret) + out: "main:6: note: Revealed type is \"\ + builtins.int\ + \"" diff --git a/tests/typesafety/test_tasks.yml b/tests/typesafety/test_tasks.yml index 9dec994907cd..b2a5934d23e1 100644 --- a/tests/typesafety/test_tasks.yml +++ b/tests/typesafety/test_tasks.yml @@ -56,4 +56,16 @@ reveal_type(foo) out: "main:9: note: Revealed type is \"\ prefect.tasks.Task[[bar: builtins.str], builtins.int]\ - \"" \ No newline at end of file + \"" + +- case: prefect_task_call + main: | + from prefect import task + @task + def foo(bar: str) -> int: + return 42 + ret = foo(bar="baz") + reveal_type(ret) + out: "main:6: note: Revealed type is \"\ + builtins.int\ + \""