From fec120a32b7fd945ff8a27f9fb753da880239de8 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Mon, 27 Apr 2026 17:53:51 -0700 Subject: [PATCH 1/3] Implement merge for async iterables (Was called interleave in discussions earlier, but @elprans thought interleave sounded too deterministic.) --- .gitignore | 5 +- src/ai/util.py | 46 +++++++++++ tests/test_util.py | 186 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 2 deletions(-) create mode 100644 src/ai/util.py create mode 100644 tests/test_util.py diff --git a/.gitignore b/.gitignore index 57de6beb..8d829c19 100644 --- a/.gitignore +++ b/.gitignore @@ -182,9 +182,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ @@ -212,3 +212,4 @@ __marimo__/ .env*.local .DS_Store +*~ diff --git a/src/ai/util.py b/src/ai/util.py new file mode 100644 index 00000000..15739e40 --- /dev/null +++ b/src/ai/util.py @@ -0,0 +1,46 @@ +"""Utility functions""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable, AsyncIterator + + +async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: + aiters = [iter.__aiter__() for iter in aiterables] + + # Launch a task doing anext on every iterator + tasks: list[asyncio.Future[T] | None] = [ + asyncio.ensure_future(iter.__anext__()) for iter in aiters + ] + + try: + while any(tasks): + done, _ = await asyncio.wait( + [t for t in tasks if t], + return_when=asyncio.FIRST_COMPLETED, + ) + + for t in done: + idx = tasks.index(t) + # Note: .exception() could raise CancelledError + if exc := t.exception(): + # Happy case for exception is StopAsyncIteration + # For other exceptions, raise + tasks[idx] = None + if not isinstance(exc, StopAsyncIteration): + raise exc + else: + # Fire off a new task for the relevant iterator + iter = aiters[idx] + tasks[idx] = asyncio.ensure_future(iter.__anext__()) + yield t.result() + except Exception: + for task in tasks: + if task: + task.cancel() + + live = [t for t in tasks if t] + await asyncio.gather(*live, return_exceptions=True) + + raise diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..5b7b1f18 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,186 @@ +"""Tests for ai.util.merge.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable +from typing import Any + +import async_solipsism # type: ignore[import-untyped] +import pytest + +from ai import util + + +@pytest.fixture +def event_loop_policy() -> async_solipsism.EventLoopPolicy: + return async_solipsism.EventLoopPolicy() + + +async def _from_list(items: list[Any], delay: float = 0) -> AsyncIterable[Any]: + for item in items: + if delay: + await asyncio.sleep(delay) + print(asyncio.get_event_loop().time(), item) + yield item + + +async def _collect(aiter: AsyncIterable[Any]) -> list[Any]: + return [item async for item in aiter] + + +# -- basic behavior -------------------------------------------------------- + + +async def test_single_iterable() -> None: + result = await _collect(util.merge(_from_list([1, 2, 3]))) + assert result == [1, 2, 3] + + +async def test_empty_iterable() -> None: + result = await _collect(util.merge(_from_list([]))) + assert result == [] + + +async def test_no_iterables() -> None: + result = await _collect(util.merge()) + assert result == [] + + +async def test_multiple_iterables_all_items_yielded() -> None: + result = await _collect( + util.merge( + _from_list(["a", "b"]), + _from_list(["x", "y"]), + ) + ) + assert sorted(result) == ["a", "b", "x", "y"] + + +async def test_different_lengths() -> None: + result = await _collect( + util.merge( + _from_list([1, 2, 3]), + _from_list([10]), + ) + ) + assert sorted(result) == [1, 2, 3, 10] + + +async def test_many_iterables() -> None: + result = await _collect( + util.merge( + _from_list([1]), + _from_list([2]), + _from_list([3]), + _from_list([4]), + ) + ) + assert sorted(result) == [1, 2, 3, 4] + + +# -- timing (async-solipsism) --------------------------------------------- + + +async def test_simulated_clock_advances() -> None: + loop = asyncio.get_event_loop() + t0 = loop.time() + await _collect( + util.merge( + _from_list([1, 2], delay=10), + _from_list([3, 4], delay=5), + ) + ) + elapsed = loop.time() - t0 + assert elapsed == 20.0 + + +async def test_ordering_shorter_delay_first() -> None: + result = await _collect( + util.merge( + _from_list(["slow"], delay=100), + _from_list(["fast"], delay=1), + ) + ) + assert result == ["fast", "slow"] + + +# -- error handling -------------------------------------------------------- + + +async def test_error_cancels_other_iterables() -> None: + """When one iterable raises, the others are closed.""" + closed: list[str] = [] + + async def good() -> AsyncIterable[str]: + try: + await asyncio.sleep(10) + yield "never" + finally: + closed.append("good") + + async def bad() -> AsyncIterable[str]: + await asyncio.sleep(1) + raise RuntimeError("boom") + yield "unreachable" # noqa: B027 + + with pytest.raises(RuntimeError, match="boom"): + await _collect(util.merge(good(), bad())) + + assert "good" in closed + + +async def test_error_propagates() -> None: + """The original exception is re-raised after cleanup.""" + + async def failing() -> AsyncIterable[int]: + yield 1 + raise ValueError("oops") + + with pytest.raises(ValueError, match="oops"): + await _collect(util.merge(failing())) + + +async def test_items_before_error_are_yielded() -> None: + """Items yielded before the error are still collected.""" + + async def ok() -> AsyncIterable[str]: + yield "a" + await asyncio.sleep(100) + yield "b" + + async def fails_later() -> AsyncIterable[str]: + yield "x" + await asyncio.sleep(1) + raise RuntimeError("fail") + + results: list[str] = [] + with pytest.raises(RuntimeError, match="fail"): + async for item in util.merge(ok(), fails_later()): + results.append(item) + + assert "x" in results + + +async def test_cleanup_with_non_generator_iterable() -> None: + """Iterables without aclose are handled gracefully.""" + + class SimpleIter: + def __init__(self) -> None: + self.values = iter([1, 2]) + + def __aiter__(self) -> SimpleIter: + return self + + async def __anext__(self) -> int: + try: + return next(self.values) + except StopIteration: + raise StopAsyncIteration from None + + async def failing() -> AsyncIterable[int]: + raise RuntimeError("boom") + yield 0 # noqa: B027 + + with pytest.raises(RuntimeError, match="boom"): + await _collect(util.merge(SimpleIter(), failing())) From 25dc522457521e19a1e8c652ea5917bb58bac081 Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 28 Apr 2026 10:36:15 -0700 Subject: [PATCH 2/3] use aiter/anext --- src/ai/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ai/util.py b/src/ai/util.py index 15739e40..27833e02 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -7,11 +7,11 @@ async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: - aiters = [iter.__aiter__() for iter in aiterables] + aiters = [aiter(iter) for iter in aiterables] # Launch a task doing anext on every iterator tasks: list[asyncio.Future[T] | None] = [ - asyncio.ensure_future(iter.__anext__()) for iter in aiters + asyncio.ensure_future(anext(iter)) for iter in aiters ] try: @@ -33,7 +33,7 @@ async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: else: # Fire off a new task for the relevant iterator iter = aiters[idx] - tasks[idx] = asyncio.ensure_future(iter.__anext__()) + tasks[idx] = asyncio.ensure_future(anext(iter)) yield t.result() except Exception: for task in tasks: From efa5422f2d27a9c3d832df80e81684ad9ecf7ecb Mon Sep 17 00:00:00 2001 From: "Michael J. Sullivan" Date: Tue, 28 Apr 2026 11:01:01 -0700 Subject: [PATCH 3/3] Switch to useing TaskGroup --- pyproject.toml | 1 + src/ai/util.py | 34 ++++++++++++---------------------- tests/test_util.py | 14 ++++++++++---- uv.lock | 11 +++++++++++ 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ae6b44f..ea4f2fae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dev = [ "mypy>=1.11", "ruff>=0.8", "pyright>=1.1.408", + "async-solipsism>=0.9", ] [tool.mypy] diff --git a/src/ai/util.py b/src/ai/util.py index 27833e02..955302fd 100644 --- a/src/ai/util.py +++ b/src/ai/util.py @@ -4,17 +4,20 @@ import asyncio from collections.abc import AsyncIterable, AsyncIterator +from typing import Any + +_EMPTY: Any = object() async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: aiters = [aiter(iter) for iter in aiterables] - # Launch a task doing anext on every iterator - tasks: list[asyncio.Future[T] | None] = [ - asyncio.ensure_future(anext(iter)) for iter in aiters - ] + async with asyncio.TaskGroup() as tg: + # Launch a task doing anext on every iterator + tasks: list[asyncio.Future[T] | None] = [ + tg.create_task(anext(iter, _EMPTY)) for iter in aiters + ] - try: while any(tasks): done, _ = await asyncio.wait( [t for t in tasks if t], @@ -23,24 +26,11 @@ async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: for t in done: idx = tasks.index(t) - # Note: .exception() could raise CancelledError - if exc := t.exception(): - # Happy case for exception is StopAsyncIteration - # For other exceptions, raise + val = t.result() + if val is _EMPTY: tasks[idx] = None - if not isinstance(exc, StopAsyncIteration): - raise exc else: # Fire off a new task for the relevant iterator iter = aiters[idx] - tasks[idx] = asyncio.ensure_future(anext(iter)) - yield t.result() - except Exception: - for task in tasks: - if task: - task.cancel() - - live = [t for t in tasks if t] - await asyncio.gather(*live, return_exceptions=True) - - raise + tasks[idx] = tg.create_task(anext(iter, _EMPTY)) + yield val diff --git a/tests/test_util.py b/tests/test_util.py index 5b7b1f18..c30af951 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -124,9 +124,10 @@ async def bad() -> AsyncIterable[str]: raise RuntimeError("boom") yield "unreachable" # noqa: B027 - with pytest.raises(RuntimeError, match="boom"): + with pytest.raises(ExceptionGroup) as exc_info: await _collect(util.merge(good(), bad())) + assert exc_info.group_contains(RuntimeError, match="boom") assert "good" in closed @@ -137,9 +138,11 @@ async def failing() -> AsyncIterable[int]: yield 1 raise ValueError("oops") - with pytest.raises(ValueError, match="oops"): + with pytest.raises(ExceptionGroup) as exc_info: await _collect(util.merge(failing())) + assert exc_info.group_contains(ValueError, match="oops") + async def test_items_before_error_are_yielded() -> None: """Items yielded before the error are still collected.""" @@ -155,10 +158,11 @@ async def fails_later() -> AsyncIterable[str]: raise RuntimeError("fail") results: list[str] = [] - with pytest.raises(RuntimeError, match="fail"): + with pytest.raises(ExceptionGroup) as exc_info: async for item in util.merge(ok(), fails_later()): results.append(item) + assert exc_info.group_contains(RuntimeError, match="fail") assert "x" in results @@ -182,5 +186,7 @@ async def failing() -> AsyncIterable[int]: raise RuntimeError("boom") yield 0 # noqa: B027 - with pytest.raises(RuntimeError, match="boom"): + with pytest.raises(ExceptionGroup) as exc_info: await _collect(util.merge(SimpleIter(), failing())) + + assert exc_info.group_contains(RuntimeError, match="boom") diff --git a/uv.lock b/uv.lock index 7d20e44c..994b964b 100644 --- a/uv.lock +++ b/uv.lock @@ -47,6 +47,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "async-solipsism" +version = "0.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/c6/136e3d282bbd292fd350f854e5e1690de78a9eb21591792c759ba840550c/async_solipsism-0.9.tar.gz", hash = "sha256:552325d3b6e4f1415fbcc9aa7dc2ba8fb3f30c39b1864ad9d104fdb666b3612d", size = 30760, upload-time = "2025-12-28T12:33:16.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/be/b11466a6fa2b9685d1fa10ab79e1234d4af065603b3e1a001dd36d6f8898/async_solipsism-0.9-py3-none-any.whl", hash = "sha256:3ef10bd7dc4ee8d18564d480d88f196e44fa0bc1e8ebe9357dd9932b2154e6dc", size = 26390, upload-time = "2025-12-28T12:33:17.921Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -1032,6 +1041,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "async-solipsism" }, { name = "mypy" }, { name = "pyright" }, { name = "pytest" }, @@ -1053,6 +1063,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "async-solipsism", specifier = ">=0.9" }, { name = "mypy", specifier = ">=1.11" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.0" },