diff --git a/.gitignore b/.gitignore index 57de6beb..8d829c19 100644 --- a/.gitignore +++ b/.gitignore @@ -182,9 +182,9 @@ cython_debug/ .abstra/ # Visual Studio Code -# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore +# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore -# and can be added to the global gitignore or merged into this file. However, if you prefer, +# and can be added to the global gitignore or merged into this file. However, if you prefer, # you could uncomment the following to ignore the entire vscode folder # .vscode/ @@ -212,3 +212,4 @@ __marimo__/ .env*.local .DS_Store +*~ diff --git a/pyproject.toml b/pyproject.toml index 1ae6b44f..ea4f2fae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dev = [ "mypy>=1.11", "ruff>=0.8", "pyright>=1.1.408", + "async-solipsism>=0.9", ] [tool.mypy] diff --git a/src/ai/util.py b/src/ai/util.py new file mode 100644 index 00000000..955302fd --- /dev/null +++ b/src/ai/util.py @@ -0,0 +1,36 @@ +"""Utility functions""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable, AsyncIterator +from typing import Any + +_EMPTY: Any = object() + + +async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]: + aiters = [aiter(iter) for iter in aiterables] + + async with asyncio.TaskGroup() as tg: + # Launch a task doing anext on every iterator + tasks: list[asyncio.Future[T] | None] = [ + tg.create_task(anext(iter, _EMPTY)) for iter in aiters + ] + + while any(tasks): + done, _ = await asyncio.wait( + [t for t in tasks if t], + return_when=asyncio.FIRST_COMPLETED, + ) + + for t in done: + idx = tasks.index(t) + val = t.result() + if val is _EMPTY: + tasks[idx] = None + else: + # Fire off a new task for the relevant iterator + iter = aiters[idx] + tasks[idx] = tg.create_task(anext(iter, _EMPTY)) + yield val diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 00000000..c30af951 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,192 @@ +"""Tests for ai.util.merge.""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncIterable +from typing import Any + +import async_solipsism # type: ignore[import-untyped] +import pytest + +from ai import util + + +@pytest.fixture +def event_loop_policy() -> async_solipsism.EventLoopPolicy: + return async_solipsism.EventLoopPolicy() + + +async def _from_list(items: list[Any], delay: float = 0) -> AsyncIterable[Any]: + for item in items: + if delay: + await asyncio.sleep(delay) + print(asyncio.get_event_loop().time(), item) + yield item + + +async def _collect(aiter: AsyncIterable[Any]) -> list[Any]: + return [item async for item in aiter] + + +# -- basic behavior -------------------------------------------------------- + + +async def test_single_iterable() -> None: + result = await _collect(util.merge(_from_list([1, 2, 3]))) + assert result == [1, 2, 3] + + +async def test_empty_iterable() -> None: + result = await _collect(util.merge(_from_list([]))) + assert result == [] + + +async def test_no_iterables() -> None: + result = await _collect(util.merge()) + assert result == [] + + +async def test_multiple_iterables_all_items_yielded() -> None: + result = await _collect( + util.merge( + _from_list(["a", "b"]), + _from_list(["x", "y"]), + ) + ) + assert sorted(result) == ["a", "b", "x", "y"] + + +async def test_different_lengths() -> None: + result = await _collect( + util.merge( + _from_list([1, 2, 3]), + _from_list([10]), + ) + ) + assert sorted(result) == [1, 2, 3, 10] + + +async def test_many_iterables() -> None: + result = await _collect( + util.merge( + _from_list([1]), + _from_list([2]), + _from_list([3]), + _from_list([4]), + ) + ) + assert sorted(result) == [1, 2, 3, 4] + + +# -- timing (async-solipsism) --------------------------------------------- + + +async def test_simulated_clock_advances() -> None: + loop = asyncio.get_event_loop() + t0 = loop.time() + await _collect( + util.merge( + _from_list([1, 2], delay=10), + _from_list([3, 4], delay=5), + ) + ) + elapsed = loop.time() - t0 + assert elapsed == 20.0 + + +async def test_ordering_shorter_delay_first() -> None: + result = await _collect( + util.merge( + _from_list(["slow"], delay=100), + _from_list(["fast"], delay=1), + ) + ) + assert result == ["fast", "slow"] + + +# -- error handling -------------------------------------------------------- + + +async def test_error_cancels_other_iterables() -> None: + """When one iterable raises, the others are closed.""" + closed: list[str] = [] + + async def good() -> AsyncIterable[str]: + try: + await asyncio.sleep(10) + yield "never" + finally: + closed.append("good") + + async def bad() -> AsyncIterable[str]: + await asyncio.sleep(1) + raise RuntimeError("boom") + yield "unreachable" # noqa: B027 + + with pytest.raises(ExceptionGroup) as exc_info: + await _collect(util.merge(good(), bad())) + + assert exc_info.group_contains(RuntimeError, match="boom") + assert "good" in closed + + +async def test_error_propagates() -> None: + """The original exception is re-raised after cleanup.""" + + async def failing() -> AsyncIterable[int]: + yield 1 + raise ValueError("oops") + + with pytest.raises(ExceptionGroup) as exc_info: + await _collect(util.merge(failing())) + + assert exc_info.group_contains(ValueError, match="oops") + + +async def test_items_before_error_are_yielded() -> None: + """Items yielded before the error are still collected.""" + + async def ok() -> AsyncIterable[str]: + yield "a" + await asyncio.sleep(100) + yield "b" + + async def fails_later() -> AsyncIterable[str]: + yield "x" + await asyncio.sleep(1) + raise RuntimeError("fail") + + results: list[str] = [] + with pytest.raises(ExceptionGroup) as exc_info: + async for item in util.merge(ok(), fails_later()): + results.append(item) + + assert exc_info.group_contains(RuntimeError, match="fail") + assert "x" in results + + +async def test_cleanup_with_non_generator_iterable() -> None: + """Iterables without aclose are handled gracefully.""" + + class SimpleIter: + def __init__(self) -> None: + self.values = iter([1, 2]) + + def __aiter__(self) -> SimpleIter: + return self + + async def __anext__(self) -> int: + try: + return next(self.values) + except StopIteration: + raise StopAsyncIteration from None + + async def failing() -> AsyncIterable[int]: + raise RuntimeError("boom") + yield 0 # noqa: B027 + + with pytest.raises(ExceptionGroup) as exc_info: + await _collect(util.merge(SimpleIter(), failing())) + + assert exc_info.group_contains(RuntimeError, match="boom") diff --git a/uv.lock b/uv.lock index 7d20e44c..994b964b 100644 --- a/uv.lock +++ b/uv.lock @@ -47,6 +47,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] +[[package]] +name = "async-solipsism" +version = "0.9" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/03/c6/136e3d282bbd292fd350f854e5e1690de78a9eb21591792c759ba840550c/async_solipsism-0.9.tar.gz", hash = "sha256:552325d3b6e4f1415fbcc9aa7dc2ba8fb3f30c39b1864ad9d104fdb666b3612d", size = 30760, upload-time = "2025-12-28T12:33:16.377Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b6/be/b11466a6fa2b9685d1fa10ab79e1234d4af065603b3e1a001dd36d6f8898/async_solipsism-0.9-py3-none-any.whl", hash = "sha256:3ef10bd7dc4ee8d18564d480d88f196e44fa0bc1e8ebe9357dd9932b2154e6dc", size = 26390, upload-time = "2025-12-28T12:33:17.921Z" }, +] + [[package]] name = "attrs" version = "25.4.0" @@ -1032,6 +1041,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "async-solipsism" }, { name = "mypy" }, { name = "pyright" }, { name = "pytest" }, @@ -1053,6 +1063,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "async-solipsism", specifier = ">=0.9" }, { name = "mypy", specifier = ">=1.11" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.0" },