diff --git a/docs/source/reference-core.rst b/docs/source/reference-core.rst index efab743dc..ee4468ce6 100644 --- a/docs/source/reference-core.rst +++ b/docs/source/reference-core.rst @@ -1607,7 +1607,14 @@ the numbers 0 through 9 with a 1-second delay before each one: trio.run(use_it) -Trio supports async generators, with some caveats described in this section. +Trio supports async generators, but there's several caveats and it's very +hard to handle them properly. Therefore Trio bundles a helper, +`trio.as_safe_channel` that does it for you. + + +.. autofunction:: trio.as_safe_channel + +The details behind the problems are described in the following sections. Finalization ~~~~~~~~~~~~ @@ -1737,7 +1744,8 @@ so sometimes you'll get an unhelpful `TrioInternalError`. (And sometimes it will seem to work, which is probably the worst outcome of all, since then you might not notice the issue until you perform some minor refactoring of the generator or the code that's iterating it, or -just get unlucky. There is a `proposed Python enhancement +just get unlucky. There is a draft :pep:`789` with accompanying +`discussion thread `__ that would at least make it fail consistently.) @@ -1753,12 +1761,6 @@ the generator is suspended, what should the background tasks do? There's no good way to suspend them, but if they keep running and throw an exception, where can that exception be reraised? -If you have an async generator that wants to ``yield`` from within a nursery -or cancel scope, your best bet is to refactor it to be a separate task -that communicates over memory channels. The ``trio_util`` package offers a -`decorator that does this for you transparently -`__. - For more discussion, see Trio issues `264 `__ (especially `this comment diff --git a/newsfragments/3197.feature.rst b/newsfragments/3197.feature.rst new file mode 100644 index 000000000..2a04c5f52 --- /dev/null +++ b/newsfragments/3197.feature.rst @@ -0,0 +1,2 @@ +Add :func:`@trio.as_safe_channel `, a wrapper that can be used to make async generators safe. +This will be the suggested fix for the flake8-async lint rule `ASYNC900 `_. diff --git a/pyproject.toml b/pyproject.toml index 463f68908..ba2c4387d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -331,6 +331,7 @@ exclude_also = [ "@overload", 'class .*\bProtocol\b.*\):', "raise NotImplementedError", + '.*if "sphinx" in sys.modules:', 'TODO: test this line' ] partial_branches = [ diff --git a/src/trio/__init__.py b/src/trio/__init__.py index 339ef7b0c..b937ac5b9 100644 --- a/src/trio/__init__.py +++ b/src/trio/__init__.py @@ -27,6 +27,7 @@ MemoryChannelStatistics as MemoryChannelStatistics, MemoryReceiveChannel as MemoryReceiveChannel, MemorySendChannel as MemorySendChannel, + as_safe_channel as as_safe_channel, open_memory_channel as open_memory_channel, ) from ._core import ( diff --git a/src/trio/_channel.py b/src/trio/_channel.py index 6410d9120..1ed594579 100644 --- a/src/trio/_channel.py +++ b/src/trio/_channel.py @@ -1,6 +1,10 @@ from __future__ import annotations +import sys from collections import OrderedDict, deque +from collections.abc import AsyncGenerator, Callable # noqa: TC003 # Needed for Sphinx +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from functools import wraps from math import inf from typing import ( TYPE_CHECKING, @@ -14,12 +18,31 @@ from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T from ._core import Abort, RaiseCancelT, Task, enable_ki_protection -from ._util import NoPublicConstructor, final, generic_function +from ._util import ( + MultipleExceptionError, + NoPublicConstructor, + final, + generic_function, + raise_single_exception_from_group, +) + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup if TYPE_CHECKING: from types import TracebackType - from typing_extensions import Self + from typing_extensions import ParamSpec, Self + + P = ParamSpec("P") +elif "sphinx" in sys.modules: + # P needs to exist for Sphinx to parse the type hints successfully. + try: + from typing_extensions import ParamSpec + except ImportError: + P = ... # This is valid in Callable, though not correct + else: + P = ParamSpec("P") def _open_memory_channel( @@ -440,3 +463,124 @@ async def aclose(self) -> None: See `MemoryReceiveChannel.close`.""" self.close() await trio.lowlevel.checkpoint() + + +class RecvChanWrapper(ReceiveChannel[T]): + def __init__( + self, recv_chan: MemoryReceiveChannel[T], send_semaphore: trio.Semaphore + ) -> None: + self._recv_chan = recv_chan + self._send_semaphore = send_semaphore + + async def receive(self) -> T: + self._send_semaphore.release() + return await self._recv_chan.receive() + + async def aclose(self) -> None: + await self._recv_chan.aclose() + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self._recv_chan.close() + + +def as_safe_channel( + fn: Callable[P, AsyncGenerator[T, None]], +) -> Callable[P, AbstractAsyncContextManager[ReceiveChannel[T]]]: + """Decorate an async generator function to make it cancellation-safe. + + The ``yield`` keyword offers a very convenient way to write iterators... + which makes it really unfortunate that async generators are so difficult + to call correctly. Yielding from the inside of a cancel scope or a nursery + to the outside `violates structured concurrency `_ + with consequences explained in :pep:`789`. Even then, resource cleanup + errors remain common (:pep:`533`) unless you wrap every call in + :func:`~contextlib.aclosing`. + + This decorator gives you the best of both worlds: with careful exception + handling and a background task we preserve structured concurrency by + offering only the safe interface, and you can still write your iterables + with the convenience of ``yield``. For example:: + + @as_safe_channel + async def my_async_iterable(arg, *, kwarg=True): + while ...: + item = await ... + yield item + + async with my_async_iterable(...) as recv_chan: + async for item in recv_chan: + ... + + While the combined async-with-async-for can be inconvenient at first, + the context manager is indispensable for both correctness and for prompt + cleanup of resources. + """ + # Perhaps a future PEP will adopt `async with for` syntax, like + # https://coconut.readthedocs.io/en/master/DOCS.html#async-with-for + + @asynccontextmanager + @wraps(fn) + async def context_manager( + *args: P.args, **kwargs: P.kwargs + ) -> AsyncGenerator[trio._channel.RecvChanWrapper[T], None]: + send_chan, recv_chan = trio.open_memory_channel[T](0) + try: + async with trio.open_nursery(strict_exception_groups=True) as nursery: + agen = fn(*args, **kwargs) + send_semaphore = trio.Semaphore(0) + # `nursery.start` to make sure that we will clean up send_chan & agen + # If this errors we don't close `recv_chan`, but the caller + # never gets access to it, so that's not a problem. + await nursery.start( + _move_elems_to_channel, agen, send_chan, send_semaphore + ) + # `async with recv_chan` could eat exceptions, so use sync cm + with RecvChanWrapper(recv_chan, send_semaphore) as wrapped_recv_chan: + yield wrapped_recv_chan + # User has exited context manager, cancel to immediately close the + # abandoned generator if it's still alive. + nursery.cancel_scope.cancel() + except BaseExceptionGroup as eg: + try: + raise_single_exception_from_group(eg) + except MultipleExceptionError: + # In case user has except* we make it possible for them to handle the + # exceptions. + raise BaseExceptionGroup( + "Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.", + [eg], + ) from None + + async def _move_elems_to_channel( + agen: AsyncGenerator[T, None], + send_chan: trio.MemorySendChannel[T], + send_semaphore: trio.Semaphore, + task_status: trio.TaskStatus, + ) -> None: + # `async with send_chan` will eat exceptions, + # see https://github.com/python-trio/trio/issues/1559 + with send_chan: + try: + task_status.started() + while True: + # wait for receiver to call next on the aiter + await send_semaphore.acquire() + try: + value = await agen.__anext__() + except StopAsyncIteration: + return + # Send the value to the channel + await send_chan.send(value) + finally: + # replace try-finally with contextlib.aclosing once python39 is dropped + await agen.aclose() + + return context_manager diff --git a/src/trio/_tests/test_channel.py b/src/trio/_tests/test_channel.py index 104b17640..f1556a153 100644 --- a/src/trio/_tests/test_channel.py +++ b/src/trio/_tests/test_channel.py @@ -1,13 +1,20 @@ from __future__ import annotations -from typing import Union +import sys +from typing import TYPE_CHECKING, Union import pytest import trio -from trio import EndOfChannel, open_memory_channel +from trio import EndOfChannel, as_safe_channel, open_memory_channel -from ..testing import assert_checkpoints, wait_all_tasks_blocked +from ..testing import Matcher, RaisesGroup, assert_checkpoints, wait_all_tasks_blocked + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator async def test_channel() -> None: @@ -411,3 +418,210 @@ async def do_send(s: trio.MemorySendChannel[int], v: int) -> None: assert await r.receive() == 1 with pytest.raises(trio.WouldBlock): r.receive_nowait() + + +async def test_as_safe_channel_exhaust() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + yield 1 + + async with agen() as recv_chan: + async for x in recv_chan: + assert x == 1 + + +async def test_as_safe_channel_broken_resource() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + yield 1 + yield 2 + + async with agen() as recv_chan: + assert await recv_chan.__anext__() == 1 + + # close the receiving channel + await recv_chan.aclose() + + # trying to get the next element errors + with pytest.raises(trio.ClosedResourceError): + await recv_chan.__anext__() + + # but we don't get an error on exit of the cm + + +async def test_as_safe_channel_cancelled() -> None: + with trio.CancelScope() as cs: + + @as_safe_channel + async def agen() -> AsyncGenerator[None]: # pragma: no cover + raise AssertionError( + "cancel before consumption means generator should not be iterated" + ) + yield # indicate that we're an iterator + + async with agen(): + cs.cancel() + + +async def test_as_safe_channel_no_race() -> None: + # this previously led to a race condition due to + # https://github.com/python-trio/trio/issues/1559 + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + yield 1 + raise ValueError("oae") + + with pytest.raises(ValueError, match=r"^oae$"): + async with agen() as recv_chan: + async for x in recv_chan: + assert x == 1 + + +async def test_as_safe_channel_buffer_size_too_small( + autojump_clock: trio.testing.MockClock, +) -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + yield 1 + raise AssertionError( + "buffer size 0 means we shouldn't be asked for another value" + ) # pragma: no cover + + with trio.move_on_after(5): + async with agen() as recv_chan: + async for x in recv_chan: # pragma: no branch + assert x == 1 + await trio.sleep_forever() + + +async def test_as_safe_channel_no_interleave() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + yield 1 + raise AssertionError # pragma: no cover + + async with agen() as recv_chan: + assert await recv_chan.__anext__() == 1 + await trio.lowlevel.checkpoint() + + +async def test_as_safe_channel_genexit_finally() -> None: + @as_safe_channel + async def agen(events: list[str]) -> AsyncGenerator[int]: + try: + yield 1 + except BaseException as e: + events.append(repr(e)) + raise + finally: + events.append("finally") + raise ValueError("agen") + + events: list[str] = [] + with RaisesGroup( + RaisesGroup( + Matcher(ValueError, match="^agen$"), + Matcher(TypeError, match="^iterator$"), + ), + match=r"^Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.$", + ): + async with agen(events) as recv_chan: + async for i in recv_chan: # pragma: no branch + assert i == 1 + raise TypeError("iterator") + + assert events == ["GeneratorExit()", "finally"] + + +async def test_as_safe_channel_nested_loop() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + for i in range(2): + yield i + + ii = 0 + async with agen() as recv_chan1: + async for i in recv_chan1: + async with agen() as recv_chan: + jj = 0 + async for j in recv_chan: + assert (i, j) == (ii, jj) + jj += 1 + ii += 1 + + +async def test_as_safe_channel_doesnt_leak_cancellation() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + with trio.CancelScope() as cscope: + cscope.cancel() + yield + + with pytest.raises(AssertionError): + async with agen() as recv_chan: + async for _ in recv_chan: + pass + raise AssertionError("should be reachable") + + +async def test_as_safe_channel_dont_unwrap_user_exceptiongroup() -> None: + @as_safe_channel + async def agen() -> AsyncGenerator[None]: + raise NotImplementedError("not entered") + yield # pragma: no cover + + with RaisesGroup(Matcher(ValueError, match="bar"), match="foo"): + async with agen() as _: + raise ExceptionGroup("foo", [ValueError("bar")]) + + +async def test_as_safe_channel_multiple_receiver() -> None: + event = trio.Event() + + @as_safe_channel + async def agen() -> AsyncGenerator[int]: + await event.wait() + yield 0 + yield 1 + + async def handle_value( + recv_chan: trio.abc.ReceiveChannel[int], + value: int, + task_status: trio.TaskStatus, + ) -> None: + task_status.started() + assert await recv_chan.receive() == value + + async with agen() as recv_chan: + async with trio.open_nursery() as nursery: + await nursery.start(handle_value, recv_chan, 0) + await nursery.start(handle_value, recv_chan, 1) + event.set() + + +async def test_as_safe_channel_multi_cancel() -> None: + @as_safe_channel + async def agen(events: list[str]) -> AsyncGenerator[None]: + try: + yield + finally: + # this will give a warning of ASYNC120, although it's not technically a + # problem of swallowing existing exceptions + try: + await trio.lowlevel.checkpoint() + except trio.Cancelled: + events.append("agen cancel") + raise + + events: list[str] = [] + with trio.CancelScope() as cs: + with pytest.raises(trio.Cancelled): + async with agen(events) as recv_chan: + async for _ in recv_chan: # pragma: no branch + cs.cancel() + try: + await trio.lowlevel.checkpoint() + except trio.Cancelled: + events.append("body cancel") + raise + assert events == ["body cancel", "agen cancel"] diff --git a/src/trio/_tests/test_util.py b/src/trio/_tests/test_util.py index ba11f5a31..c0b0a3108 100644 --- a/src/trio/_tests/test_util.py +++ b/src/trio/_tests/test_util.py @@ -18,6 +18,7 @@ ) from .._util import ( ConflictDetector, + MultipleExceptionError, NoPublicConstructor, coroutine_or_error, final, @@ -288,18 +289,20 @@ async def test_raise_single_exception_from_group() -> None: assert excinfo.value.__cause__ == cause assert excinfo.value.__context__ == context - with pytest.raises(ValueError, match="foo") as excinfo: - raise_single_exception_from_group( - ExceptionGroup("", [ExceptionGroup("", [exc])]) - ) - assert excinfo.value.__cause__ == cause - assert excinfo.value.__context__ == context + # only unwraps one layer of exceptiongroup + inner_eg = ExceptionGroup("inner eg", [exc]) + inner_cause = SyntaxError("inner eg cause") + inner_context = TypeError("inner eg context") + inner_eg.__cause__ = inner_cause + inner_eg.__context__ = inner_context + with RaisesGroup(Matcher(ValueError, match="^foo$"), match="^inner eg$") as eginfo: + raise_single_exception_from_group(ExceptionGroup("", [inner_eg])) + assert eginfo.value.__cause__ == inner_cause + assert eginfo.value.__context__ == inner_context with pytest.raises(ValueError, match="foo") as excinfo: raise_single_exception_from_group( - BaseExceptionGroup( - "", [cancelled, BaseExceptionGroup("", [cancelled, exc])] - ) + BaseExceptionGroup("", [cancelled, cancelled, exc]) ) assert excinfo.value.__cause__ == cause assert excinfo.value.__context__ == context @@ -307,7 +310,7 @@ async def test_raise_single_exception_from_group() -> None: # multiple non-cancelled eg = ExceptionGroup("", [ValueError("foo"), ValueError("bar")]) with pytest.raises( - AssertionError, + MultipleExceptionError, match=r"^Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller.$", ) as excinfo: raise_single_exception_from_group(eg) @@ -328,6 +331,20 @@ async def test_raise_single_exception_from_group() -> None: assert excinfo.value.__cause__ is eg_ki assert excinfo.value.__context__ is None + # and same for SystemExit + systemexit_ki = BaseExceptionGroup( + "", + [ + ValueError("foo"), + ValueError("bar"), + SystemExit("this exc doesn't get reraised"), + ], + ) + with pytest.raises(SystemExit, match=r"^$") as excinfo: + raise_single_exception_from_group(systemexit_ki) + assert excinfo.value.__cause__ is systemexit_ki + assert excinfo.value.__context__ is None + # if we only got cancelled, first one is reraised with pytest.raises(trio.Cancelled, match=r"^Cancelled$") as excinfo: raise_single_exception_from_group( diff --git a/src/trio/_util.py b/src/trio/_util.py index 54fc5ff73..665674911 100644 --- a/src/trio/_util.py +++ b/src/trio/_util.py @@ -4,7 +4,6 @@ import collections.abc import inspect import signal -import sys from abc import ABCMeta from collections.abc import Awaitable, Callable, Sequence from functools import update_wrapper @@ -21,19 +20,20 @@ import trio -if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup - # Explicit "Any" is not allowed CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[explicit-any] T = TypeVar("T") RetT = TypeVar("RetT") if TYPE_CHECKING: + import sys from types import AsyncGeneratorType, TracebackType from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack + if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + ArgsT = ParamSpec("ArgsT") PosArgsT = TypeVarTuple("PosArgsT") @@ -359,9 +359,10 @@ def wraps( # type: ignore[explicit-any] from functools import wraps # noqa: F401 # this is re-exported -def _raise(exc: BaseException) -> NoReturn: +def raise_saving_context(exc: BaseException) -> NoReturn: """This helper allows re-raising an exception without __context__ being set.""" # cause does not need special handling, we simply avoid using `raise .. from ..` + # __suppress_context__ also does not need handling, it's only set if modifying cause __tracebackhide__ = True context = exc.__context__ try: @@ -371,12 +372,19 @@ def _raise(exc: BaseException) -> NoReturn: del exc, context +class MultipleExceptionError(Exception): + """Raised by raise_single_exception_from_group if encountering multiple + non-cancelled exceptions.""" + + def raise_single_exception_from_group( eg: BaseExceptionGroup[BaseException], ) -> NoReturn: """This function takes an exception group that is assumed to have at most one non-cancelled exception, which it reraises as a standalone exception. + This exception may be an exceptiongroup itself, in which case it will not be unwrapped. + If a :exc:`KeyboardInterrupt` is encountered, a new KeyboardInterrupt is immediately raised with the entire group as cause. @@ -388,30 +396,27 @@ def raise_single_exception_from_group( If multiple non-cancelled exceptions are encountered, it raises :exc:`AssertionError`. """ - cancelled_exceptions = [] - noncancelled_exceptions = [] - - # subgroup/split retains excgroup structure, so we need to manually traverse - def _parse_excg(e: BaseException) -> None: + # immediately bail out if there's any KI or SystemExit + for e in eg.exceptions: if isinstance(e, (KeyboardInterrupt, SystemExit)): - # immediately bail out - raise KeyboardInterrupt from eg + raise type(e) from eg + + cancelled_exception: trio.Cancelled | None = None + noncancelled_exception: BaseException | None = None + for e in eg.exceptions: if isinstance(e, trio.Cancelled): - cancelled_exceptions.append(e) - elif isinstance(e, BaseExceptionGroup): - for sub_e in e.exceptions: - _parse_excg(sub_e) + if cancelled_exception is None: + cancelled_exception = e + elif noncancelled_exception is None: + noncancelled_exception = e else: - noncancelled_exceptions.append(e) - - _parse_excg(eg) - - if len(noncancelled_exceptions) > 1: - raise AssertionError( - "Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller." - ) from eg - if len(noncancelled_exceptions) == 1: - _raise(noncancelled_exceptions[0]) - assert cancelled_exceptions, "internal error" - _raise(cancelled_exceptions[0]) + raise MultipleExceptionError( + "Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller." + ) from eg + + if noncancelled_exception is not None: + raise_saving_context(noncancelled_exception) + + assert cancelled_exception is not None, "group can't be empty" + raise_saving_context(cancelled_exception)