Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ cython_debug/
.abstra/

# Visual Studio Code
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# and can be added to the global gitignore or merged into this file. However, if you prefer,
# you could uncomment the following to ignore the entire vscode folder
# .vscode/

Expand Down Expand Up @@ -212,3 +212,4 @@ __marimo__/
.env*.local

.DS_Store
*~
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dev = [
"mypy>=1.11",
"ruff>=0.8",
"pyright>=1.1.408",
"async-solipsism>=0.9",
]

[tool.mypy]
Expand Down
36 changes: 36 additions & 0 deletions src/ai/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Utility functions"""

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable, AsyncIterator
from typing import Any

_EMPTY: Any = object()


async def merge[T](*aiterables: AsyncIterable[T]) -> AsyncIterator[T]:
aiters = [aiter(iter) for iter in aiterables]

async with asyncio.TaskGroup() as tg:
# Launch a task doing anext on every iterator
tasks: list[asyncio.Future[T] | None] = [
tg.create_task(anext(iter, _EMPTY)) for iter in aiters
]

while any(tasks):
done, _ = await asyncio.wait(
[t for t in tasks if t],
return_when=asyncio.FIRST_COMPLETED,
)

for t in done:
idx = tasks.index(t)
val = t.result()
if val is _EMPTY:
tasks[idx] = None
else:
# Fire off a new task for the relevant iterator
iter = aiters[idx]
tasks[idx] = tg.create_task(anext(iter, _EMPTY))
yield val
192 changes: 192 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
"""Tests for ai.util.merge."""

from __future__ import annotations

import asyncio
from collections.abc import AsyncIterable
from typing import Any

import async_solipsism # type: ignore[import-untyped]
import pytest

from ai import util


@pytest.fixture
def event_loop_policy() -> async_solipsism.EventLoopPolicy:
return async_solipsism.EventLoopPolicy()


async def _from_list(items: list[Any], delay: float = 0) -> AsyncIterable[Any]:
for item in items:
if delay:
await asyncio.sleep(delay)
print(asyncio.get_event_loop().time(), item)
yield item


async def _collect(aiter: AsyncIterable[Any]) -> list[Any]:
return [item async for item in aiter]


# -- basic behavior --------------------------------------------------------


async def test_single_iterable() -> None:
result = await _collect(util.merge(_from_list([1, 2, 3])))
assert result == [1, 2, 3]


async def test_empty_iterable() -> None:
result = await _collect(util.merge(_from_list([])))
assert result == []


async def test_no_iterables() -> None:
result = await _collect(util.merge())
assert result == []


async def test_multiple_iterables_all_items_yielded() -> None:
result = await _collect(
util.merge(
_from_list(["a", "b"]),
_from_list(["x", "y"]),
)
)
assert sorted(result) == ["a", "b", "x", "y"]


async def test_different_lengths() -> None:
result = await _collect(
util.merge(
_from_list([1, 2, 3]),
_from_list([10]),
)
)
assert sorted(result) == [1, 2, 3, 10]


async def test_many_iterables() -> None:
result = await _collect(
util.merge(
_from_list([1]),
_from_list([2]),
_from_list([3]),
_from_list([4]),
)
)
assert sorted(result) == [1, 2, 3, 4]


# -- timing (async-solipsism) ---------------------------------------------


async def test_simulated_clock_advances() -> None:
loop = asyncio.get_event_loop()
t0 = loop.time()
await _collect(
util.merge(
_from_list([1, 2], delay=10),
_from_list([3, 4], delay=5),
)
)
elapsed = loop.time() - t0
assert elapsed == 20.0


async def test_ordering_shorter_delay_first() -> None:
result = await _collect(
util.merge(
_from_list(["slow"], delay=100),
_from_list(["fast"], delay=1),
)
)
assert result == ["fast", "slow"]


# -- error handling --------------------------------------------------------


async def test_error_cancels_other_iterables() -> None:
"""When one iterable raises, the others are closed."""
closed: list[str] = []

async def good() -> AsyncIterable[str]:
try:
await asyncio.sleep(10)
yield "never"
finally:
closed.append("good")

async def bad() -> AsyncIterable[str]:
await asyncio.sleep(1)
raise RuntimeError("boom")
yield "unreachable" # noqa: B027

with pytest.raises(ExceptionGroup) as exc_info:
await _collect(util.merge(good(), bad()))

assert exc_info.group_contains(RuntimeError, match="boom")
assert "good" in closed


async def test_error_propagates() -> None:
"""The original exception is re-raised after cleanup."""

async def failing() -> AsyncIterable[int]:
yield 1
raise ValueError("oops")

with pytest.raises(ExceptionGroup) as exc_info:
await _collect(util.merge(failing()))

assert exc_info.group_contains(ValueError, match="oops")


async def test_items_before_error_are_yielded() -> None:
"""Items yielded before the error are still collected."""

async def ok() -> AsyncIterable[str]:
yield "a"
await asyncio.sleep(100)
yield "b"

async def fails_later() -> AsyncIterable[str]:
yield "x"
await asyncio.sleep(1)
raise RuntimeError("fail")

results: list[str] = []
with pytest.raises(ExceptionGroup) as exc_info:
async for item in util.merge(ok(), fails_later()):
results.append(item)

assert exc_info.group_contains(RuntimeError, match="fail")
assert "x" in results


async def test_cleanup_with_non_generator_iterable() -> None:
"""Iterables without aclose are handled gracefully."""

class SimpleIter:
def __init__(self) -> None:
self.values = iter([1, 2])

def __aiter__(self) -> SimpleIter:
return self

async def __anext__(self) -> int:
try:
return next(self.values)
except StopIteration:
raise StopAsyncIteration from None

async def failing() -> AsyncIterable[int]:
raise RuntimeError("boom")
yield 0 # noqa: B027

with pytest.raises(ExceptionGroup) as exc_info:
await _collect(util.merge(SimpleIter(), failing()))

assert exc_info.group_contains(RuntimeError, match="boom")
11 changes: 11 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading