diff --git a/AGENTS.md b/AGENTS.md index 6ad29a1..7b6073b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -22,6 +22,7 @@ be brief, use simple language - UNLESS it's `typing` — then `from typing import Foo` (there are too many of them). - if the module name shadows a local variable in the same file, add a trailing underscore to the import: `from ..types import messages as messages_`. do not add trailing underscores preemptively, only when there is an actual collision. 3. minimize the number of helper functions, prioritize locality of behavior. +4. in any async generator, use `util.TaskGroup()` instead of `asyncio.TaskGroup()`. the stdlib version wraps a `GeneratorExit` into a `BaseExceptionGroup`, which breaks the `aclose()` path; `util.TaskGroup` unwraps a lone `GeneratorExit` back out. ## design principles diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 3cb1116..e9ceb1a 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -728,7 +728,7 @@ class ToolRunner: def __init__(self) -> None: self._new_results: list[events_.ToolCallResult] = [] self._tool_results: list[events_.ToolCallResult] = [] - self._tg_base = asyncio.TaskGroup() + self._tg_base = util.TaskGroup() self._waiter: util.MultiWaiter[events_.ToolCallResult] = ( util.MultiWaiter() ) diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index 33e3e76..deda135 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import contextvars from typing import TYPE_CHECKING, Any @@ -92,7 +91,7 @@ async def _drain() -> None: _runtime.reset(token) - async with asyncio.TaskGroup() as tg: + async with util.TaskGroup() as tg: tg.create_task(_stop_when_done(rt, _drain())) async for item in rt._event_queue: diff --git a/src/ai/util.py b/src/ai/util.py index 7e52bb4..7d0b098 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -14,6 +14,7 @@ Collection, Generator, ) + from types import TracebackType @dataclasses.dataclass @@ -116,25 +117,36 @@ async def __aexit__( return False -@contextlib.asynccontextmanager -async def unwrap_generator_exit() -> AsyncIterator[None]: - """Unwrap ``BaseExceptionGroup`` containing only ``GeneratorExit``. - - ``asyncio.TaskGroup``'s ``__aexit__`` wraps any body exception (including - ``GeneratorExit``) into a ``BaseExceptionGroup``. Inside an async - generator that means ``aclose()`` propagates an ``ExceptionGroup`` instead - of the bare ``GeneratorExit`` the protocol expects, and the aclose-task - ends up with an unretrieved exception. Wrapping the ``async with - TaskGroup(...)`` block in this manager unwraps the group back to a plain - ``GeneratorExit`` so the close path stays clean. +class TaskGroup(asyncio.TaskGroup): + """asyncio.TaskGroup that directly propagates GeneratorExit. + + If the context body raises a GeneratorExit, we don't want to wrap + it in an ExceptionGroup because that will do the wrong thing when + it bubbles up. + + So if a GeneratorExit is raised inside the context and that is the + *only* exception reported, then unwrap it and raise it by itself. + + If there are multiple exceptions, keep them packaged so as to not + lose anything. """ - try: - yield - except BaseExceptionGroup as eg: - matched, rest = eg.split(GeneratorExit) - if matched is not None and rest is None: - raise GeneratorExit from None - raise + + async def __aexit__( + self, + et: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, + ) -> None: + try: + await super().__aexit__(et, exc, tb) + except BaseExceptionGroup as eg: + if ( + isinstance(exc, GeneratorExit) + and len(eg.exceptions) == 1 + and eg.exceptions[0] is exc + ): + raise exc from None + raise @contextlib.asynccontextmanager @@ -224,12 +236,8 @@ async def merge[T]( iterators (importantly, this means that async generators are not restarted). """ - # We use unwrap_generator_exit() to keep a GeneratorExit that gets - # packaged in an ExceptionGroup from causing grief. But maybe we - # ought to not use a TaskGroup? async with ( - unwrap_generator_exit(), - asyncio.TaskGroup() as tg, + TaskGroup() as tg, MultiWaiter[T]() as mw, ): raw_aiters = [aiter(iter) for iter in aiterables] diff --git a/tests/test_util.py b/tests/test_util.py index e42a0d9..f87665c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -192,47 +192,51 @@ async def failing() -> AsyncIterable[int]: assert exc_info.group_contains(RuntimeError, match="boom") -# -- unwrap_generator_exit -------------------------------------------------- +# -- TaskGroup -------------------------------------------------------------- -async def test_unwrap_generator_exit_pure_generator_exit() -> None: - """A group containing only GeneratorExit unwraps to GeneratorExit.""" +async def test_taskgroup_unwraps_lone_generator_exit() -> None: + """A GeneratorExit in the body unwraps from the group it gets wrapped in.""" with pytest.raises(GeneratorExit): - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup("group", [GeneratorExit()]) + async with util.TaskGroup(): + raise GeneratorExit -async def test_unwrap_generator_exit_nested_generator_exits() -> None: - """Nested groups containing only GeneratorExits also unwrap.""" - with pytest.raises(GeneratorExit): - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup( - "outer", - [BaseExceptionGroup("inner", [GeneratorExit()])], - ) +async def test_taskgroup_generator_exit_with_task_error_propagates() -> None: + """A GeneratorExit alongside a task failure stays packaged in the group.""" + with pytest.raises(BaseExceptionGroup) as exc_info: + async with util.TaskGroup() as tg: + async def boom() -> None: + raise ValueError("x") -async def test_unwrap_generator_exit_mixed_propagates() -> None: - """A group with non-GeneratorExit exceptions propagates as-is.""" - with pytest.raises(BaseExceptionGroup) as exc_info: - async with util.unwrap_generator_exit(): - raise BaseExceptionGroup( - "group", [GeneratorExit(), ValueError("x")] - ) + tg.create_task(boom()) + await asyncio.sleep(0) + raise GeneratorExit assert exc_info.group_contains(ValueError, match="x") + assert exc_info.group_contains(GeneratorExit) -async def test_unwrap_generator_exit_non_group_passes_through() -> None: - """Non-group exceptions pass through unchanged.""" - with pytest.raises(ValueError, match="x"): - async with util.unwrap_generator_exit(): +async def test_taskgroup_non_generator_exit_propagates() -> None: + """A non-GeneratorExit body exception propagates as the usual group.""" + with pytest.raises(BaseExceptionGroup) as exc_info: + async with util.TaskGroup(): raise ValueError("x") + assert exc_info.group_contains(ValueError, match="x") + + +async def test_taskgroup_no_exception() -> None: + """No exception: behaves like a normal TaskGroup.""" + ran = False + + async with util.TaskGroup() as tg: + async def work() -> None: + nonlocal ran + ran = True -async def test_unwrap_generator_exit_no_exception() -> None: - """No exception → context manager returns normally.""" - async with util.unwrap_generator_exit(): - pass + tg.create_task(work()) + assert ran # -- maybe_aclosing --------------------------------------------------------