Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 20 additions & 47 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,26 +726,19 @@ def __aiter__(self) -> AsyncGenerator[events_.ToolCallResult]:

class ToolRunner:
def __init__(self) -> None:
# A future that gets signalled when we add a new tool, so that
# asyncio.wait gets woken up and cycles around in the loop to
# wait on the new thing as well.
# Also used when add_result is called, to signal that
self._sched_waiter: asyncio.Future[None] = (
asyncio.get_running_loop().create_future()
)
self._active: set[
asyncio.Future[events_.ToolCallResult] | asyncio.Future[None]
] = set()

self._new_results: list[events_.ToolCallResult] = []
self._tool_results: list[events_.ToolCallResult] = []
self._tg_base = asyncio.TaskGroup()
self._waiter: util.MultiWaiter[events_.ToolCallResult] = (
util.MultiWaiter()
)

async def __aenter__(self) -> Self:
self._tg = await self._tg_base.__aenter__()
return self

async def __aexit__(self, *args: Any) -> None:
self._waiter.clear()
return await self._tg_base.__aexit__(*args)

def events(self) -> _RestartableToolStream:
Expand All @@ -760,17 +753,13 @@ def schedule(self, tc: ToolCallCallable) -> None:
in custom logic (e.g. an approval hook await) and still ride the
runner's merge-and-iterate flow.
"""
self._active.add(self._tg.create_task(tc()))
if not self._sched_waiter.done():
self._sched_waiter.set_result(None)
self._waiter.add(self._tg.create_task(tc()))

def add_result(self, res: events_.ToolCallResult) -> None:
self._tool_results.append(res)
async def _feed() -> events_.ToolCallResult:
return res

# Also add to _new_results and signal sched_waiter to return them
self._new_results.append(res)
if not self._sched_waiter.done():
self._sched_waiter.set_result(None)
self._waiter.add(self._tg.create_task(_feed()))

def get_tool_message(self) -> types.messages.Message | None:
if self._tool_results:
Expand All @@ -780,34 +769,18 @@ def get_tool_message(self) -> types.messages.Message | None:
return None

async def _iterate(self) -> AsyncGenerator[events_.ToolCallResult]:
while self._active:
done, _ = await asyncio.wait(
[*self._active, self._sched_waiter],
return_when=asyncio.FIRST_COMPLETED,
)
for t in done:
self._active.discard(t)
if t is self._sched_waiter:
t.result()

new = self._new_results
self._new_results = []
for n in new:
yield n
self._sched_waiter = (
asyncio.get_running_loop().create_future()
)
else:
try:
res = t.result()
except asyncio.CancelledError:
# If a task got cancelled, that's fine.
# Need to catch it or the whole runner gets zapped.
continue

assert res is not None
self._tool_results.append(res)
yield res
while self._waiter.tasks():
t = await self._waiter
try:
res = t.result()
except asyncio.CancelledError:
# If a task got cancelled, that's fine.
# Need to catch it or the whole runner gets zapped.
continue

assert res is not None
self._tool_results.append(res)
yield res


class Context(pydantic.BaseModel):
Expand Down
135 changes: 106 additions & 29 deletions src/ai/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,20 @@
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import AsyncIterable, AsyncIterator
from collections.abc import (
AsyncIterable,
AsyncIterator,
Collection,
Generator,
)

_EMPTY: Any = object()

@dataclasses.dataclass
class _Empty:
pass


_EMPTY: Any = _Empty()


@dataclasses.dataclass
Expand Down Expand Up @@ -48,6 +59,63 @@ async def astop(self) -> None:
await self.put(_STOP)


class MultiWaiter[T]:
"""Waiter object for waiting on multiple futures.

The advantages over using asyncio.wait are:
* New futures may be added while the object is already being waited on
* Completion order of the tasks is preserved.

A *potential* downside is:
* Batching of future completion is lost

