From 36b2d06616fa2a217dd408ddc4152239119d0670 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Thu, 18 Jun 2026 11:28:02 -0700 Subject: [PATCH] Define our own TaskGroup subtype that doesn't munge up GeneratorExit TaskGroup will wrap a GeneratorExit in the with body into an ExceptionGroup, which then won't be recognized right by the enclosing generator's machinery. Fix this by unwrapping an ExceptionGroup if it contains *only* a GeneratorExit that was raised by the context body. This means that the straightforward case still works cleanly but we don't ever throw away exceptions from tasks. Previously we were handling this case with a second decorator, just in `merge`, but I've been seeing this error in other places hit logs while working on seal. This is a known issue in Python but there doesn't seem to quite have been consensus on what the fix is? This not quite what Yury proposed in https://github.com/python/cpython/issues/135736#issuecomment-3348203075, which is to *always* propagate GeneratorExit directly, and to call `loop.call_exception_handler()` if there are any exceptions in tasks. I prefer the version in this PR because it keeps the task errors reported while still covering the common case, but I could be missing something. @1st1? If we like this approach then maybe we should suggest it as the fix for the cpython issue? (There might also be some reason why it is suitable for our use but not as the default?) --- AGENTS.md | 1 + src/ai/agents/agent.py | 2 +- src/ai/agents/runtime.py | 3 +- src/ai/util.py | 54 +++++++++++++++++++++--------------- tests/test_util.py | 60 +++++++++++++++++++++------------------- 5 files changed, 66 insertions(+), 54 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index 6ad29a1c..7b6073bd 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -22,6 +22,7 @@ be brief, use simple language - UNLESS it's `typing` — then `from typing import Foo` (there are too many of them). - if the module name shadows a local variable in the same file, add a trailing underscore to the import: `from ..types import messages as messages_`. do not add trailing underscores preemptively, only when there is an actual collision. 3. minimize the number of helper functions, prioritize locality of behavior. +4. in any async generator, use `util.TaskGroup()` instead of `asyncio.TaskGroup()`. the stdlib version wraps a `GeneratorExit` into a `BaseExceptionGroup`, which breaks the `aclose()` path; `util.TaskGroup` unwraps a lone `GeneratorExit` back out. ## design principles diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 3cb11164..e9ceb1a7 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -728,7 +728,7 @@ class ToolRunner: def __init__(self) -> None: self._new_results: list[events_.ToolCallResult] = [] self._tool_results: list[events_.ToolCallResult] = [] - self._tg_base = asyncio.TaskGroup() + self._tg_base = util.TaskGroup() self._waiter: util.MultiWaiter[events_.ToolCallResult] = ( util.MultiWaiter() ) diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index 33e3e760..deda1356 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import contextvars from typing import TYPE_CHECKING, Any @@ -92,7 +91,7 @@ async def _drain() -> None: _runtime.reset(token) - async with asyncio.TaskGroup() as tg: + async with util.TaskGroup() as tg: tg.create_task(_stop_when_done(rt, _drain())) async for item in rt._event_queue: diff --git a/src/ai/util.py b/src/ai/util.py index 7e52bb4a..7d0b098e 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -14,6 +14,7 @@ Collection, Generator, ) + from types import TracebackType @dataclasses.dataclass @@ -116,25 +117,36 @@ async def __aexit__( return False -@contextlib.asynccontextmanager -async def unwrap_generator_exit() -> AsyncIterator[None]: - """Unwrap ``BaseExceptionGroup`` containing only ``GeneratorExit``. - - ``asyncio.TaskGroup``'s ``__aexit__`` wraps any body exception (including - ``GeneratorExit``) into a ``BaseExceptionGroup``. Inside an async - generator that means ``aclose()`` propagates an ``ExceptionGroup`` instead - of the bare ``GeneratorExit`` the protocol expects, and the aclose-task - ends up with an unretrieved exception. Wrapping the ``async with - TaskGroup(...)`` block in this manager unwraps the group back to a plain - ``GeneratorExit`` so the close path stays clean. +class TaskGroup(asyncio.TaskGroup): + """asyncio.TaskGroup that directly propagates GeneratorExit. + + If the context body raises a GeneratorExit, we don't want to wrap + it in an ExceptionGroup because that will do the wrong thing when + it bubbles up. + + So if a GeneratorExit is raised inside the context and that is the + *only* exception reported, then unwrap it and raise it by itself. + + If there are multiple exceptions, keep them packaged so as to not + lose anything. """ - try: - yield - except BaseExceptionGroup as eg: - matched, rest = eg.split(GeneratorExit) - if matched is not None and rest is None: - raise GeneratorExit from None - raise + + async def __aexit__( + self, + et: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + try: + await super().__aexit__(et, exc, tb) + except BaseExceptionGroup as eg: + if ( + isinstance(exc, GeneratorExit) + and len(eg.exceptions) == 1 + and eg.exceptions[0] is exc + ): + raise exc from None + raise @contextlib.asynccontextmanager @@ -224,12 +236,8 @@ async def merge[T]( iterators (importantly, this means that async generators are not restarted). """ - # We use unwrap_generator_exit() to keep a GeneratorExit that gets - # packaged in an ExceptionGroup from causing grief. But maybe we - # ought to not use a TaskGroup? async with ( - unwrap_generator_exit(), - asyncio.TaskGroup() as tg, + TaskGroup() as tg, MultiWaiter[T]() as mw, ): raw_aiters = [aiter(iter) for iter in aiterables] diff --git a/tests/test_util.py b/tests/test_util.py index e42a0d94..f87665c3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -192,47 +192,51 @@ async def failing() -> AsyncIterable[int]: assert exc_info.group_contains(RuntimeError, match="boom") -# -- unwrap_generator_exit -------------------------------------------------- +# -- TaskGroup -------------------------------------------------------------- -async def test_unwrap_generator_exit_pure_generator_exit() -> None: - """A group containing only GeneratorExit unwraps to GeneratorExit.""" +async def test_taskgroup_unwraps_lone_generator_exit() -> None: + """A GeneratorExit in the body unwraps from the group it gets wrapped in.""" with pytest.raises(GeneratorExit): - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup("group", [GeneratorExit()]) + async with util.TaskGroup(): + raise GeneratorExit -async def test_unwrap_generator_exit_nested_generator_exits() -> None: - """Nested groups containing only GeneratorExits also unwrap.""" - with pytest.raises(GeneratorExit): - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup( - "outer", - [BaseExceptionGroup("inner", [GeneratorExit()])], - ) +async def test_taskgroup_generator_exit_with_task_error_propagates() -> None: + """A GeneratorExit alongside a task failure stays packaged in the group.""" + with pytest.raises(BaseExceptionGroup) as exc_info: + async with util.TaskGroup() as tg: + async def boom() -> None: + raise ValueError("x") -async def test_unwrap_generator_exit_mixed_propagates() -> None: - """A group with non-GeneratorExit exceptions propagates as-is.""" - with pytest.raises(BaseExceptionGroup) as exc_info: - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup( - "group", [GeneratorExit(), ValueError("x")] - ) + tg.create_task(boom()) + await asyncio.sleep(0) + raise GeneratorExit assert exc_info.group_contains(ValueError, match="x") + assert exc_info.group_contains(GeneratorExit) -async def test_unwrap_generator_exit_non_group_passes_through() -> None: - """Non-group exceptions pass through unchanged.""" - with pytest.raises(ValueError, match="x"): - async with util.unwrap_generator_exit(): +async def test_taskgroup_non_generator_exit_propagates() -> None: + """A non-GeneratorExit body exception propagates as the usual group.""" + with pytest.raises(BaseExceptionGroup) as exc_info: + async with util.TaskGroup(): raise ValueError("x") + assert exc_info.group_contains(ValueError, match="x") + + +async def test_taskgroup_no_exception() -> None: + """No exception: behaves like a normal TaskGroup.""" + ran = False + + async with util.TaskGroup() as tg: + async def work() -> None: + nonlocal ran + ran = True -async def test_unwrap_generator_exit_no_exception() -> None: - """No exception → context manager returns normally.""" - async with util.unwrap_generator_exit(): - pass + tg.create_task(work()) + assert ran # -- maybe_aclosing --------------------------------------------------------