Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ be brief, use simple language
- UNLESS it's `typing` — then `from typing import Foo` (there are too many of them).
- if the module name shadows a local variable in the same file, add a trailing underscore to the import: `from ..types import messages as messages_`. do not add trailing underscores preemptively, only when there is an actual collision.
3. minimize the number of helper functions, prioritize locality of behavior.
4. in any async generator, use `util.TaskGroup()` instead of `asyncio.TaskGroup()`. the stdlib version wraps a `GeneratorExit` into a `BaseExceptionGroup`, which breaks the `aclose()` path; `util.TaskGroup` unwraps a lone `GeneratorExit` back out.

## design principles

Expand Down
2 changes: 1 addition & 1 deletion src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ class ToolRunner:
def __init__(self) -> None:
self._new_results: list[events_.ToolCallResult] = []
self._tool_results: list[events_.ToolCallResult] = []
self._tg_base = asyncio.TaskGroup()
self._tg_base = util.TaskGroup()
self._waiter: util.MultiWaiter[events_.ToolCallResult] = (
util.MultiWaiter()
)
Expand Down
3 changes: 1 addition & 2 deletions src/ai/agents/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import asyncio
import contextvars
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -92,7 +91,7 @@ async def _drain() -> None:

_runtime.reset(token)

async with asyncio.TaskGroup() as tg:
async with util.TaskGroup() as tg:
tg.create_task(_stop_when_done(rt, _drain()))

async for item in rt._event_queue:
Expand Down
54 changes: 31 additions & 23 deletions src/ai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Collection,
Generator,
)
from types import TracebackType


@dataclasses.dataclass
Expand Down Expand Up @@ -116,25 +117,36 @@ async def __aexit__(
return False


@contextlib.asynccontextmanager
async def unwrap_generator_exit() -> AsyncIterator[None]:
"""Unwrap ``BaseExceptionGroup`` containing only ``GeneratorExit``.

``asyncio.TaskGroup``'s ``__aexit__`` wraps any body exception (including
``GeneratorExit``) into a ``BaseExceptionGroup``. Inside an async
generator that means ``aclose()`` propagates an ``ExceptionGroup`` instead
of the bare ``GeneratorExit`` the protocol expects, and the aclose-task
ends up with an unretrieved exception. Wrapping the ``async with
TaskGroup(...)`` block in this manager unwraps the group back to a plain
``GeneratorExit`` so the close path stays clean.
class TaskGroup(asyncio.TaskGroup):
"""asyncio.TaskGroup that directly propagates GeneratorExit.

If the context body raises a GeneratorExit, we don't want to wrap
it in an ExceptionGroup because that will do the wrong thing when
it bubbles up.

So if a GeneratorExit is raised inside the context and that is the
*only* exception reported, then unwrap it and raise it by itself.

If there are multiple exceptions, keep them packaged so as to not
lose anything.
"""
try:
yield
except BaseExceptionGroup as eg:
matched, rest = eg.split(GeneratorExit)
if matched is not None and rest is None:
raise GeneratorExit from None
raise

async def __aexit__(
self,
et: type[BaseException] | None,
exc: BaseException | None,
tb: TracebackType | None,
) -> None:
try:
await super().__aexit__(et, exc, tb)
except BaseExceptionGroup as eg:
if (
isinstance(exc, GeneratorExit)
and len(eg.exceptions) == 1
and eg.exceptions[0] is exc
):
raise exc from None
raise


@contextlib.asynccontextmanager
Expand Down Expand Up @@ -224,12 +236,8 @@ async def merge[T](
iterators (importantly, this means that async generators are not
restarted).
"""
# We use unwrap_generator_exit() to keep a GeneratorExit that gets
# packaged in an ExceptionGroup from causing grief. But maybe we
# ought to not use a TaskGroup?
async with (
unwrap_generator_exit(),
asyncio.TaskGroup() as tg,
TaskGroup() as tg,
MultiWaiter[T]() as mw,
):
raw_aiters = [aiter(iter) for iter in aiterables]
Expand Down
60 changes: 32 additions & 28 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,47 +192,51 @@ async def failing() -> AsyncIterable[int]:
assert exc_info.group_contains(RuntimeError, match="boom")


# -- unwrap_generator_exit --------------------------------------------------
# -- TaskGroup --------------------------------------------------------------


async def test_unwrap_generator_exit_pure_generator_exit() -> None:
"""A group containing only GeneratorExit unwraps to GeneratorExit."""
async def test_taskgroup_unwraps_lone_generator_exit() -> None:
"""A GeneratorExit in the body unwraps from the group it gets wrapped in."""
with pytest.raises(GeneratorExit):
async with util.unwrap_generator_exit():
raise BaseExceptionGroup("group", [GeneratorExit()])
async with util.TaskGroup():
raise GeneratorExit


async def test_unwrap_generator_exit_nested_generator_exits() -> None:
"""Nested groups containing only GeneratorExits also unwrap."""
with pytest.raises(GeneratorExit):
async with util.unwrap_generator_exit():
raise BaseExceptionGroup(
"outer",
[BaseExceptionGroup("inner", [GeneratorExit()])],
)
async def test_taskgroup_generator_exit_with_task_error_propagates() -> None:
"""A GeneratorExit alongside a task failure stays packaged in the group."""
with pytest.raises(BaseExceptionGroup) as exc_info:
async with util.TaskGroup() as tg:

async def boom() -> None:
raise ValueError("x")

async def test_unwrap_generator_exit_mixed_propagates() -> None:
"""A group with non-GeneratorExit exceptions propagates as-is."""
with pytest.raises(BaseExceptionGroup) as exc_info:
async with util.unwrap_generator_exit():
raise BaseExceptionGroup(
"group", [GeneratorExit(), ValueError("x")]
)
tg.create_task(boom())
await asyncio.sleep(0)
raise GeneratorExit
assert exc_info.group_contains(ValueError, match="x")
assert exc_info.group_contains(GeneratorExit)


async def test_unwrap_generator_exit_non_group_passes_through() -> None:
"""Non-group exceptions pass through unchanged."""
with pytest.raises(ValueError, match="x"):
async with util.unwrap_generator_exit():
async def test_taskgroup_non_generator_exit_propagates() -> None:
"""A non-GeneratorExit body exception propagates as the usual group."""
with pytest.raises(BaseExceptionGroup) as exc_info:
async with util.TaskGroup():
raise ValueError("x")
assert exc_info.group_contains(ValueError, match="x")


async def test_taskgroup_no_exception() -> None:
"""No exception: behaves like a normal TaskGroup."""
ran = False

async with util.TaskGroup() as tg:

async def work() -> None:
nonlocal ran
ran = True

async def test_unwrap_generator_exit_no_exception() -> None:
"""No exception → context manager returns normally."""
async with util.unwrap_generator_exit():
pass
tg.create_task(work())
assert ran


# -- maybe_aclosing --------------------------------------------------------
Expand Down
Loading