But that is actually good for our use cases, since that introduces
a potential mismatch when using workflows/temporal.
"""

def __init__(self, *tasks: asyncio.Future[T]) -> None:
self._queue: asyncio.Queue[asyncio.Future[T]] = asyncio.Queue(0)
self._tasks: dict[asyncio.Future[T], None] = {}

# We bind this to an attribute so that the bound method is
# always the same and can be passed to remove_done_callback.
self._callback = self._queue.put_nowait
self.add(*tasks)

def add(self, *tasks: asyncio.Future[T]) -> None:
for task in tasks:
self._tasks[task] = None
task.add_done_callback(self._callback)

def clear(self) -> None:
for task in self._tasks:
task.remove_done_callback(self._callback)
self._tasks.clear()

def tasks(self) -> Collection[asyncio.Future[T]]:
return self._tasks.keys()

async def wait(self) -> asyncio.Future[T]:
t = await self._queue.get()
self._tasks.pop(t, None)
return t

def __await__(self) -> Generator[Any, Any, asyncio.Future[T]]:
return self.wait().__await__()

async def __aenter__(self) -> MultiWaiter[T]:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc: BaseException | None,
tb: Any | None,
) -> bool:
self.clear()
return False


@contextlib.asynccontextmanager
async def unwrap_generator_exit() -> AsyncIterator[None]:
"""Unwrap ``BaseExceptionGroup`` containing only ``GeneratorExit``.
Expand Down Expand Up @@ -159,7 +227,11 @@ async def merge[T](
# We use unwrap_generator_exit() to keep a GeneratorExit that gets
# packaged in an ExceptionGroup from causing grief. But maybe we
# ought to not use a TaskGroup?
async with unwrap_generator_exit(), asyncio.TaskGroup() as tg:
async with (
unwrap_generator_exit(),
asyncio.TaskGroup() as tg,
MultiWaiter[T]() as mw,
):
raw_aiters = [aiter(iter) for iter in aiterables]
aiters = [decouple(iter, task_group=tg) for iter in raw_aiters]
# We consider anything that doesn't __aiter__ to itself to be
Expand All @@ -173,37 +245,42 @@ async def merge[T](
tasks: list[asyncio.Future[T] | None] = [
tg.create_task(anext(iter, _EMPTY)) for iter in aiters
]

while any(tasks):
done, _ = await asyncio.wait(
[t for t in tasks if t],
return_when=asyncio.FIRST_COMPLETED,
)

fired = []
for t in done:
idx = tasks.index(t)
val = t.result()
if val is _EMPTY:
tasks[idx] = None
else:
# Fire off a new task for the relevant iterator
fired.append(idx)
iter = aiters[idx]
tasks[idx] = tg.create_task(anext(iter, _EMPTY))
yield val

if restart and fired:
mw.add(*[t for t in tasks if t])

top_fired = False
while mw.tasks():
t = await mw

idx = tasks.index(t)
val = t.result()
if val is _EMPTY:
tasks[idx] = None
else:
# Fire off a new task for the relevant iterator
top_fired = True
iter = aiters[idx]
tasks[idx] = nt = tg.create_task(anext(iter, _EMPTY))
mw.add(nt)
yield val

if restart and (
val is not _EMPTY or (not mw.tasks() and top_fired)
):
if not mw.tasks():
top_fired = False
# Also, we try *restarting* other stopped streams
# that may have more to do now.
#
# N.B: We do this *after* the values are yielded, so
# they've had a chance to trigger things, and we do it
# after *all* tasks have been handled, so that if a
# task *just* finished, we still restart it.
# they've had a chance to trigger things, and we also
# do it if we would otherwise terminate and we have
# seen any elements since the start or the last time
# we may have been exhausted.
for idx, (ok, otask) in enumerate(
zip(restartable, tasks, strict=True)
):
if ok and otask is None and idx not in fired:
if ok and otask is None:
niter = decouple(aiterables[idx], task_group=tg)
aiters[idx] = niter
tasks[idx] = tg.create_task(anext(niter, _EMPTY))
tasks[idx] = nt = tg.create_task(anext(niter, _EMPTY))
mw.add(nt)
16 changes: 8 additions & 8 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,8 @@ async def driver() -> AsyncIterator[str]:

result = await _collect(util.merge(driver(), src))
assert sorted(result) == ["d1", "d2", "d3", "r1", "r2", "r3", "r4"]
# __aiter__ called once initially + once after each driver yield.
assert src.iter_count == 4
# __aiter__ called once initially, and once more at the end.
assert src.iter_count == 5


async def test_merge_does_not_restart_async_generator() -> None:
Expand Down Expand Up @@ -520,8 +520,8 @@ async def driver() -> AsyncIterator[str]:

result = await _collect(util.merge(driver(), src))
assert sorted(result) == ["d1", "d2", "only"]
# Still re-iterated once per driver yield, even though nothing new arrived.
assert src.iter_count == 3
# Still re-iterated once per driver yield, and once more at the end.
assert src.iter_count == 4


async def test_merge_restart_with_multiple_restartables() -> None:
Expand All @@ -539,8 +539,8 @@ async def driver() -> AsyncIterator[str]:

result = await _collect(util.merge(driver(), a, b))
assert sorted(result) == ["a1", "a2", "b1", "b2", "d1"]
assert a.iter_count == 2
assert b.iter_count == 2
assert a.iter_count == 3
assert b.iter_count == 3


async def test_merge_restart_only_after_other_iterable_yields() -> None:
Expand All @@ -549,10 +549,10 @@ async def test_merge_restart_only_after_other_iterable_yields() -> None:
src.push("r1")

# Single-iterable merge: src exhausts itself and merge ends without
# __aiter__ being called again.
# __aiter__ being called again, and once more at the end.
result = await _collect(util.merge(src))
assert result == ["r1"]
assert src.iter_count == 1
assert src.iter_count == 2


async def test_merge_restart_when_yield_and_stop_collide() -> None:
Expand Down
Loading