From a8b5e28d9806c3ac15c53ddffc20baab0d04b548 Mon Sep 17 00:00:00 2001 From: Jacob Tomlinson Date: Tue, 21 May 2024 17:51:16 +0100 Subject: [PATCH] Ensure type is passed through _io.sync (#384) --- kr8s/_io.py | 31 +++++++++++++++++++++++-------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/kr8s/_io.py b/kr8s/_io.py index adaaca12..f3f50f96 100644 --- a/kr8s/_io.py +++ b/kr8s/_io.py @@ -14,15 +14,31 @@ import inspect import subprocess +import sys import tempfile from contextlib import asynccontextmanager from functools import partial, wraps from threading import Thread -from typing import Any, AsyncGenerator, Awaitable, Callable, Generator, Tuple, TypeVar +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + Generator, + Tuple, + TypeVar, +) + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec import anyio T = TypeVar("T") +C = TypeVar("C") +P = ParamSpec("P") class Portal: @@ -44,13 +60,13 @@ async def _run(self): self._portal = portal await portal.sleep_until_stopped() - def call(self, func: Callable[..., T], *args, **kwargs) -> T: + def call(self, func: Callable[P, T], *args, **kwargs) -> T: while not self._portal: pass return self._portal.call(func, *args, **kwargs) -def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: +def run_sync(coro: Callable[P, Awaitable[T]]) -> Callable[P, T]: """Wraps coroutine in a function that blocks until it has executed. Parameters @@ -65,18 +81,17 @@ def run_sync(coro: Callable[..., Awaitable[T]]) -> Callable[..., T]: """ @wraps(coro) - def wrapped(*args, **kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: wrapped = partial(coro, *args, **kwargs) wrapped.__doc__ = coro.__doc__ if inspect.isasyncgenfunction(coro): return iter_over_async(wrapped) - portal = Portal() if inspect.iscoroutinefunction(coro): + portal = Portal() return portal.call(wrapped) raise TypeError(f"Expected coroutine function, got {coro.__class__.__name__}") - wrapped.__doc__ = coro.__doc__ - return wrapped + return wrapper def iter_over_async(agen: AsyncGenerator) -> Generator: @@ -97,7 +112,7 @@ async def get_next() -> Tuple[bool, Any]: yield obj -def sync(source: object) -> object: +def sync(source: C) -> C: """Convert all public async methods/properties of an object to universal methods. See :func:`run_sync` for more info