From cff6daa67d5dfe40288f5408e37723dc622b4d45 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 16 Jun 2026 13:16:13 -0700 Subject: [PATCH 1/3] Add a util.MultiWaiter class for waiting on many events The big differences/advantages compared to asyncio.wait: * New futures to wait on may be added while the object is already being waited on * Completion order of the tasks is not lost. However, completion *batching* is, which is actually a plus for us. I think that MultiWaiter should be deterministic enough to use with workflows/temporal. --- src/ai/util.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/src/ai/util.py b/src/ai/util.py index ddaa3583..8e9f8e87 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -8,9 +8,20 @@ from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from collections.abc import AsyncIterable, AsyncIterator + from collections.abc import ( + AsyncIterable, + AsyncIterator, + Collection, + Generator, + ) -_EMPTY: Any = object() + +@dataclasses.dataclass +class _Empty: + pass + + +_EMPTY: Any = _Empty() @dataclasses.dataclass @@ -48,6 +59,63 @@ async def astop(self) -> None: await self.put(_STOP) +class MultiWaiter[T]: + """Waiter object for waiting on multiple futures. + + The advantages over using asyncio.wait are: + * New futures may be added while the object is already being waited on + * Completion order of the tasks is preserved. + + A *potential* downside is: + * Batching of future completion is lost + + But that is actually good for our use cases, since that introduces + a potential mismatch when using workflows/temporal. + """ + + def __init__(self, *tasks: asyncio.Future[T]) -> None: + self._queue: asyncio.Queue[asyncio.Future[T]] = asyncio.Queue(0) + self._tasks: dict[asyncio.Future[T], None] = {} + + # We bind this to an attribute so that the bound method is + # always the same and can be passed to remove_done_callback. + self._callback = self._queue.put_nowait + self.add(*tasks) + + def add(self, *tasks: asyncio.Future[T]) -> None: + for task in tasks: + self._tasks[task] = None + task.add_done_callback(self._callback) + + def clear(self) -> None: + for task in self._tasks: + task.remove_done_callback(self._callback) + self._tasks.clear() + + def tasks(self) -> Collection[asyncio.Future[T]]: + return self._tasks.keys() + + async def wait(self) -> asyncio.Future[T]: + t = await self._queue.get() + self._tasks.pop(t, None) + return t + + def __await__(self) -> Generator[Any, Any, asyncio.Future[T]]: + return self.wait().__await__() + + async def __aenter__(self) -> MultiWaiter[T]: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: Any | None, + ) -> bool: + self.clear() + return False + + @contextlib.asynccontextmanager async def unwrap_generator_exit() -> AsyncIterator[None]: """Unwrap ``BaseExceptionGroup`` containing only ``GeneratorExit``. From b26f087d46666d9e5b5c223fd8917574e611ca24 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 16 Jun 2026 13:16:33 -0700 Subject: [PATCH 2/3] Update ToolRunner to use the new MultiWaiter This is a big improvement! --- src/ai/agents/agent.py | 67 +++++++++++++----------------------------- 1 file changed, 20 insertions(+), 47 deletions(-) diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 06bea9e7..3cb11164 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -726,26 +726,19 @@ def __aiter__(self) -> AsyncGenerator[events_.ToolCallResult]: class ToolRunner: def __init__(self) -> None: - # A future that gets signalled when we add a new tool, so that - # asyncio.wait gets woken up and cycles around in the loop to - # wait on the new thing as well. - # Also used when add_result is called, to signal that - self._sched_waiter: asyncio.Future[None] = ( - asyncio.get_running_loop().create_future() - ) - self._active: set[ - asyncio.Future[events_.ToolCallResult] | asyncio.Future[None] - ] = set() - self._new_results: list[events_.ToolCallResult] = [] self._tool_results: list[events_.ToolCallResult] = [] self._tg_base = asyncio.TaskGroup() + self._waiter: util.MultiWaiter[events_.ToolCallResult] = ( + util.MultiWaiter() + ) async def __aenter__(self) -> Self: self._tg = await self._tg_base.__aenter__() return self async def __aexit__(self, *args: Any) -> None: + self._waiter.clear() return await self._tg_base.__aexit__(*args) def events(self) -> _RestartableToolStream: @@ -760,17 +753,13 @@ def schedule(self, tc: ToolCallCallable) -> None: in custom logic (e.g. an approval hook await) and still ride the runner's merge-and-iterate flow. """ - self._active.add(self._tg.create_task(tc())) - if not self._sched_waiter.done(): - self._sched_waiter.set_result(None) + self._waiter.add(self._tg.create_task(tc())) def add_result(self, res: events_.ToolCallResult) -> None: - self._tool_results.append(res) + async def _feed() -> events_.ToolCallResult: + return res - # Also add to _new_results and signal sched_waiter to return them - self._new_results.append(res) - if not self._sched_waiter.done(): - self._sched_waiter.set_result(None) + self._waiter.add(self._tg.create_task(_feed())) def get_tool_message(self) -> types.messages.Message | None: if self._tool_results: @@ -780,34 +769,18 @@ def get_tool_message(self) -> types.messages.Message | None: return None async def _iterate(self) -> AsyncGenerator[events_.ToolCallResult]: - while self._active: - done, _ = await asyncio.wait( - [*self._active, self._sched_waiter], - return_when=asyncio.FIRST_COMPLETED, - ) - for t in done: - self._active.discard(t) - if t is self._sched_waiter: - t.result() - - new = self._new_results - self._new_results = [] - for n in new: - yield n - self._sched_waiter = ( - asyncio.get_running_loop().create_future() - ) - else: - try: - res = t.result() - except asyncio.CancelledError: - # If a task got cancelled, that's fine. - # Need to catch it or the whole runner gets zapped. - continue - - assert res is not None - self._tool_results.append(res) - yield res + while self._waiter.tasks(): + t = await self._waiter + try: + res = t.result() + except asyncio.CancelledError: + # If a task got cancelled, that's fine. + # Need to catch it or the whole runner gets zapped. + continue + + assert res is not None + self._tool_results.append(res) + yield res class Context(pydantic.BaseModel): From 5a2123e1923727c92b9c71ce1140e3ac9a2c09b3 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 16 Jun 2026 13:17:26 -0700 Subject: [PATCH 3/3] Update merge to use MultiWaiter I'm not sure if this one is an improvement though --- src/ai/util.py | 63 ++++++++++++++++++++++++++-------------------- tests/test_util.py | 16 ++++++------ 2 files changed, 44 insertions(+), 35 deletions(-) diff --git a/src/ai/util.py b/src/ai/util.py index 8e9f8e87..7e52bb4a 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -227,7 +227,11 @@ async def merge[T]( # We use unwrap_generator_exit() to keep a GeneratorExit that gets # packaged in an ExceptionGroup from causing grief. But maybe we # ought to not use a TaskGroup? - async with unwrap_generator_exit(), asyncio.TaskGroup() as tg: + async with ( + unwrap_generator_exit(), + asyncio.TaskGroup() as tg, + MultiWaiter[T]() as mw, + ): raw_aiters = [aiter(iter) for iter in aiterables] aiters = [decouple(iter, task_group=tg) for iter in raw_aiters] # We consider anything that doesn't __aiter__ to itself to be @@ -241,37 +245,42 @@ async def merge[T]( tasks: list[asyncio.Future[T] | None] = [ tg.create_task(anext(iter, _EMPTY)) for iter in aiters ] - - while any(tasks): - done, _ = await asyncio.wait( - [t for t in tasks if t], - return_when=asyncio.FIRST_COMPLETED, - ) - - fired = [] - for t in done: - idx = tasks.index(t) - val = t.result() - if val is _EMPTY: - tasks[idx] = None - else: - # Fire off a new task for the relevant iterator - fired.append(idx) - iter = aiters[idx] - tasks[idx] = tg.create_task(anext(iter, _EMPTY)) - yield val - - if restart and fired: + mw.add(*[t for t in tasks if t]) + + top_fired = False + while mw.tasks(): + t = await mw + + idx = tasks.index(t) + val = t.result() + if val is _EMPTY: + tasks[idx] = None + else: + # Fire off a new task for the relevant iterator + top_fired = True + iter = aiters[idx] + tasks[idx] = nt = tg.create_task(anext(iter, _EMPTY)) + mw.add(nt) + yield val + + if restart and ( + val is not _EMPTY or (not mw.tasks() and top_fired) + ): + if not mw.tasks(): + top_fired = False # Also, we try *restarting* other stopped streams # that may have more to do now. + # # N.B: We do this *after* the values are yielded, so - # they've had a chance to trigger things, and we do it - # after *all* tasks have been handled, so that if a - # task *just* finished, we still restart it. + # they've had a chance to trigger things, and we also + # do it if we would otherwise terminate and we have + # seen any elements since the start or the last time + # we may have been exhausted. for idx, (ok, otask) in enumerate( zip(restartable, tasks, strict=True) ): - if ok and otask is None and idx not in fired: + if ok and otask is None: niter = decouple(aiterables[idx], task_group=tg) aiters[idx] = niter - tasks[idx] = tg.create_task(anext(niter, _EMPTY)) + tasks[idx] = nt = tg.create_task(anext(niter, _EMPTY)) + mw.add(nt) diff --git a/tests/test_util.py b/tests/test_util.py index d8afe555..e42a0d94 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -465,8 +465,8 @@ async def driver() -> AsyncIterator[str]: result = await _collect(util.merge(driver(), src)) assert sorted(result) == ["d1", "d2", "d3", "r1", "r2", "r3", "r4"] - # __aiter__ called once initially + once after each driver yield. - assert src.iter_count == 4 + # __aiter__ called once initially, and once more at the end. + assert src.iter_count == 5 async def test_merge_does_not_restart_async_generator() -> None: @@ -520,8 +520,8 @@ async def driver() -> AsyncIterator[str]: result = await _collect(util.merge(driver(), src)) assert sorted(result) == ["d1", "d2", "only"] - # Still re-iterated once per driver yield, even though nothing new arrived. - assert src.iter_count == 3 + # Still re-iterated once per driver yield, and once more at the end. + assert src.iter_count == 4 async def test_merge_restart_with_multiple_restartables() -> None: @@ -539,8 +539,8 @@ async def driver() -> AsyncIterator[str]: result = await _collect(util.merge(driver(), a, b)) assert sorted(result) == ["a1", "a2", "b1", "b2", "d1"] - assert a.iter_count == 2 - assert b.iter_count == 2 + assert a.iter_count == 3 + assert b.iter_count == 3 async def test_merge_restart_only_after_other_iterable_yields() -> None: @@ -549,10 +549,10 @@ async def test_merge_restart_only_after_other_iterable_yields() -> None: src.push("r1") # Single-iterable merge: src exhausts itself and merge ends without - # __aiter__ being called again. + # __aiter__ being called again, and once more at the end. result = await _collect(util.merge(src)) assert result == ["r1"] - assert src.iter_count == 1 + assert src.iter_count == 2 async def test_merge_restart_when_yield_and_stop_collide() -> None: