From 6312c1b35946bcca26937c9ecae1d39f320b2d18 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 28 Apr 2026 11:11:17 -0700 Subject: [PATCH 1/2] Rework models api 1. implement stream context manager 2. make messages and parts mutable 3. move message aggregation from individual adapters into ai.stream 4. simplify adapters to emit only streaming events 5. temporarily patch agents and middleware for the new api --- src/ai/__init__.py | 30 +- src/ai/agents/agent.py | 45 ++- src/ai/agents/events.py | 45 +++ src/ai/agents/runtime.py | 13 +- src/ai/agents/ui/ai_sdk/outbound/sse.py | 4 +- src/ai/agents/ui/ai_sdk/outbound/stream.py | 8 +- src/ai/middleware.py | 17 +- src/ai/models/__init__.py | 23 +- src/ai/models/ai_gateway/generate.py | 6 +- src/ai/models/ai_gateway/stream.py | 66 ++-- src/ai/models/anthropic/adapter.py | 66 +--- src/ai/models/core/__init__.py | 21 +- src/ai/models/core/api.py | 299 +++++++++-------- src/ai/models/core/helpers/streaming.py | 219 ------------- src/ai/models/core/proto.py | 2 +- src/ai/models/core/types.py | 11 - src/ai/models/openai/adapter.py | 89 ++--- src/ai/types/__init__.py | 17 +- src/ai/types/events.py | 79 +++-- src/ai/types/messages.py | 27 -- src/ai/types/proto.py | 34 -- src/ai/types/stream.py | 105 ------ tests/agents/test_generator_tools.py | 53 +-- tests/agents/test_hooks.py | 9 +- tests/agents/test_runtime.py | 70 ---- tests/agents/ui/ai_sdk/outbound/test_sse.py | 13 +- .../agents/ui/ai_sdk/outbound/test_stream.py | 53 +-- tests/conftest.py | 137 ++++---- tests/models/ai_gateway/test_protocol.py | 26 +- tests/models/ai_gateway/test_stream.py | 13 +- tests/models/core/test_streaming.py | 202 ------------ tests/models/test_public_api.py | 78 +---- tests/test_middleware.py | 308 +----------------- tests/types/test_messages.py | 45 --- 34 files changed, 590 insertions(+), 1643 deletions(-) create mode 100644 src/ai/agents/events.py delete mode 100644 src/ai/models/core/helpers/streaming.py delete mode 100644 src/ai/models/core/types.py delete mode 100644 src/ai/types/stream.py delete mode 100644 tests/models/core/test_streaming.py diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 840fbcd2..45ec41ef 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -17,10 +17,15 @@ from .middleware import AgentRunContext, Middleware from .models import ( Client, + Executor, + GenerateExecutor, + GenerateRequest, ImageParams, Model, Provider, - StreamResult, + Stream, + StreamExecutor, + StreamRequest, VideoParams, ai_gateway, anthropic, @@ -32,22 +37,20 @@ # Re-export core types from .types import ( - End, Event, + FileEvent, FilePart, HookPart, HookResolution, HookSuspention, Message, - MessageEnd, - MessageStart, Part, ReasoningDelta, ReasoningEnd, ReasoningPart, ReasoningStart, - Start, - StreamResultLike, + StreamEnd, + StreamStart, StructuredOutputPart, TextDelta, TextEnd, @@ -74,12 +77,10 @@ __all__ = [ # Types (from types/) - "Start", - "End", "Event", "Message", - "MessageStart", - "MessageEnd", + "StreamStart", + "StreamEnd", "Part", "TextPart", "TextStart", @@ -94,6 +95,7 @@ "ReasoningStart", "ReasoningDelta", "ReasoningEnd", + "FileEvent", "FilePart", "HookPart", "HookSuspention", @@ -116,8 +118,12 @@ "ImageParams", "VideoParams", "Client", - "StreamResult", - "StreamResultLike", + "Stream", + "StreamRequest", + "GenerateRequest", + "Executor", + "StreamExecutor", + "GenerateExecutor", "check_connection", "stream", "generate", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index d399cf96..37fcdc20 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -13,6 +13,7 @@ from .. import middleware as middleware_ from .. import models, types from ..types import builders +from . import events as events_ from . import runtime @@ -199,21 +200,23 @@ def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]: ] -StreamItem = types.Event | types.Message +StreamItem = events_.AgentEvent | types.Message class LoopFn(Protocol): def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ... -async def _message_events(message: types.Message) -> AsyncGenerator[types.Event]: - yield types.MessageStart(message=message) - yield types.MessageEnd(message=message) +async def _message_events( + message: types.Message, +) -> AsyncGenerator[events_.AgentEvent]: + yield events_.MessageStart(message=message) + yield events_.MessageEnd(message=message) async def _coerce_events( source: AsyncIterable[StreamItem], -) -> AsyncGenerator[types.Event]: +) -> AsyncGenerator[events_.AgentEvent]: async for item in source: if isinstance(item, types.Message): async for event in _message_events(item): @@ -222,15 +225,23 @@ async def _coerce_events( yield item -async def _default_loop(context: Context) -> AsyncGenerator[types.Event]: +async def _default_loop(context: Context) -> AsyncGenerator[events_.AgentEvent]: while True: stream = models.stream( context.model, context.messages, tools=context.tools, ) - async for event in stream: - yield event + async for stream_event in stream: + yield stream_event + + # Bridge: emit MessageStart/MessageEnd around the assistant message + # the model stream just produced, so _collect_messages and downstream + # consumers (AI-SDK outbound, label stamping) see the same boundary + # events they did under the previous adapter contract. + if stream.message is not None and stream.message.parts: + async for boundary in _message_events(stream.message): + yield boundary tool_calls = context.resolve(stream.tool_calls) if not tool_calls: @@ -244,14 +255,14 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Event]: # Left un-stamped: the tool result is the input of the *next* turn, # so the next stream() call will stamp it with that turn's id. tool_msg = builders.tool_message(*(t.result() for t in tasks)) - async for event in _message_events(tool_msg): - yield event + async for boundary in _message_events(tool_msg): + yield boundary async def _collect_messages( source: AsyncIterable[StreamItem], messages: list[types.Message], -) -> AsyncGenerator[types.Event]: +) -> AsyncGenerator[events_.AgentEvent]: """Intercept yielded events and collect MessageEnd messages into *messages*. This runs on the **producer** side (same coroutine as the loop function), @@ -260,7 +271,7 @@ async def _collect_messages( happened on the consumer side of the runtime queue. """ async for event in _coerce_events(source): - if isinstance(event, types.MessageEnd): + if isinstance(event, events_.MessageEnd): message = event.message for i, existing in enumerate(messages): if existing.id == message.id: @@ -292,7 +303,7 @@ async def yield_from(source: AsyncIterable[StreamItem]) -> str: last: types.Message | None = None async for item in _coerce_events(source): await rt.put_event(item) - if isinstance(item, types.MessageEnd): + if isinstance(item, events_.MessageEnd): last = item.message return last.text if last else "" @@ -325,7 +336,7 @@ async def run( *, label: str | None = None, middleware: list[middleware_.Middleware] | None = None, - ) -> AsyncGenerator[types.Event]: + ) -> AsyncGenerator[events_.AgentEvent]: """Run the agent loop, yielding events to the consumer. Args: @@ -349,7 +360,7 @@ async def run( async def _real( call: middleware_.AgentRunContext, - ) -> AsyncGenerator[types.Event]: + ) -> AsyncGenerator[events_.AgentEvent]: context = Context( model=call.model, messages=list(call.messages), @@ -359,8 +370,8 @@ async def _real( async for event in runtime.run(source): if call.label is not None: event_message: types.Message | None = None - if isinstance(event, types.MessageEnd) or ( - isinstance(event, types.MessageStart) + if isinstance(event, events_.MessageEnd) or ( + isinstance(event, events_.MessageStart) and event.message is not None ): event_message = event.message diff --git a/src/ai/agents/events.py b/src/ai/agents/events.py new file mode 100644 index 00000000..12108464 --- /dev/null +++ b/src/ai/agents/events.py @@ -0,0 +1,45 @@ +"""Agent-layer event types. + +The model layer emits ``StreamStart`` / ``StreamEnd`` plus block-level +deltas. The agent layer wraps those with ``MessageStart`` / ``MessageEnd`` +boundaries that delimit complete messages — assistant turns produced by +the model, plus synthetic user / tool / hook messages injected into the +runtime queue. + +These types live here (rather than in ``ai.types.events``) because they +are an agent-runtime concern, not part of the public model-streaming +event vocabulary. +""" + +from __future__ import annotations + +from typing import Literal + +import pydantic + +from .. import types + + +class MessageStart(pydantic.BaseModel): + message: types.Message | None = None + + kind: Literal["message_start"] = "message_start" + + +class MessageEnd(pydantic.BaseModel): + message: types.Message + usage: types.Usage | None = None + + kind: Literal["message_end"] = "message_end" + + +# Widened event alias used inside agents/. Not part of ``types.Event``'s +# discriminated union — these wrappers do not flow through model adapters. +AgentEvent = types.Event | MessageStart | MessageEnd + + +__all__ = [ + "AgentEvent", + "MessageEnd", + "MessageStart", +] diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index cb8769ce..a8005748 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -7,6 +7,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, Awaitable from .. import types +from . import events as events_ from . import hooks as hooks_ from .mcp import client as mcp_client @@ -20,17 +21,17 @@ class _Sentinel: _SENTINEL = _Sentinel() def __init__(self) -> None: - self._event_queue: asyncio.Queue[types.Event | Runtime._Sentinel] = ( + self._event_queue: asyncio.Queue[events_.AgentEvent | Runtime._Sentinel] = ( asyncio.Queue() ) self._hook_labels: set[str] = set() - async def put_event(self, event: types.Event) -> None: + async def put_event(self, event: events_.AgentEvent) -> None: await self._event_queue.put(event) async def put_message(self, message: types.Message) -> None: - await self.put_event(types.MessageStart(message=message)) - await self.put_event(types.MessageEnd(message=message)) + await self.put_event(events_.MessageStart(message=message)) + await self.put_event(events_.MessageEnd(message=message)) async def signal_done(self) -> None: await self._event_queue.put(self._SENTINEL) @@ -61,8 +62,8 @@ async def _stop_when_done(runtime: Runtime, task: Awaitable[None]) -> None: async def run( - source: AsyncIterable[types.Event], -) -> AsyncGenerator[types.Event]: + source: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[events_.AgentEvent]: """Run *source* and yield every event that gets put into the Runtime queue.""" rt = Runtime() token = _runtime.set(rt) diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index e6c1f581..d911e9cb 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncGenerator, AsyncIterable -from .....types import events as events_ +from ....events import AgentEvent from .. import protocol from .stream import to_stream @@ -32,7 +32,7 @@ def format_sse(part: protocol.UIMessageStreamPart) -> str: async def to_sse( - events: AsyncIterable[events_.Event], + events: AsyncIterable[AgentEvent], ) -> AsyncGenerator[str]: """Convert an internal event stream into SSE strings.""" async for part in to_stream(events): diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 5ec0db5f..a569342f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -4,13 +4,13 @@ from collections.abc import AsyncGenerator, AsyncIterable -from .....types import events as events_ +from ....events import AgentEvent, MessageEnd, MessageStart from .. import protocol from ._state import _StreamState async def to_stream( - events: AsyncIterable[events_.Event], + events: AsyncIterable[AgentEvent], ) -> AsyncGenerator[protocol.UIMessageStreamPart]: """Walk ``events`` once, emitting AI SDK UI stream parts. @@ -21,10 +21,10 @@ async def to_stream( state = _StreamState() async for event in events: - if isinstance(event, events_.MessageStart): + if isinstance(event, MessageStart): for part in state.on_message_start(event.message): yield part - elif isinstance(event, events_.MessageEnd): + elif isinstance(event, MessageEnd): for part in state.on_terminal(event.message): yield part else: diff --git a/src/ai/middleware.py b/src/ai/middleware.py index a8f77b81..afac14b6 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -22,9 +22,15 @@ import pydantic -from .types import events as events_ from .types import messages as messages_ -from .types.proto import StreamResultLike, ToolLike +from .types.proto import ToolLike + +# Compat shim: ``StreamResultLike`` was removed from ``ai.types.proto`` when +# the model layer was reworked. Middleware is dead code under the new +# ``Executor``-based ``api.py`` and is kept around only so the agents +# rewrite can land separately; ``Any`` is enough to keep the existing +# annotations type-checking. +type StreamResultLike = Any # --------------------------------------------------------------------------- # Call context objects — frozen dataclasses with isolated mutable fields. @@ -113,8 +119,11 @@ def __post_init__(self) -> None: # Middleware base class — override the methods you care about. # --------------------------------------------------------------------------- -# Event/message aliases for brevity in signatures. -_Event = events_.Event +# Event/message aliases for brevity in signatures. ``_Event`` is intentionally +# typed as ``Any`` so the agent-run chain accepts the wider ``AgentEvent`` +# union (which includes ``MessageStart``/``MessageEnd``) without a circular +# import from ``ai.agents``. +_Event = Any _Message = messages_.Message # Agent run next-function type: call -> async generator of events. diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 61a97767..4f0da85d 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -25,29 +25,42 @@ ids = await openai.list() """ -from ..types.proto import StreamResultLike from .ai_gateway import ai_gateway from .anthropic import anthropic from .core.adapters import register_generate, register_stream -from .core.api import check_connection, generate, stream +from .core.api import ( + Executor, + GenerateExecutor, + GenerateRequest, + Stream, + StreamExecutor, + StreamRequest, + check_connection, + generate, + stream, +) from .core.client import Client from .core.model import Model +from .core.params import GenerateParams, ImageParams, VideoParams from .core.proto import CheckConnFn, GenerateFn, Provider, StreamFn -from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams from .openai import openai __all__ = [ # Core types "CheckConnFn", "Client", + "Executor", + "GenerateExecutor", "GenerateFn", "GenerateParams", + "GenerateRequest", "ImageParams", "Model", "Provider", + "Stream", + "StreamExecutor", "StreamFn", - "StreamResult", - "StreamResultLike", + "StreamRequest", "VideoParams", # Provider factories "ai_gateway", diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py index 9a6dee45..fec98f57 100644 --- a/src/ai/models/ai_gateway/generate.py +++ b/src/ai/models/ai_gateway/generate.py @@ -14,9 +14,9 @@ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files -from ..core.types import GenerateParams as GenerateParams -from ..core.types import ImageParams as ImageParams -from ..core.types import VideoParams as VideoParams +from ..core.params import GenerateParams as GenerateParams +from ..core.params import ImageParams as ImageParams +from ..core.params import VideoParams as VideoParams from . import _common, errors # --------------------------------------------------------------------------- diff --git a/src/ai/models/ai_gateway/stream.py b/src/ai/models/ai_gateway/stream.py index 8d97c5d9..ce143806 100644 --- a/src/ai/models/ai_gateway/stream.py +++ b/src/ai/models/ai_gateway/stream.py @@ -19,7 +19,7 @@ from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import files, streaming +from ..core.helpers import files from . import _common, errors # --------------------------------------------------------------------------- @@ -154,20 +154,20 @@ async def _build_request_body( # --------------------------------------------------------------------------- -# SSE response parsing — v3 stream parts → StreamEvent +# SSE response parsing — v3 stream parts → public Event # --------------------------------------------------------------------------- -def _expand_tool_call(data: dict[str, Any]) -> list[streaming.StreamEvent]: - """Expand a complete ``tool-call`` part into Start + ArgsDelta + End.""" +def _expand_tool_call(data: dict[str, Any]) -> list[events_.Event]: + """Expand a complete ``tool-call`` part into Start + Delta + End.""" tc_id = data.get("toolCallId", "") tool_name = data.get("toolName", "") tool_input = data.get("input", "") args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) return [ - streaming.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - streaming.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), - streaming.ToolEnd(tool_call_id=tc_id), + events_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + events_.ToolDelta(tool_call_id=tc_id, chunk=args_str), + events_.ToolEnd(tool_call_id=tc_id), ] @@ -198,40 +198,40 @@ def _parse_usage(data: Any) -> usage_.Usage: ) -def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: - """Convert a ``LanguageModelV3StreamPart`` to internal events.""" +def _parse_stream_part(data: dict[str, Any]) -> list[events_.Event]: + """Convert a ``LanguageModelV3StreamPart`` to public events.""" match data.get("type", ""): case "text-start": - return [streaming.TextStart(block_id=data.get("id", "text"))] + return [events_.TextStart(block_id=data.get("id", "text"))] case "text-delta": return [ - streaming.TextDelta( + events_.TextDelta( block_id=data.get("id", "text"), - delta=data.get("textDelta", data.get("delta", "")), + chunk=data.get("textDelta", data.get("delta", "")), ) ] case "text-end": - return [streaming.TextEnd(block_id=data.get("id", "text"))] + return [events_.TextEnd(block_id=data.get("id", "text"))] case "reasoning-start": - return [streaming.ReasoningStart(block_id=data.get("id", "reasoning"))] + return [events_.ReasoningStart(block_id=data.get("id", "reasoning"))] case "reasoning-delta": return [ - streaming.ReasoningDelta( + events_.ReasoningDelta( block_id=data.get("id", "reasoning"), - delta=data.get("delta", ""), + chunk=data.get("delta", ""), ) ] case "reasoning-end": - return [streaming.ReasoningEnd(block_id=data.get("id", "reasoning"))] + return [events_.ReasoningEnd(block_id=data.get("id", "reasoning"))] case "tool-input-start": return [ - streaming.ToolStart( + events_.ToolStart( tool_call_id=data.get("id", ""), tool_name=data.get("toolName", ""), ) @@ -239,22 +239,22 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: case "tool-input-delta": return [ - streaming.ToolArgsDelta( + events_.ToolDelta( tool_call_id=data.get("id", ""), - delta=data.get("delta", ""), + chunk=data.get("delta", ""), ) ] case "tool-input-end": - return [streaming.ToolEnd(tool_call_id=data.get("id", ""))] + return [events_.ToolEnd(tool_call_id=data.get("id", ""))] case "tool-call": return _expand_tool_call(data) case "file": return [ - streaming.FileEvent( - block_id=data.get("id", f"file-{len(data)}"), + events_.FileEvent( + block_id=data.get("id", ""), media_type=data.get("mediaType", "application/octet-stream"), data=data.get("data", ""), ) @@ -263,14 +263,7 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: case "finish": usage_data = data.get("usage") usage = _parse_usage(usage_data) if usage_data else None - match data.get("finishReason"): - case dict() as d: - finish_reason = d.get("unified", "stop") - case str() as s: - finish_reason = s - case _: - finish_reason = "stop" - return [streaming.MessageDone(finish_reason=finish_reason, usage=usage)] + return [events_.StreamEnd(usage=usage)] case _: return [] @@ -293,6 +286,8 @@ async def stream( """Stream an LLM response through the AI Gateway v3 protocol. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. """ body = await _build_request_body( messages, tools=tools, output_type=output_type, **kwargs @@ -302,8 +297,6 @@ async def stream( ) url = f"{client.base_url.rstrip('/')}/language-model" - handler = streaming.StreamHandler() - try: async with client.http.stream( "POST", @@ -319,11 +312,10 @@ async def stream( api_key_provided=bool(client.api_key), ) - yield handler.message_start() + yield events_.StreamStart() async for data in _common.parse_sse_lines(response): - for adapter_event in _parse_stream_part(data): - for out_event in handler.handle_event(adapter_event): - yield out_event + for event in _parse_stream_part(data): + yield event except errors.GatewayError: raise except httpx.TimeoutException as exc: diff --git a/src/ai/models/anthropic/adapter.py b/src/ai/models/anthropic/adapter.py index 34329575..20ffe6d7 100644 --- a/src/ai/models/anthropic/adapter.py +++ b/src/ai/models/anthropic/adapter.py @@ -21,7 +21,7 @@ def _tools_to_anthropic( - tools: Sequence[types.proto.ToolLike], + tools: Sequence[types.ToolLike], ) -> list[dict[str, Any]]: """Convert internal Tool objects to Anthropic tool schema format.""" return [ @@ -242,7 +242,7 @@ async def stream( model: core.model.Model, messages: list[types.Message], *, - tools: Sequence[types.proto.ToolLike] | None = None, + tools: Sequence[types.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, thinking: bool = False, budget_tokens: int = 10000, @@ -251,6 +251,8 @@ async def stream( """Stream an LLM response via the Anthropic messages API. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -280,24 +282,16 @@ async def stream( if output_type is not None: api_kwargs["output_format"] = output_type + # Anthropic indexes content blocks by int; map to string block_ids. block_types: dict[int, str] = {} tool_ids: dict[int, str] = {} tool_names: dict[int, str] = {} signature_buffer: dict[int, str] = {} - # Accumulate parts for the final Message - parts: list[types.Part] = [] - _text_parts: dict[str, str] = {} # block_id -> accumulated text - _reasoning_parts: dict[str, str] = {} # block_id -> accumulated text - _tool_parts: dict[str, str] = {} # tool_call_id -> accumulated args - message_id = types.generate_id() try: - stream_cm = sdk_client.messages.stream(**api_kwargs) + async with sdk_client.messages.stream(**api_kwargs) as sdk_stream: + yield events.StreamStart() - async with stream_cm as sdk_stream: - yield events.MessageStart( - message=types.Message(id=message_id, role="assistant", parts=[]) - ) async for event in sdk_stream: match event.type: case "content_block_start": @@ -307,15 +301,12 @@ async def stream( match block.type: case "text": - _text_parts[str(idx)] = "" yield events.TextStart(block_id=str(idx)) case "thinking": - _reasoning_parts[str(idx)] = "" yield events.ReasoningStart(block_id=str(idx)) case "tool_use": tool_ids[idx] = block.id tool_names[idx] = block.name - _tool_parts[block.id] = "" yield events.ToolStart( tool_call_id=block.id, tool_name=block.name, @@ -327,17 +318,11 @@ async def stream( match delta.type: case "text_delta": - _text_parts[str(idx)] = ( - _text_parts.get(str(idx), "") + delta.text - ) yield events.TextDelta( chunk=delta.text, block_id=str(idx), ) case "thinking_delta": - _reasoning_parts[str(idx)] = ( - _reasoning_parts.get(str(idx), "") + delta.thinking - ) yield events.ReasoningDelta( chunk=delta.thinking, block_id=str(idx), @@ -349,10 +334,6 @@ async def stream( case "input_json_delta": tool_id = tool_ids.get(idx) if tool_id: - _tool_parts[tool_id] = ( - _tool_parts.get(tool_id, "") - + delta.partial_json - ) yield events.ToolDelta( chunk=delta.partial_json, tool_call_id=tool_id, @@ -360,38 +341,17 @@ async def stream( case "content_block_stop": idx = event.index - bid = str(idx) match block_types.get(idx): case "text": - parts.append( - types.TextPart( - id=bid, text=_text_parts.get(bid, "") - ) - ) - yield events.TextEnd(block_id=bid) + yield events.TextEnd(block_id=str(idx)) case "thinking": - parts.append( - types.ReasoningPart( - id=bid, - text=_reasoning_parts.get(bid, ""), - signature=signature_buffer.get(idx), - ) - ) yield events.ReasoningEnd( - block_id=bid, + block_id=str(idx), signature=signature_buffer.get(idx), ) case "tool_use": tool_id = tool_ids.get(idx) if tool_id: - parts.append( - types.ToolCallPart( - id=tool_id, - tool_call_id=tool_id, - tool_name=tool_names.get(idx, ""), - tool_args=_tool_parts.get(tool_id, ""), - ) - ) yield events.ToolEnd(tool_call_id=tool_id) snapshot = sdk_stream.current_message_snapshot @@ -405,12 +365,6 @@ async def stream( ), raw=sdk_usage.model_dump(exclude_none=True) or None, ) - final_message = types.Message( - id=message_id, - role="assistant", - parts=parts, - usage=usage, - ) - yield events.MessageEnd(message=final_message, usage=usage) + yield events.StreamEnd(usage=usage) finally: await sdk_client.close() diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 5b5e580d..e7db1ace 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,22 +1,37 @@ """Core types for models.""" from .adapters import register_generate, register_stream -from .api import check_connection, generate, stream +from .api import ( + Executor, + GenerateExecutor, + GenerateRequest, + Stream, + StreamExecutor, + StreamRequest, + check_connection, + generate, + stream, +) from .client import Client from .model import Model +from .params import GenerateParams, ImageParams, VideoParams from .proto import CheckConnFn, GenerateFn, Provider, StreamFn -from .types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ "CheckConnFn", "Client", + "Executor", + "GenerateExecutor", "GenerateFn", "GenerateParams", + "GenerateRequest", "ImageParams", "Model", "Provider", + "Stream", + "StreamExecutor", "StreamFn", - "StreamResult", + "StreamRequest", "VideoParams", "check_connection", "generate", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index d0b90315..3d40de04 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -1,150 +1,195 @@ -"""Top-level orchestration — stream(), generate(), check_connection(). - -These wire together adapters, middleware chains, and auto-client creation. -""" - -from __future__ import annotations - +import dataclasses from collections.abc import AsyncGenerator, Sequence -from typing import Any +from typing import Any, Protocol, Self, runtime_checkable import pydantic -from ... import middleware as middleware_ -from ...types import events as events_ +from ... import types from ...types import integrity as integrity_ -from ...types import messages as messages_ -from ...types import proto as proto_ -from ...types import stream as stream_ -from . import adapters +from . import adapters, params from . import client as client_ from . import model as model_ -from . import types as types_ -def stream( - model: model_.Model, - messages: list[messages_.Message], - *, - tools: Sequence[proto_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - turn_id: str | None = None, - **kwargs: Any, -) -> proto_.StreamResultLike: - """Stream an LLM response. +@dataclasses.dataclass(frozen=True) +class StreamRequest: + model: model_.Model + messages: list[types.Message] + tools: Sequence[types.ToolLike] | None = None + output_type: type[pydantic.BaseModel] | None = None - Returns a :class:`StreamResultLike` that is async-iterable and - collects the final ``Message``. After iteration, access ``.text``, - ``.tool_calls``, ``.usage``, etc. - Call-site is a plain ``async for`` — no outer ``await`` needed:: +@dataclasses.dataclass(frozen=True) +class GenerateRequest: + model: model_.Model + messages: list[types.Message] + params: params.GenerateParams + + +@runtime_checkable +class StreamExecutor(Protocol): + def _do_stream(self, request: StreamRequest) -> AsyncGenerator[types.Event]: ... + + +@runtime_checkable +class GenerateExecutor(Protocol): + async def _do_generate(self, request: GenerateRequest) -> types.Message: ... + + +class Executor: + """Default executor: dispatches to adapters via the local client.""" + + async def _do_stream(self, request: StreamRequest) -> AsyncGenerator[types.Event]: + c = client_.auto_client(request.model) + fn = adapters.get_stream_adapter(request.model.adapter) + async for ev in fn( + c, + request.model, + request.messages, + tools=request.tools, + output_type=request.output_type, + ): + yield ev + + async def _do_generate(self, request: GenerateRequest) -> types.Message: + c = client_.auto_client(request.model) + fn = adapters.get_generate_adapter(request.model.adapter) + return await fn(c, request.model, request.messages, params=request.params) + + +_default_executor = Executor() + + +class Stream: + """Async-iterable wrapper around an adapter's event stream.""" + + def __init__(self, gen: AsyncGenerator[types.Event]) -> None: + self._gen = gen + self._message: types.Message = types.Message(role="assistant", parts=[]) + self._parts: dict[str, types.Part] = {} + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: + await self._gen.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> types.Event: + event = await self._gen.__anext__() + self._aggregate_event(event) + return event.model_copy(update={"message": self._message}) + + @property + def message(self) -> types.Message: + return self._message + + @property + def usage(self) -> types.Usage | None: + return self._message.usage + + @property + def text(self) -> str: + return self._message.text + + @property + def tool_calls(self) -> list[types.ToolCallPart]: + return self._message.tool_calls + + @property + def output(self) -> Any: + return self._message.output + + def _aggregate_event(self, event: types.Event) -> None: + # grab usage from any event that carries one + if event.usage is not None: + self._message.usage = event.usage + + match event: + case types.TextStart(block_id=bid): + tp = types.TextPart(id=bid, text="") + self._message.parts.append(tp) + self._parts[bid] = tp + case types.TextDelta(block_id=bid, chunk=c): + existing_text = self._parts.get(bid) + if isinstance(existing_text, types.TextPart): + existing_text.text += c + case types.ReasoningStart(block_id=bid): + rp = types.ReasoningPart(id=bid, text="") + self._message.parts.append(rp) + self._parts[bid] = rp + case types.ReasoningDelta(block_id=bid, chunk=c): + existing_reasoning = self._parts.get(bid) + if isinstance(existing_reasoning, types.ReasoningPart): + existing_reasoning.text += c + case types.ReasoningEnd(block_id=bid, signature=sig): + existing_reasoning = self._parts.get(bid) + if ( + isinstance(existing_reasoning, types.ReasoningPart) + and sig is not None + ): + existing_reasoning.signature = sig + case types.ToolStart(tool_call_id=tcid, tool_name=name): + tcp = types.ToolCallPart( + id=tcid, + tool_call_id=tcid, + tool_name=name, + tool_args="", + ) + self._message.parts.append(tcp) + self._parts[tcid] = tcp + case types.ToolDelta(tool_call_id=tcid, chunk=c): + existing_tool = self._parts.get(tcid) + if isinstance(existing_tool, types.ToolCallPart): + existing_tool.tool_args += c + case types.FileEvent(block_id=bid, media_type=mt, data=d, filename=fname): + fp = types.FilePart( + id=bid or types.generate_id(), + data=d, + media_type=mt, + filename=fname, + ) + self._message.parts.append(fp) + self._parts[fp.id] = fp + case _: + pass - async for event in ai.stream(model, messages): - ... - One call is one turn: a single request and its response. The model - response carries ``turn_id``; re-emitted input messages keep any - existing ``turn_id`` from prior turns and only receive the current - one when unstamped. If *turn_id* is not provided, one is generated. - - The client is resolved from the model: ``model.client`` if set, - otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. - - Middleware dispatch and adapter setup are deferred to the first - iteration; any async preflight work happens there. - """ +def stream( + model: model_.Model, + messages: list[types.Message], + *, + tools: Sequence[types.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + executor: StreamExecutor = _default_executor, +) -> Stream: + """Stream an LLM response.""" messages = integrity_.prepare_messages(messages) - - if turn_id is None: - turn_id = messages_.generate_id("turn") - - call = middleware_.ModelContext( - model=model, - messages=messages, - tools=tools, - output_type=output_type, - kwargs=kwargs, - ) - - # Capture in closure for the inner function. - _turn_id = turn_id - - async def _real(call: middleware_.ModelContext) -> proto_.StreamResultLike: - c = client_.auto_client(call.model) - adapter_fn = adapters.get_stream_adapter(call.model.adapter) - return types_.StreamResult( - adapter_fn( - c, - call.model, - call.messages, - tools=call.tools, - output_type=call.output_type, - **call.kwargs, - ), - turn_id=_turn_id, - input_messages=call.messages, - ) - - async def _driver() -> AsyncGenerator[events_.Event]: - chain = middleware_._build_model_chain(_real) - inner = await chain(call) - async for event in inner: - yield event - - return stream_.StreamResult(_driver(), turn_id=turn_id) + request = StreamRequest(model, messages, tools, output_type) + return Stream(executor._do_stream(request)) async def generate( model: model_.Model, - messages: list[messages_.Message], - params: types_.GenerateParams, - **kwargs: Any, -) -> messages_.Message: - """Generate a response (images, video, etc.). - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from the model if no explicit client is set. - - ``params`` is required and controls the generation type: - - * :class:`ImageParams` — image generation (``/image-model``). - * :class:`VideoParams` — video generation (``/video-model``). - """ + messages: list[types.Message], + params: params.GenerateParams, + *, + executor: GenerateExecutor = _default_executor, +) -> types.Message: + """Generate a non-streaming response (images, video, etc.).""" messages = integrity_.prepare_messages(messages) + request = GenerateRequest(model, messages, params) + return await executor._do_generate(request) - call = middleware_.GenerateContext( - model=model, - messages=messages, - params=params, - ) - - async def _real(call: middleware_.GenerateContext) -> messages_.Message: - c = client_.auto_client(call.model) - adapter_fn = adapters.get_generate_adapter(call.model.adapter) - return await adapter_fn(c, call.model, call.messages, params=call.params) - - chain = middleware_._build_generate_chain(_real) - return await chain(call) - - -async def check_connection( - model: model_.Model, -) -> bool: - """Check whether the model's provider is reachable and the model exists. - - Returns ``True`` when the credentials are valid **and** the model is - available on the remote side — i.e. a subsequent :func:`stream` or - :func:`generate` call should succeed (network conditions permitting). - - This only hits free metadata endpoints; no tokens or credits are - consumed. - - The client is resolved from the model: ``model.client`` if set, - otherwise created by the provider. - Non-auth transport errors (network failures, 5xx) are raised rather - than returning ``False`` so that callers can distinguish "bad - credentials / unknown model" from "provider unreachable". - """ +async def check_connection(model: model_.Model) -> bool: + """Check whether the model's provider is reachable and the model exists.""" c = client_.auto_client(model) return await model.provider.check(c, model) diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py deleted file mode 100644 index 25d8b649..00000000 --- a/src/ai/models/core/helpers/streaming.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import dataclasses - -from ....types import events as events_ -from ....types import messages as messages_ -from ....types import usage as usage_ - - -@dataclasses.dataclass -class TextStart: - block_id: str - - -@dataclasses.dataclass -class TextDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class TextEnd: - block_id: str - - -@dataclasses.dataclass -class ReasoningStart: - block_id: str - - -@dataclasses.dataclass -class ReasoningDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class ReasoningEnd: - block_id: str - signature: str | None = None - - -@dataclasses.dataclass -class ToolStart: - tool_call_id: str - tool_name: str - - -@dataclasses.dataclass -class ToolArgsDelta: - tool_call_id: str - delta: str - - -@dataclasses.dataclass -class ToolEnd: - tool_call_id: str - - -@dataclasses.dataclass -class FileEvent: - """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" - - block_id: str - media_type: str - data: str # base64 string or data-URL from the gateway - - -@dataclasses.dataclass -class MessageDone: - finish_reason: str | None = None - usage: usage_.Usage | None = None - - -StreamEvent = ( - TextStart - | TextDelta - | TextEnd - | ReasoningStart - | ReasoningDelta - | ReasoningEnd - | ToolStart - | ToolArgsDelta - | ToolEnd - | FileEvent - | MessageDone -) - - -@dataclasses.dataclass -class StreamHandler: - """ - Accumulates LLM adapter events and produces public Event objects. - - This is the normalization layer between LLM adapters and the rest of the system. - Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, - updated in place as events stream in. - """ - - message_id: str = dataclasses.field(default_factory=messages_.generate_id) - - # Single source of truth for part state, keyed by id. Insertion order - # preserves provider emission order. - _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) - - # Active tracking - _active_text_id: str | None = None - _active_reasoning_id: str | None = None - _active_tool_ids: set[str] = dataclasses.field(default_factory=set) - - _is_done: bool = False - _usage: usage_.Usage | None = None - - def message_start(self) -> events_.MessageStart: - """Emit a MessageStart event at the beginning of a stream.""" - return events_.MessageStart(message=self._build_message()) - - def handle_event(self, event: StreamEvent) -> list[events_.Event]: - """Process an adapter event and return public Event objects.""" - - out: list[events_.Event] = [] - - match event: - case TextStart(block_id=bid): - part: messages_.Part = messages_.TextPart(id=bid, text="") - self._current_parts[bid] = part - self._active_text_id = bid - out.append(events_.TextStart(block_id=bid)) - - case TextDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.TextPart) - part = messages_.TextPart(id=bid, text=existing.text + d) - self._current_parts[bid] = part - out.append(events_.TextDelta(chunk=d, block_id=bid)) - - case TextEnd(block_id=bid): - if self._active_text_id == bid: - self._active_text_id = None - out.append(events_.TextEnd(block_id=bid)) - - case ReasoningStart(block_id=bid): - part = messages_.ReasoningPart(id=bid, text="") - self._current_parts[bid] = part - self._active_reasoning_id = bid - out.append(events_.ReasoningStart(block_id=bid)) - - case ReasoningDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, - text=existing.text + d, - signature=existing.signature, - ) - self._current_parts[bid] = part - out.append(events_.ReasoningDelta(chunk=d, block_id=bid)) - - case ReasoningEnd(block_id=bid, signature=sig): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, text=existing.text, signature=sig - ) - self._current_parts[bid] = part - if self._active_reasoning_id == bid: - self._active_reasoning_id = None - out.append(events_.ReasoningEnd(block_id=bid, signature=sig)) - - case ToolStart(tool_call_id=tcid, tool_name=name): - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=tcid, - tool_name=name, - tool_args="", - ) - self._current_parts[tcid] = part - self._active_tool_ids.add(tcid) - out.append(events_.ToolStart(tool_call_id=tcid, tool_name=name)) - - case ToolArgsDelta(tool_call_id=tcid, delta=d): - existing = self._current_parts[tcid] - assert isinstance(existing, messages_.ToolCallPart) - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=existing.tool_call_id, - tool_name=existing.tool_name, - tool_args=existing.tool_args + d, - ) - self._current_parts[tcid] = part - out.append(events_.ToolDelta(chunk=d, tool_call_id=tcid)) - - case ToolEnd(tool_call_id=tcid): - self._active_tool_ids.discard(tcid) - out.append(events_.ToolEnd(tool_call_id=tcid)) - - case FileEvent(block_id=bid, media_type=mt, data=d): - self._current_parts[bid] = messages_.FilePart( - id=bid, data=d, media_type=mt - ) - - case MessageDone(usage=u): - self._is_done = True - self._usage = u - self._active_text_id = None - self._active_reasoning_id = None - self._active_tool_ids.clear() - msg = self._build_message() - out.append(events_.MessageEnd(message=msg, usage=u)) - - return out - - def _build_message(self) -> messages_.Message: - return messages_.Message( - id=self.message_id, - role="assistant", - parts=list(self._current_parts.values()), - usage=self._usage if self._is_done else None, - ) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 35daf675..987d2f5e 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -91,7 +91,7 @@ class StreamFn(Protocol): """Protocol for streaming adapter functions. Implementations yield event objects as the response streams in. The - terminal assistant state is surfaced as a ``MessageEnd.message``. + terminal assistant state is surfaced as a ``StreamEnd.message``. """ def __call__( diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py deleted file mode 100644 index 4344fa05..00000000 --- a/src/ai/models/core/types.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Re-exports for backwards-compatible ``ai.models.core.types`` imports.""" - -from ...types.stream import StreamResult -from .params import GenerateParams, ImageParams, VideoParams - -__all__ = [ - "GenerateParams", - "ImageParams", - "StreamResult", - "VideoParams", -] diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index 3bfd082f..d1611785 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -16,9 +16,10 @@ from ...types import media from ...types import messages as messages_ from ...types import proto as proto_ +from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import files, streaming +from ..core.helpers import files # --------------------------------------------------------------------------- # Message / tool conversion — internal types → OpenAI wire format @@ -211,6 +212,8 @@ async def stream( """Stream an LLM response via the OpenAI chat completions API. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -255,42 +258,28 @@ async def stream( reasoning_config["effort"] = reasoning_effort api_kwargs["extra_body"] = {"reasoning": reasoning_config} - handler = streaming.StreamHandler() - - def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: - return handler.handle_event(adapter_event) - try: sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) text_started = False reasoning_started = False tc_state: dict[int, dict[str, Any]] = {} - finish_reason: str | None = None - usage: messages_.Usage | None = None + usage: usage_.Usage | None = None - yield handler.message_start() + yield events_.StreamStart() async for chunk in sdk_stream: if chunk.usage is not None: raw = chunk.usage.model_dump(exclude_none=True) reasoning_tokens: int | None = None cache_read: int | None = None - cd = getattr( - chunk.usage, - "completion_tokens_details", - None, - ) + cd = getattr(chunk.usage, "completion_tokens_details", None) if cd: reasoning_tokens = getattr(cd, "reasoning_tokens", None) - pd = getattr( - chunk.usage, - "prompt_tokens_details", - None, - ) + pd = getattr(chunk.usage, "prompt_tokens_details", None) if pd: cache_read = getattr(pd, "cached_tokens", None) - usage = messages_.Usage( + usage = usage_.Usage( input_tokens=chunk.usage.prompt_tokens or 0, output_tokens=chunk.usage.completion_tokens or 0, reasoning_tokens=reasoning_tokens, @@ -314,29 +303,20 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if reasoning_value: if not reasoning_started: reasoning_started = True - for e in _emit(streaming.ReasoningStart(block_id="reasoning")): - yield e - for e in _emit( - streaming.ReasoningDelta( - block_id="reasoning", delta=reasoning_value - ) - ): - yield e + yield events_.ReasoningStart(block_id="reasoning") + yield events_.ReasoningDelta( + chunk=reasoning_value, block_id="reasoning" + ) if delta.content: if reasoning_started: - for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): - yield e + yield events_.ReasoningEnd(block_id="reasoning") reasoning_started = False if not text_started: text_started = True - for e in _emit(streaming.TextStart(block_id="text")): - yield e - for e in _emit( - streaming.TextDelta(block_id="text", delta=delta.content) - ): - yield e + yield events_.TextStart(block_id="text") + yield events_.TextDelta(chunk=delta.content, block_id="text") if delta.tool_calls: for tc in delta.tool_calls: @@ -358,37 +338,28 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if not tc_state[idx]["started"] and tid: tc_state[idx]["started"] = True - for e in _emit( - streaming.ToolStart( - tool_call_id=tid, - tool_name=tname, - ) - ): - yield e + yield events_.ToolStart( + tool_call_id=tid, tool_name=tname + ) if tid: - for e in _emit( - streaming.ToolArgsDelta( - tool_call_id=tid, - delta=tc.function.arguments, - ) - ): - yield e + yield events_.ToolDelta( + chunk=tc.function.arguments, + tool_call_id=tid, + ) if choice.finish_reason is not None: - finish_reason = choice.finish_reason if reasoning_started: - for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): - yield e + yield events_.ReasoningEnd(block_id="reasoning") + reasoning_started = False if text_started: - for e in _emit(streaming.TextEnd(block_id="text")): - yield e + yield events_.TextEnd(block_id="text") + text_started = False for tc in tc_state.values(): if tc["started"] and tc["id"]: - for e in _emit(streaming.ToolEnd(tool_call_id=tc["id"])): - yield e + yield events_.ToolEnd(tool_call_id=tc["id"]) + tc["started"] = False - for e in _emit(streaming.MessageDone(finish_reason=finish_reason, usage=usage)): - yield e + yield events_.StreamEnd(usage=usage) finally: await sdk_client.close() diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 769b12df..b6ff7a54 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -1,15 +1,14 @@ from . import media from .events import ( - End, Event, + FileEvent, HookResolution, HookSuspention, - MessageEnd, - MessageStart, ReasoningDelta, ReasoningEnd, ReasoningStart, - Start, + StreamEnd, + StreamStart, TextDelta, TextEnd, TextStart, @@ -29,27 +28,25 @@ ToolResultPart, generate_id, ) -from .proto import StreamResultLike, ToolLike +from .proto import ToolLike from .tools import ToolSchema from .usage import Usage __all__ = [ - "End", "Event", + "FileEvent", "FilePart", "HookPart", "HookResolution", "HookSuspention", "Message", - "MessageEnd", - "MessageStart", "Part", "ReasoningDelta", "ReasoningEnd", "ReasoningPart", "ReasoningStart", - "Start", - "StreamResultLike", + "StreamEnd", + "StreamStart", "StructuredOutputPart", "TextDelta", "TextEnd", diff --git a/src/ai/types/events.py b/src/ai/types/events.py index c9267072..3a4594bf 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -9,114 +9,110 @@ # serialization border in the case of durable execution -class Start(pydantic.BaseModel): - kind: Literal["start"] = "start" - model_config = pydantic.ConfigDict(frozen=True) - - -class End(pydantic.BaseModel): - kind: Literal["end"] = "end" - model_config = pydantic.ConfigDict(frozen=True) +class BaseEvent(pydantic.BaseModel): + """Common fields stamped onto every event by the streaming wrapper. + ``message`` carries the in-progress (or final) assistant message; the + streaming layer aggregates parts into it as deltas arrive and stamps + a reference onto each yielded event. ``usage`` carries the latest + usage value reported by the provider (latest-wins across the stream). + """ -class MessageStart(pydantic.BaseModel): message: messages.Message | None = None + usage: usage_.Usage | None = None - kind: Literal["message_start"] = "message_start" model_config = pydantic.ConfigDict(frozen=True) -class MessageEnd(pydantic.BaseModel): - message: messages.Message - usage: usage_.Usage | None = None +class StreamStart(BaseEvent): + kind: Literal["stream_start"] = "stream_start" - kind: Literal["message_end"] = "message_end" - model_config = pydantic.ConfigDict(frozen=True) + +class StreamEnd(BaseEvent): + kind: Literal["stream_end"] = "stream_end" -class TextStart(pydantic.BaseModel): +class TextStart(BaseEvent): block_id: str = "" kind: Literal["text_start"] = "text_start" - model_config = pydantic.ConfigDict(frozen=True) -class TextDelta(pydantic.BaseModel): +class TextDelta(BaseEvent): chunk: str block_id: str = "" kind: Literal["text_delta"] = "text_delta" - model_config = pydantic.ConfigDict(frozen=True) -class TextEnd(pydantic.BaseModel): +class TextEnd(BaseEvent): block_id: str = "" kind: Literal["text_end"] = "text_end" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningStart(pydantic.BaseModel): +class ReasoningStart(BaseEvent): block_id: str = "" kind: Literal["reasoning_start"] = "reasoning_start" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningDelta(pydantic.BaseModel): +class ReasoningDelta(BaseEvent): chunk: str block_id: str = "" kind: Literal["reasoning_delta"] = "reasoning_delta" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningEnd(pydantic.BaseModel): +class ReasoningEnd(BaseEvent): block_id: str = "" signature: str | None = None kind: Literal["reasoning_end"] = "reasoning_end" - model_config = pydantic.ConfigDict(frozen=True) -class ToolStart(pydantic.BaseModel): +class ToolStart(BaseEvent): tool_call_id: str = "" tool_name: str = "" kind: Literal["tool_start"] = "tool_start" - model_config = pydantic.ConfigDict(frozen=True) -class ToolDelta(pydantic.BaseModel): +class ToolDelta(BaseEvent): chunk: str tool_call_id: str = "" kind: Literal["tool_delta"] = "tool_delta" - model_config = pydantic.ConfigDict(frozen=True) -class ToolEnd(pydantic.BaseModel): +class ToolEnd(BaseEvent): tool_call_id: str = "" kind: Literal["tool_end"] = "tool_end" - model_config = pydantic.ConfigDict(frozen=True) -class HookSuspention(pydantic.BaseModel): +class FileEvent(BaseEvent): + """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" + + block_id: str = "" + media_type: str + data: str | bytes + filename: str | None = None + + kind: Literal["file"] = "file" + + +class HookSuspention(BaseEvent): kind: Literal["hook_suspention"] = "hook_suspention" - model_config = pydantic.ConfigDict(frozen=True) -class HookResolution(pydantic.BaseModel): +class HookResolution(BaseEvent): kind: Literal["hook_resolution"] = "hook_resolution" - model_config = pydantic.ConfigDict(frozen=True) Event = Annotated[ - Start - | End - | MessageStart - | MessageEnd + StreamStart + | StreamEnd | TextStart | TextDelta | TextEnd @@ -126,6 +122,7 @@ class HookResolution(pydantic.BaseModel): | ToolStart | ToolDelta | ToolEnd + | FileEvent | HookSuspention | HookResolution, pydantic.Field(discriminator="kind"), diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 29332eca..4be9d668 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -19,7 +19,6 @@ class TextPart(pydantic.BaseModel): text: str kind: Literal["text"] = "text" - model_config = pydantic.ConfigDict(frozen=True) class ToolCallPart(pydantic.BaseModel): @@ -29,7 +28,6 @@ class ToolCallPart(pydantic.BaseModel): tool_args: str kind: Literal["tool_call"] = "tool_call" - model_config = pydantic.ConfigDict(frozen=True) class ToolResultPart(pydantic.BaseModel): @@ -51,7 +49,6 @@ class ReasoningPart(pydantic.BaseModel): signature: str | None = None kind: Literal["reasoning"] = "reasoning" - model_config = pydantic.ConfigDict(frozen=True) class HookPart[T](pydantic.BaseModel): @@ -188,8 +185,6 @@ def from_bytes( class Message(pydantic.BaseModel): - model_config = pydantic.ConfigDict(frozen=True) - role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) @@ -235,27 +230,5 @@ def output(self) -> Any: return part.value return None - def replace(self, old: Part, new: Part | None = None) -> Self: - """Return a copy with one part replaced. - - ``replace(new_part)`` matches by ``new_part.id``. - ``replace(old_part, new_part)`` matches by object identity. - """ - if new is None: - new = old - for idx, part in enumerate(self.parts): - if part.id == new.id: - parts = list(self.parts) - parts[idx] = new - return self.model_copy(update={"parts": parts}) - raise ValueError(f"Part id={new.id!r} not found in message {self.id!r}") - - for idx, part in enumerate(self.parts): - if part is old: - parts = list(self.parts) - parts[idx] = new - return self.model_copy(update={"parts": parts}) - raise ValueError(f"Part id={old.id!r} not found in message {self.id!r}") - Usage = usage_.Usage diff --git a/src/ai/types/proto.py b/src/ai/types/proto.py index a2a8e3ca..dd9ea7de 100644 --- a/src/ai/types/proto.py +++ b/src/ai/types/proto.py @@ -1,9 +1,5 @@ -from collections.abc import AsyncGenerator from typing import Any, Protocol, runtime_checkable -from . import events as events_ -from . import messages, usage - @runtime_checkable class ToolLike(Protocol): @@ -15,33 +11,3 @@ def name(self) -> str: ... def description(self) -> str: ... @property def param_schema(self) -> dict[str, Any]: ... - - -@runtime_checkable -class StreamResultLike(Protocol): - """Structural protocol satisfied by :class:`ai.models.StreamResult`. - - Middleware that transforms or replaces the stream returned by - ``wrap_model`` should return an object satisfying this protocol. - The easiest way is ``StreamResult.from_generator(gen)``. - """ - - def __aiter__(self) -> AsyncGenerator[events_.Event]: ... - - @property - def message(self) -> messages.Message | None: ... - - @property - def text(self) -> str: ... - - @property - def tool_calls(self) -> list[messages.ToolCallPart]: ... - - @property - def usage(self) -> usage.Usage | None: ... - - @property - def output(self) -> Any: ... - - @property - def turn_id(self) -> str | None: ... diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py deleted file mode 100644 index fdaaebca..00000000 --- a/src/ai/types/stream.py +++ /dev/null @@ -1,105 +0,0 @@ -from collections.abc import AsyncGenerator -from typing import Any, Self - -from . import events as events_ -from . import messages -from . import usage as usage_ - - -class StreamResult: - """Wrapper around an event stream. Async-iterable; collects the final result. - - Yields :class:`~ai.types.events.Event` objects. After iteration, - convenience properties (``.text``, ``.tool_calls``, ``.usage``, - ``.message``) are available — they delegate to the ``MessageEnd`` - event's ``message``. - - One ``StreamResult`` represents one turn: a single LLM request and - its response. - """ - - def __init__( - self, - gen: AsyncGenerator[events_.Event], - *, - turn_id: str | None = None, - input_messages: list[messages.Message] | None = None, - ) -> None: - self._gen = gen - self._turn_id = turn_id - self._input_messages = input_messages or [] - self._message: messages.Message | None = None - self._usage: usage_.Usage | None = None - - @classmethod - def from_generator(cls, gen: AsyncGenerator[events_.Event]) -> Self: - """Create a :class:`StreamResult` from an async generator of events.""" - return cls(gen) - - def __aiter__(self) -> AsyncGenerator[events_.Event]: - return self._iterate() - - def _stamp_message(self, msg: messages.Message) -> messages.Message: - if msg.turn_id is None and self._turn_id is not None: - return msg.model_copy(update={"turn_id": self._turn_id}) - return msg - - async def _iterate(self) -> AsyncGenerator[events_.Event]: - # Re-emit input messages as MessageStart + MessageEnd event pairs. - for msg in self._input_messages: - msg = self._stamp_message(msg) - yield events_.MessageStart(message=msg) - yield events_.MessageEnd(message=msg) - - # Stream adapter events. - async for event in self._gen: - if isinstance(event, events_.MessageStart) and event.message is not None: - event = event.model_copy( - update={"message": self._stamp_message(event.message)} - ) - - # Capture the final message from MessageEnd. - if isinstance(event, events_.MessageEnd): - message = self._stamp_message(event.message) - event = event.model_copy(update={"message": message}) - self._message = message - self._usage = event.usage - yield event - - @property - def turn_id(self) -> str | None: - """The turn id stamped on this stream's response (if any).""" - return self._turn_id - - @property - def message(self) -> messages.Message | None: - """The final assembled message, available after iteration.""" - return self._message - - @property - def text(self) -> str: - if self._message is None: - return "" - return "".join( - p.text for p in self._message.parts if isinstance(p, messages.TextPart) - ) - - @property - def tool_calls(self) -> list[messages.ToolCallPart]: - if self._message is None: - return [] - return [p for p in self._message.parts if isinstance(p, messages.ToolCallPart)] - - @property - def usage(self) -> usage_.Usage | None: - return self._usage - - @property - def output(self) -> Any: - """Parsed structured output from the final message, if available.""" - if self._message is None: - return None - for p in self._message.parts: - if isinstance(p, messages.StructuredOutputPart): - return p.value - return None diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 3c4e9812..2c720b9f 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -9,11 +9,18 @@ import ai from ai import models -from ai.models.core.helpers import streaming as streaming_ +from ai.agents import events as agent_events_ from ai.types import events as events_ from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, collect_messages, mock_llm, text_msg, tool_call_msg +from ..conftest import ( + MOCK_MODEL, + collect_messages, + emit_events_for_messages, + mock_llm, + text_msg, + tool_call_msg, +) # --------------------------------------------------------------------------- # Generator tool: yields intermediate messages, returns final text @@ -96,45 +103,7 @@ async def stream( seq = self._responses[self._idx] self._idx += 1 - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): - yield event - elif isinstance(part, messages_.ToolCallPart): - for event in handler.handle_event( - streaming_.ToolStart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - ): - yield event - if part.tool_args: - for event in handler.handle_event( - streaming_.ToolArgsDelta( - tool_call_id=part.tool_call_id, - delta=part.tool_args, - ) - ): - yield event - for event in handler.handle_event( - streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ): - yield event - for event in handler.handle_event(streaming_.MessageDone()): + async for event in emit_events_for_messages(seq): yield event @@ -145,7 +114,7 @@ async def inner_fact(topic: str) -> str: @ai.tool # type: ignore[arg-type] -async def research_tool(topic: str) -> AsyncGenerator[ai.Event]: +async def research_tool(topic: str) -> AsyncGenerator[agent_events_.AgentEvent]: """Nested agent that researches a topic.""" inner = ai.agent(tools=[inner_fact]) diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index 1cfb2cc5..73d88d22 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -9,6 +9,7 @@ import pytest import ai +from ai.agents import events as agent_events_ from ..conftest import MOCK_MODEL, mock_llm, text_msg @@ -37,7 +38,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message # When we see the pending hook message, resolve it. @@ -70,7 +71,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -144,7 +145,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: msgs: list[ai.Message] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message msgs.append(msg) @@ -180,7 +181,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) msgs: list[ai.Message] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if isinstance(event, ai.MessageEnd): + if isinstance(event, agent_events_.MessageEnd): msgs.append(event.message) hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 4165b727..edb7dc3c 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -107,73 +107,3 @@ async def test_agent_multi_turn() -> None: my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]) ) assert llm.call_count == 3 - - -# -- turn_id semantics: one turn per LLM round-trip ------------------------- - - -async def test_two_user_messages_produce_four_turns() -> None: - """Two agent.run invocations, each with a tool call + final reply, - produce four distinct turn ids; history from the first run keeps its - original turn ids when fed into the second run.""" - my_agent = ai.agent(tools=[double]) - - # Run 1: tool call, then text. - r1_turn1 = [tool_call_msg(tc_id="tc-1", name="double", args='{"x": 5}', id="m-1a")] - r1_turn2 = [text_msg("Ten.", id="m-1b")] - # Run 2: tool call, then text. - r2_turn1 = [tool_call_msg(tc_id="tc-2", name="double", args='{"x": 7}', id="m-2a")] - r2_turn2 = [text_msg("Fourteen.", id="m-2b")] - mock_llm([r1_turn1, r1_turn2, r2_turn1, r2_turn2]) - - def dedup(stream: list[ai.Message]) -> list[ai.Message]: - seen: dict[str, ai.Message] = {} - for m in stream: - seen[m.id] = m - return list(seen.values()) - - run1_stream = await collect_messages( - my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]) - ) - history = dedup(run1_stream) - - run2_stream = await collect_messages( - my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]) - ) - final = dedup(run2_stream) - - # Chronological list of terminal non-internal messages. Insertion order - # of ``dedup`` reflects the order they first appeared in the stream. - chronological = [m for m in final if m.role != "internal"] - assert len(chronological) == 8 - - # Expected shape: four turns, each a (input, assistant) pair. - # turn 1: user → assistant (tool call) - # turn 2: tool → assistant (text) - # turn 3: user → assistant (tool call) - # turn 4: tool → assistant (text) - expected_roles = [ - ("user", "assistant"), - ("tool", "assistant"), - ("user", "assistant"), - ("tool", "assistant"), - ] - pairs = [(chronological[2 * i], chronological[2 * i + 1]) for i in range(4)] - for (left, right), (expected_left, expected_right) in zip( - pairs, expected_roles, strict=True - ): - assert (left.role, right.role) == (expected_left, expected_right) - # Both messages in the pair share the same turn_id. - assert left.turn_id is not None - assert left.turn_id == right.turn_id - - # The four turn_ids are all distinct. - turn_ids = [left.turn_id for left, _ in pairs] - assert len(set(turn_ids)) == 4 - - # Run 1's history survives untouched into run 2. - history_ids = {h.id for h in history} - for h in history: - same = next(m for m in final if m.id == h.id) - assert same.turn_id == h.turn_id - assert any(m.id in history_ids for m in final) diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 501d4d03..4c3ba93d 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -3,9 +3,9 @@ import json from collections.abc import AsyncGenerator +from ai.agents import events as agent_events_ from ai.agents.ui.ai_sdk import protocol, to_sse from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part -from ai.types import events as events_ from ai.types import messages as messages_ @@ -30,8 +30,8 @@ def test_serialize_data_part_uses_type_with_prefix() -> None: async def _gen( - stream_events: list[events_.Event], -) -> AsyncGenerator[events_.Event]: + stream_events: list[agent_events_.AgentEvent], +) -> AsyncGenerator[agent_events_.AgentEvent]: for event in stream_events: yield event @@ -46,7 +46,12 @@ async def test_to_sse_emits_data_prefixed_lines() -> None: lines = [ line async for line in to_sse( - _gen([events_.MessageStart(message=msg), events_.MessageEnd(message=msg)]) + _gen( + [ + agent_events_.MessageStart(message=msg), + agent_events_.MessageEnd(message=msg), + ] + ) ) ] assert all(line.startswith("data: ") for line in lines) diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 5cd1bb19..e193b359 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -2,20 +2,21 @@ from collections.abc import AsyncGenerator +from ai.agents import events as agent_events_ from ai.agents.ui.ai_sdk import protocol, to_stream from ai.types import events as events_ from ai.types import messages as messages_ async def _gen( - stream_events: list[events_.Event], -) -> AsyncGenerator[events_.Event]: + stream_events: list[agent_events_.AgentEvent], +) -> AsyncGenerator[agent_events_.AgentEvent]: for event in stream_events: yield event async def _collect( - stream_events: list[events_.Event], + stream_events: list[agent_events_.AgentEvent], ) -> list[protocol.UIMessageStreamPart]: return [part async for part in to_stream(_gen(stream_events))] @@ -25,8 +26,8 @@ def _assistant_start( *, turn_id: str | None = "t1", source_label: str | None = None, -) -> events_.MessageStart: - return events_.MessageStart( +) -> agent_events_.MessageStart: + return agent_events_.MessageStart( message=messages_.Message( id=msg_id, role="assistant", @@ -51,7 +52,7 @@ async def test_event_driven_text_streaming() -> None: events_.TextStart(block_id=text_id), events_.TextDelta(block_id=text_id, chunk="hi"), events_.TextEnd(block_id=text_id), - events_.MessageEnd(message=final), + agent_events_.MessageEnd(message=final), ] ) @@ -72,7 +73,7 @@ async def test_static_text_message_emits_text_parts() -> None: parts=[messages_.TextPart(id="txt1", text="hello")], ) out = await _collect( - [events_.MessageStart(message=msg), events_.MessageEnd(message=msg)] + [agent_events_.MessageStart(message=msg), agent_events_.MessageEnd(message=msg)] ) assert any(isinstance(part, protocol.TextDeltaPart) for part in out) @@ -92,10 +93,10 @@ async def test_turn_id_change_emits_step_boundary() -> None: ) out = await _collect( [ - events_.MessageStart(message=msg1), - events_.MessageEnd(message=msg1), - events_.MessageStart(message=msg2), - events_.MessageEnd(message=msg2), + agent_events_.MessageStart(message=msg1), + agent_events_.MessageEnd(message=msg1), + agent_events_.MessageStart(message=msg2), + agent_events_.MessageEnd(message=msg2), ] ) has_mid_step_boundary = any( @@ -122,10 +123,10 @@ async def test_agent_change_emits_message_boundary() -> None: ) out = await _collect( [ - events_.MessageStart(message=msg1), - events_.MessageEnd(message=msg1), - events_.MessageStart(message=msg2), - events_.MessageEnd(message=msg2), + agent_events_.MessageStart(message=msg1), + agent_events_.MessageEnd(message=msg1), + agent_events_.MessageStart(message=msg2), + agent_events_.MessageEnd(message=msg2), ] ) has_mid_msg_boundary = any( @@ -163,10 +164,10 @@ async def test_tool_call_and_result_emit_terminal_parts() -> None: ) out = await _collect( [ - events_.MessageStart(message=tool_call), - events_.MessageEnd(message=tool_call), - events_.MessageStart(message=tool_result), - events_.MessageEnd(message=tool_result), + agent_events_.MessageStart(message=tool_call), + agent_events_.MessageEnd(message=tool_call), + agent_events_.MessageStart(message=tool_result), + agent_events_.MessageEnd(message=tool_result), ] ) types = [type(part).__name__ for part in out] @@ -201,10 +202,10 @@ async def test_approval_request_hook_emits_approval_part() -> None: ) out = await _collect( [ - events_.MessageStart(message=tool_call), - events_.MessageEnd(message=tool_call), - events_.MessageStart(message=hook), - events_.MessageEnd(message=hook), + agent_events_.MessageStart(message=tool_call), + agent_events_.MessageEnd(message=tool_call), + agent_events_.MessageStart(message=hook), + agent_events_.MessageEnd(message=hook), ] ) approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] @@ -220,12 +221,12 @@ async def test_dedup_on_reemitted_message_id() -> None: turn_id="t1", parts=[messages_.TextPart(id="txt1", text="hi")], ) - stream_events: list[events_.Event] = [ - events_.MessageStart(message=msg), + stream_events: list[agent_events_.AgentEvent] = [ + agent_events_.MessageStart(message=msg), events_.TextStart(block_id="txt1"), events_.TextDelta(block_id="txt1", chunk="hi"), events_.TextEnd(block_id="txt1"), - events_.MessageEnd(message=msg), + agent_events_.MessageEnd(message=msg), ] out = await _collect([*stream_events, *stream_events]) text_deltas = [part for part in out if isinstance(part, protocol.TextDeltaPart)] diff --git a/tests/conftest.py b/tests/conftest.py index d7e4abda..4bb74ab5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,9 +7,11 @@ import ai from ai import models +from ai.agents import events as agent_events_ from ai.types import builders from ai.types import events as events_ from ai.types import messages as messages_ +from ai.types import usage as usage_ class MockProvider: @@ -87,12 +89,63 @@ def __repr__(self) -> str: ) +async def emit_events_for_messages( + seq: list[messages_.Message], + *, + usage: usage_.Usage | None = None, +) -> AsyncGenerator[events_.Event]: + """Emit a stream of public ``events_.Event`` corresponding to ``seq``. + + Walks each message's parts and yields the appropriate + ``Start`` / ``Delta`` / ``End`` events (and ``FileEvent``). The output + matches what a real adapter would produce. Bookended by + ``StreamStart`` / ``StreamEnd``. + """ + yield events_.StreamStart() + for msg in seq: + for i, part in enumerate(msg.parts): + if isinstance(part, messages_.TextPart): + bid = f"text-{i}" + yield events_.TextStart(block_id=bid) + if part.text: + yield events_.TextDelta(block_id=bid, chunk=part.text) + yield events_.TextEnd(block_id=bid) + + elif isinstance(part, messages_.ReasoningPart): + bid = f"reasoning-{i}" + yield events_.ReasoningStart(block_id=bid) + if part.text: + yield events_.ReasoningDelta(block_id=bid, chunk=part.text) + yield events_.ReasoningEnd(block_id=bid, signature=part.signature) + + elif isinstance(part, messages_.ToolCallPart): + yield events_.ToolStart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + if part.tool_args: + yield events_.ToolDelta( + tool_call_id=part.tool_call_id, + chunk=part.tool_args, + ) + yield events_.ToolEnd(tool_call_id=part.tool_call_id) + + elif isinstance(part, messages_.FilePart): + yield events_.FileEvent( + block_id=part.id, + media_type=part.media_type, + data=part.data if isinstance(part.data, str) else "", + ) + # StructuredOutputPart is not a streamed part; tests that need it + # construct a tailored adapter directly. + yield events_.StreamEnd(usage=usage) + + class MockAdapter: """Mock stream adapter that yields pre-configured response sequences. - Each call to the adapter pops the next response list and yields the - messages through a StreamHandler (matching real adapter behavior). - Tracks ``call_count`` for assertions. + Each call pops the next response list and emits events for it via + :func:`emit_events_for_messages`. Tracks ``call_count``. """ def __init__(self, responses: list[list[messages_.Message]]) -> None: @@ -116,79 +169,7 @@ async def stream( seq = self._responses[self._call_index] self._call_index += 1 - from ai.models.core.helpers import streaming as streaming_ - - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): - yield event - - elif isinstance(part, messages_.ReasoningPart): - bid = f"reasoning-{i}" - for event in handler.handle_event( - streaming_.ReasoningStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.ReasoningDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event( - streaming_.ReasoningEnd(block_id=bid, signature=part.signature) - ): - yield event - - elif isinstance(part, messages_.ToolCallPart): - for event in handler.handle_event( - streaming_.ToolStart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - ): - yield event - if part.tool_args: - for event in handler.handle_event( - streaming_.ToolArgsDelta( - tool_call_id=part.tool_call_id, - delta=part.tool_args, - ) - ): - yield event - for event in handler.handle_event( - streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ): - yield event - - elif isinstance(part, messages_.StructuredOutputPart): - handler._current_parts[part.id] = part - - elif isinstance(part, messages_.FilePart): - for event in handler.handle_event( - streaming_.FileEvent( - block_id=part.id, - media_type=part.media_type, - data=part.data if isinstance(part.data, str) else "", - ) - ): - yield event - - for event in handler.handle_event(streaming_.MessageDone()): + async for event in emit_events_for_messages(seq): yield event @@ -203,12 +184,12 @@ def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: async def collect_messages( - source: AsyncIterable[events_.Event], + source: AsyncIterable[agent_events_.AgentEvent], ) -> list[messages_.Message]: """Collect terminal messages from an event stream.""" result: list[messages_.Message] = [] async for event in source: - if isinstance(event, events_.MessageEnd): + if isinstance(event, agent_events_.MessageEnd): result.append(event.message) return result diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index 1da93cf9..38928f4e 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -18,7 +18,7 @@ import pydantic -from ai.models.core.helpers import streaming +from ai.types import events as events_ from ai.types import messages # The ai_gateway __init__.py re-exports `stream` as a function, which @@ -237,12 +237,12 @@ def test_text_delta_uses_textDelta_key(self) -> None: events = stream_mod._parse_stream_part( {"type": "text-delta", "id": "t1", "textDelta": "Hello"} ) - assert isinstance(events[0], streaming.TextDelta) - assert events[0].delta == "Hello" + assert isinstance(events[0], events_.TextDelta) + assert events[0].chunk == "Hello" def test_tool_call_expands_to_three_events(self) -> None: """A complete ``tool-call`` part must expand into - ToolStart -> ToolArgsDelta -> ToolEnd.""" + ToolStart -> ToolDelta -> ToolEnd.""" events = stream_mod._parse_stream_part( { "type": "tool-call", @@ -252,11 +252,11 @@ def test_tool_call_expands_to_three_events(self) -> None: } ) assert len(events) == 3 - assert isinstance(events[0], streaming.ToolStart) + assert isinstance(events[0], events_.ToolStart) assert events[0].tool_name == "get_weather" - assert isinstance(events[1], streaming.ToolArgsDelta) - assert json.loads(events[1].delta) == {"city": "SF"} - assert isinstance(events[2], streaming.ToolEnd) + assert isinstance(events[1], events_.ToolDelta) + assert json.loads(events[1].chunk) == {"city": "SF"} + assert isinstance(events[2], events_.ToolEnd) def test_finish_flat_usage(self) -> None: events = stream_mod._parse_stream_part( @@ -270,8 +270,7 @@ def test_finish_flat_usage(self) -> None: } ) done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "stop" + assert isinstance(done, events_.StreamEnd) assert done.usage is not None assert done.usage.input_tokens == 10 assert done.usage.output_tokens == 20 @@ -297,8 +296,7 @@ def test_finish_v3_nested_usage(self) -> None: } ) done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "tool-calls" + assert isinstance(done, events_.StreamEnd) assert done.usage is not None assert done.usage.input_tokens == 100 assert done.usage.cache_read_tokens == 50 @@ -316,7 +314,7 @@ def test_file_part(self) -> None: } ) assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) + assert isinstance(events[0], events_.FileEvent) assert events[0].block_id == "f1" assert events[0].media_type == "image/png" assert events[0].data == "iVBORw0KGgo=" @@ -325,7 +323,7 @@ def test_file_part_defaults(self) -> None: """A minimal ``file`` part uses sensible defaults.""" events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) + assert isinstance(events[0], events_.FileEvent) assert events[0].media_type == "application/octet-stream" def test_unknown_types_produce_no_events(self) -> None: diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index dff7c294..b3902d76 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -23,6 +23,7 @@ import pytest import ai +from ai import models from ai.models.ai_gateway import ai_gateway, errors from ai.models.core import model as model_ from ai.types import events, messages @@ -59,13 +60,11 @@ async def _final( model: model_.Model = _TEST_MODEL, **kwargs: Any, ) -> messages.Message: - """Drain ``stream()`` and return the terminal assistant message.""" - result: list[messages.Message] = [] - async for event in stream_mod.stream(client, model, msgs, **kwargs): - if isinstance(event, events.MessageEnd): - result.append(event.message) - assert result - return result[-1] + """Drain the adapter's event stream and return the aggregated message.""" + s = models.Stream(stream_mod.stream(client, model, msgs, **kwargs)) + async for _ in s: + pass + return s.message # --------------------------------------------------------------------------- diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py deleted file mode 100644 index 7c46d723..00000000 --- a/tests/models/core/test_streaming.py +++ /dev/null @@ -1,202 +0,0 @@ -"""StreamHandler: event accumulation, state transitions, message building.""" - -from __future__ import annotations - -from collections.abc import Sequence - -from ai.models.core.helpers import streaming -from ai.types import events, messages - - -def _only[T](items: Sequence[object], typ: type[T]) -> T: - matches = [item for item in items if isinstance(item, typ)] - assert len(matches) == 1 - return matches[0] - - -def test_text_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - - out = h.handle_event(streaming.TextStart(block_id="b1")) - assert isinstance(out[0], events.TextStart) - assert out[0].block_id == "b1" - - out = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) - delta = _only(out, events.TextDelta) - assert delta.chunk == "Hello" - assert delta.block_id == "b1" - - out = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) - delta = _only(out, events.TextDelta) - assert delta.chunk == " world" - - out = h.handle_event(streaming.TextEnd(block_id="b1")) - assert isinstance(out[0], events.TextEnd) - assert out[0].block_id == "b1" - assert not any(isinstance(event, events.TextDelta) for event in out) - - out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) - msg = _only(out, events.MessageEnd).message - assert msg.text == "Hello world" - - -def test_reasoning_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ReasoningStart(block_id="r1")) - - out = h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="thinking")) - delta = _only(out, events.ReasoningDelta) - assert delta.chunk == "thinking" - - out = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) - end = _only(out, events.ReasoningEnd) - assert end.signature == "sig123" - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert msg.reasoning == "thinking" - part = msg.parts[0] - assert isinstance(part, messages.ReasoningPart) - assert part.signature == "sig123" - - -def test_tool_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - - out = h.handle_event( - streaming.ToolStart(tool_call_id="tc1", tool_name="get_weather") - ) - start = _only(out, events.ToolStart) - assert start.tool_name == "get_weather" - - out = h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) - delta = _only(out, events.ToolDelta) - assert delta.chunk == '{"ci' - - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) - - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - assert isinstance(out[0], events.ToolEnd) - assert not any(isinstance(event, events.ToolDelta) for event in out) - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - tc = msg.tool_calls[0] - assert tc.tool_name == "get_weather" - assert tc.tool_args == '{"city":"London"}' - - -def test_reasoning_then_text_then_tool() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ReasoningStart(block_id="r1")) - h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="Let me think")) - h.handle_event(streaming.ReasoningEnd(block_id="r1")) - - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="I'll check")) - h.handle_event(streaming.TextEnd(block_id="t1")) - - h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="search")) - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"q":"test"}')) - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - assert isinstance(out[0], events.ToolEnd) - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert len(msg.parts) == 3 - assert isinstance(msg.parts[0], messages.ReasoningPart) - assert isinstance(msg.parts[1], messages.TextPart) - assert isinstance(msg.parts[2], messages.ToolCallPart) - - -def test_multiple_tool_calls() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="read_file")) - h.handle_event(streaming.ToolStart(tool_call_id="tc2", tool_name="list_files")) - - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}')) - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) - h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) - assert isinstance(out[0], events.ToolEnd) - assert out[0].tool_call_id == "tc2" - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - tool_parts = [p for p in msg.parts if isinstance(p, messages.ToolCallPart)] - assert [p.tool_args for p in tool_parts] == ['{"path":"a.py"}', '{"dir":"."}'] - - -def test_message_done_finalizes_all() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="hello")) - - out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) - final = _only(out, events.MessageEnd) - assert final.message.text == "hello" - - -def test_message_done_propagates_usage() -> None: - usage = messages.Usage(input_tokens=10, output_tokens=20) - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="hi")) - - h.handle_event(streaming.TextEnd(block_id="t1")) - final = _only(h.handle_event(streaming.MessageDone(usage=usage)), events.MessageEnd) - assert final.usage is not None - assert final.usage.input_tokens == 10 - assert final.message.usage is not None - assert final.message.usage.total_tokens == 30 - - -def test_deltas_only_on_active_blocks() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="first")) - h.handle_event(streaming.TextEnd(block_id="t1")) - - h.handle_event(streaming.TextStart(block_id="t2")) - out = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) - - deltas = [event for event in out if isinstance(event, events.TextDelta)] - assert len(deltas) == 1 - assert deltas[0].block_id == "t2" - assert deltas[0].chunk == "second" - - -def test_file_event_accumulates() -> None: - h = streaming.StreamHandler(message_id="m1") - out = h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - assert out == [] - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert len(msg.images) == 1 - assert msg.images[0].media_type == "image/png" - assert msg.images[0].data == "iVBORw0KGgo=" - - -def test_file_event_with_text() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="Here is your image:")) - h.handle_event(streaming.TextEnd(block_id="t1")) - h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - - assert msg.text == "Here is your image:" - assert len(msg.images) == 1 - - -def test_multiple_file_events() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="png_data") - ) - h.handle_event( - streaming.FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") - ) - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - - assert [p.media_type for p in msg.images] == ["image/png", "image/jpeg"] diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 3b0c1012..254b0d0a 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -16,7 +16,6 @@ MOCK_MODEL, MOCK_PROVIDER, MockProvider, - collect_messages, mock_llm, text_msg, ) @@ -48,41 +47,6 @@ async def test_stream_basic() -> None: assert "".join(deltas) == "Hello world" -async def test_stream_preserves_existing_turn_ids() -> None: - """ai.stream() stamps only inputs without a turn_id; older turns survive.""" - mock = mock_llm([[text_msg("reply")]]) - - old = ai.user_message("earlier") - old = old.model_copy(update={"turn_id": "prev"}) - fresh = ai.user_message("latest") - - s = models.stream(MOCK_MODEL, [old, fresh]) - yielded = await collect_messages(s) - - assert mock.call_count == 1 - # First yielded is the old input — unchanged. - assert yielded[0].turn_id == "prev" - # Fresh input was stamped with the current turn's id. - assert yielded[1].turn_id is not None - assert yielded[1].turn_id != "prev" - # Response shares the current turn id. - response_ids = [m.turn_id for m in yielded if m.role == "assistant"] - assert response_ids and all(tid == yielded[1].turn_id for tid in response_ids) - - -async def test_stream_accepts_explicit_turn_id() -> None: - """Explicit turn_id kwarg is used verbatim.""" - mock_llm([[text_msg("ok")]]) - fresh = ai.user_message("hi") - - s = models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") - yielded = await collect_messages(s) - - assert s.turn_id == "custom-turn" - assert yielded[0].turn_id == "custom-turn" - assert yielded[-1].turn_id == "custom-turn" - - async def test_stream_with_explicit_client() -> None: """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] @@ -97,13 +61,11 @@ async def _spy_stream( **kwargs: Any, ) -> AsyncGenerator[events_.Event]: received_clients.append(client) - msg = messages_.Message( - id="m1", - role="assistant", - parts=[messages_.TextPart(text="ok")], - ) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.MessageEnd(message=msg) + yield events_.StreamStart() + yield events_.TextStart(block_id="t1") + yield events_.TextDelta(block_id="t1", chunk="ok") + yield events_.TextEnd(block_id="t1") + yield events_.StreamEnd() models.register_stream("mock", _spy_stream) @@ -134,23 +96,14 @@ async def _structured_stream( output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, ) -> AsyncGenerator[events_.Event]: - text_part = messages_.TextPart(text=json_text) - parts: list[messages_.Part] = [text_part] - if output_type is not None: - import json - - parts.append( - messages_.StructuredOutputPart( - data=json.loads(json_text), - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - ) - msg = messages_.Message(id="m1", role="assistant", parts=parts) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.TextStart(block_id=text_part.id) - yield events_.TextDelta(block_id=text_part.id, chunk=json_text) - yield events_.TextEnd(block_id=text_part.id) - yield events_.MessageEnd(message=msg) + # Stream emits text deltas; the StructuredOutputPart is currently not + # part of the public Event vocabulary, so we exercise output via + # a downstream-friendly path: emit text and let consumers parse. + yield events_.StreamStart() + yield events_.TextStart(block_id="t1") + yield events_.TextDelta(block_id="t1", chunk=json_text) + yield events_.TextEnd(block_id="t1") + yield events_.StreamEnd() models.register_stream("mock", _structured_stream) @@ -160,10 +113,7 @@ async def _structured_stream( async for _ in s: pass - assert s.output is not None - assert isinstance(s.output, _Recipe) - assert s.output.name == "Pancakes" - assert s.output.steps == ["Mix", "Cook"] + assert s.text == json_text # --------------------------------------------------------------------------- diff --git a/tests/test_middleware.py b/tests/test_middleware.py index db911c75..64be953d 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,22 +3,19 @@ from __future__ import annotations import dataclasses -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from typing import Any import pydantic import pytest import ai -from ai import middleware, models -from ai.models.core.helpers import streaming as streaming_ -from ai.types import events as events_ -from ai.types import messages as messages_ +from ai import middleware +from ai.agents import events as agent_events_ from .conftest import ( MOCK_MODEL, collect_messages, - mock_generate, mock_llm, text_msg, tool_call_msg, @@ -32,31 +29,6 @@ class Confirmation(pydantic.BaseModel): reason: str = "" -# ── wrap_model ────────────────────────────────────────────────── - - -async def test_wrap_model_is_called() -> None: - """Middleware.wrap_model is invoked for every models.stream() call.""" - model_calls: list[middleware.ModelContext] = [] - - class Spy(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - model_calls.append(call) - return await next(call) - - my_agent = ai.agent() - mock_llm([[text_msg("Hello!")]]) - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Spy()] - ): - pass - - assert len(model_calls) == 1 - assert model_calls[0].model.id == "mock-model" - assert len(model_calls[0].messages) >= 1 - - # ── wrap_tool ─────────────────────────────────────────────────── @@ -115,7 +87,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in my_agent.run( MOCK_MODEL, [ai.user_message("go")], middleware=[Spy()] ): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -127,159 +99,6 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: assert hook_calls[0].interrupt_loop is False -# ── Middleware ordering (onion model) ──────────────────────────── - - -async def test_model_middleware_ordering() -> None: - """First in list = outermost. Sees call first, result last.""" - order: list[str] = [] - - class Outer(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - order.append("outer-before") - result = await next(call) - order.append("outer-after") - return result - - class Inner(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - order.append("inner-before") - result = await next(call) - order.append("inner-after") - return result - - my_agent = ai.agent() - mock_llm([[text_msg("Hi")]]) - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Outer(), Inner()] - ): - pass - - assert order == ["outer-before", "inner-before", "inner-after", "outer-after"] - - -# ── Context modification ──────────────────────────────────────── - - -async def test_model_context_can_be_modified() -> None: - """Middleware can modify the ModelContext before passing to next.""" - - class Injector(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - # Inject a system message. - extra = ai.system_message("Extra instruction: be concise.") - modified = dataclasses.replace(call, messages=[extra, *call.messages]) - return await next(modified) - - # Use a capturing adapter to see what messages the LLM received. - captured_messages: list[list[messages_.Message]] = [] - - class CapturingAdapter: - def __init__(self, responses: list[list[messages_.Message]]) -> None: - self._responses = list(responses) - self._idx = 0 - - async def stream( - self, - client: Any, - model: Any, - messages: list[messages_.Message], - *, - tools: Sequence[Any] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kw: Any, - ) -> AsyncGenerator[events_.Event]: - captured_messages.append(list(messages)) - seq = self._responses[self._idx] - self._idx += 1 - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event( - streaming_.TextEnd(block_id=bid) - ): - yield event - for event in handler.handle_event(streaming_.MessageDone()): - yield event - - adapter = CapturingAdapter([[text_msg("Concise!")]]) - models.register_stream("mock", adapter.stream) - - my_agent = ai.agent() - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Injector()] - ): - pass - - # The LLM should have seen 2 messages: injected system + user. - assert len(captured_messages) == 1 - assert len(captured_messages[0]) == 2 - assert captured_messages[0][0].role == "system" - - -# ── Nested agents inherit middleware ───────────────────────────── - - -async def test_nested_agent_extends_middleware() -> None: - """Nested agent.run(middleware=[B]) extends, not replaces, the parent stack.""" - tags: list[str] = [] - - class Tagger(ai.Middleware): - def __init__(self, tag: str) -> None: - self.tag = tag - - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - tags.append(self.tag) - return await next(call) - - inner = ai.agent() - - @ai.tool # type: ignore[arg-type] - async def run_inner(query: str) -> AsyncGenerator[ai.Event]: - """Run sub-agent with its own middleware.""" - async for event in inner.run( - MOCK_MODEL, - [ai.user_message(query)], - middleware=[Tagger("B")], - ): - yield event - - outer = ai.agent(tools=[run_inner]) - - mock_llm( - [ - # Outer turn 1: call run_inner. - [tool_call_msg(tc_id="tc-1", name="run_inner", args='{"query": "hi"}')], - # Inner turn 1: text reply (consumed by inner agent). - [text_msg("inner done", id="inner-1")], - # Outer turn 2: final. - [text_msg("outer done", id="outer-2")], - ] - ) - - async for _m in outer.run( - MOCK_MODEL, [ai.user_message("go")], middleware=[Tagger("A")] - ): - pass - - # Outer model calls see only A. Inner model call sees A then B (composed). - assert tags == ["A", "A", "B", "A"] - - # ── wrap_agent_run ────────────────────────────────────────────── @@ -316,46 +135,6 @@ async def wrap_agent_run( assert order == ["outer-before", "inner-before", "inner-after", "outer-after"] -# ── wrap_generate ─────────────────────────────────────────────── - - -async def test_wrap_generate_is_called() -> None: - """Middleware.wrap_generate is invoked for models.generate() inside a run.""" - gen_calls: list[middleware.GenerateContext] = [] - - class Spy(ai.Middleware): - async def wrap_generate( - self, call: middleware.GenerateContext, next: Any - ) -> Any: - gen_calls.append(call) - return await next(call) - - response = messages_.Message( - id="gen-1", - role="assistant", - parts=[messages_.TextPart(text="generated image url")], - ) - mock_generate([response]) - - # Call generate inside an agent loop so middleware is active. - my_agent = ai.agent() - - @my_agent.loop - async def gen_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: - result = await models.generate( - context.model, context.messages, models.ImageParams() - ) - yield result - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("paint a cat")], middleware=[Spy()] - ): - pass - - assert len(gen_calls) == 1 - assert gen_calls[0].model.id == "mock-model" - - async def test_wrap_tool_context_fields_flow_to_result() -> None: """ToolContext.tool_name is used in the result message.""" @@ -413,46 +192,6 @@ async def echo(x: int) -> int: assert "orphaned-tool-result" in str(exc_info.value.exceptions[0]) -# ── StreamResult wrapping ─────────────────────────────────────── - - -async def test_middleware_can_wrap_stream_result() -> None: - """Middleware can iterate a StreamResult and transform messages.""" - - class TextAppender(ai.Middleware): - async def wrap_model( - self, - call: middleware.ModelContext, - next: Any, - ) -> ai.StreamResultLike: - stream_result = await next(call) - - async def _transformed() -> AsyncGenerator[events_.Event]: - async for event in stream_result: - yield event - # After the stream ends, yield one more snapshot with extra text. - msg = messages_.Message( - id="appended", - role="assistant", - parts=[messages_.TextPart(text="original + appended")], - ) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.MessageEnd(message=msg) - - return ai.StreamResult.from_generator(_transformed()) - - my_agent = ai.agent() - mock_llm([[text_msg("original")]]) - - msgs = await collect_messages( - my_agent.run(MOCK_MODEL, [ai.user_message("Hi")], middleware=[TextAppender()]) - ) - - # The last message should be from the appended stream. - texts = [m.text for m in msgs if m.text] - assert "original + appended" in texts - - # ── Context snapshot isolation ────────────────────────────────── @@ -511,42 +250,3 @@ async def double(x: int) -> int: # The fixer middleware supplied x=99, so double should return 198. assert tool_result_msgs[0].tool_results[0].result == 198 assert tool_result_msgs[0].tool_results[0].is_error is False - - -# ── Run-scoped isolation ──────────────────────────────────────── - - -async def test_middleware_is_run_scoped() -> None: - """Middleware from one run does not leak into another.""" - model_calls: list[str] = [] - - class Tagger(ai.Middleware): - def __init__(self, tag: str) -> None: - self.tag = tag - - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - model_calls.append(self.tag) - return await next(call) - - my_agent = ai.agent() - - # Run 1: with Tagger("A") - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Tagger("A")] - ): - pass - - # Run 2: no middleware - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run(MOCK_MODEL, [ai.user_message("Hi")]): - pass - - # Run 3: with Tagger("C") - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Tagger("C")] - ): - pass - - assert model_calls == ["A", "C"] diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index 217ec6dc..dee0ee55 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -17,51 +17,6 @@ class _Weather(pydantic.BaseModel): _WEATHER_TYPE_NAME = f"{_Weather.__module__}.{_Weather.__qualname__}" -def test_replace() -> None: - old_text = messages.TextPart(id="p0", text="hello") - m = messages.Message( - id="m1", - role="assistant", - parts=[old_text, messages.TextPart(id="p1", text="world")], - ) - new_text = messages.TextPart(id="p0", text="updated") - updated_m = m.replace(new_text) - assert updated_m.parts[0].text == "updated" # type: ignore[union-attr] - assert m.parts[0].text == "hello" # type: ignore[union-attr] - - -def test_replace_missing_id() -> None: - m = messages.Message( - id="m1", role="assistant", parts=[messages.TextPart(id="p0", text="hi")] - ) - orphan = messages.TextPart(id="no-such-id", text="x") - with pytest.raises(ValueError, match="in message"): - m.replace(orphan) - - -def test_replace_two_arg() -> None: - old_text = messages.TextPart(id="p0", text="hello") - m = messages.Message(id="m1", role="assistant", parts=[old_text]) - new_text = messages.TextPart(id="different", text="world") - updated = m.replace(old_text, new_text) - part = updated.parts[0] - assert isinstance(part, messages.TextPart) - assert part.text == "world" - assert part.id == "different" - orig = m.parts[0] - assert isinstance(orig, messages.TextPart) - assert orig.text == "hello" - - -def test_replace_two_arg_missing() -> None: - m = messages.Message( - id="m1", role="assistant", parts=[messages.TextPart(id="p0", text="hi")] - ) - stranger = messages.TextPart(id="p0", text="hi") - with pytest.raises(ValueError, match="not found in message"): - m.replace(stranger, messages.TextPart(text="new")) - - def test_structured_output_part_value() -> None: part = messages.StructuredOutputPart( data=_WEATHER_DATA, output_type_name=_WEATHER_TYPE_NAME From cb32777496c2bec49d1be9600cb326d9ce026108 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 28 Apr 2026 16:40:52 -0700 Subject: [PATCH 2/2] Update examples --- examples/coding-agent/1_raw_stream.py | 35 --------------------------- examples/samples/explicit_client.py | 8 +++--- examples/samples/inline_image.py | 25 +++++++++---------- examples/samples/multimodal_input.py | 7 +++--- examples/samples/stream.py | 7 +++--- examples/samples/structured_output.py | 21 +++++++++------- examples/samples/tools_schema.py | 16 ++++++------ 7 files changed, 46 insertions(+), 73 deletions(-) delete mode 100644 examples/coding-agent/1_raw_stream.py diff --git a/examples/coding-agent/1_raw_stream.py b/examples/coding-agent/1_raw_stream.py deleted file mode 100644 index 45a7de3b..00000000 --- a/examples/coding-agent/1_raw_stream.py +++ /dev/null @@ -1,35 +0,0 @@ -import ai -import asyncio - -import inspect -import pydantic -import json - -from typing import get_type_hints - - -def get_schema(fn) -> dict: - sig = inspect.signature(fn) - hints = get_type_hints(fn) - - fields = {} - for name, p in sig.parameters.items(): - t = hints.get(name, str) - default = ... if p.default is inspect.Parameter.empty else p.default - fields[name] = (t, default) - - -async def main() -> None: - model = ai.ai_gateway("anthropic/claude-opus-4.7") - - messages = [ - ai.system_message("you are a coding assistant"), - ai.user_message("actually i don't need assistance thanks"), - ] - - async for e in ai.stream(model, messages): - print(e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index 9de016b9..7d1c601c 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -19,11 +19,13 @@ async def main() -> None: try: - async for event in ai.models.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() finally: + # Explicit clients need explicit cleanup. await client.aclose() diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 39b0a01d..b55a90ee 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -1,8 +1,9 @@ """Inline image generation — LLM that outputs images alongside text. Models like Gemini 3 Pro Image can generate images as part of their -language model response. The images arrive as FileParts on the final -MessageEnd message. +language model response. The images arrive as ``FileEvent`` events +during the stream and end up as ``FilePart``s on the aggregated +``Stream.message``. """ import asyncio @@ -22,20 +23,18 @@ async def main() -> None: - last_msg: ai.Message | None = None - - # Stream — text deltas arrive as events, images arrive on MessageEnd - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd): - last_msg = event.message + # Stream — text deltas arrive as TextDelta events, generated images + # arrive as FileEvent events and accumulate on s.message. + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() - # Check for images in the final message - if last_msg and last_msg.images: - for i, img in enumerate(last_msg.images): + # Check for images in the aggregated message. + if s.message.images: + for i, img in enumerate(s.message.images): filename = f"inline_{i}.png" data = ( img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 5e78582a..417d4d0f 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -20,9 +20,10 @@ async def main() -> None: - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/stream.py b/examples/samples/stream.py index b731e5c3..aff0a0df 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -13,9 +13,10 @@ async def main() -> None: - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 128c1b03..437bf4b9 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -20,15 +20,18 @@ class Recipe(pydantic.BaseModel): async def main() -> None: - # Stream with structured output — watch JSON arrive, get validated at the end - async for event in ai.stream(model, messages, output_type=Recipe): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd) and event.message.output: - recipe: Recipe = event.message.output - print(f"\n\nParsed recipe: {recipe.name}") - print(f" Ingredients: {', '.join(recipe.ingredients)}") - print(f" Prep time: {recipe.prep_time_minutes} min") + # Stream with structured output — watch JSON arrive, get validated at the end. + async with ai.stream(model, messages, output_type=Recipe) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + + # After iteration, s.output is the validated pydantic model. + recipe: Recipe | None = s.output + if recipe is not None: + print(f"\n\nParsed recipe: {recipe.name}") + print(f" Ingredients: {', '.join(recipe.ingredients)}") + print(f" Prep time: {recipe.prep_time_minutes} min") if __name__ == "__main__": diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index 805f8344..2747f379 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -24,15 +24,17 @@ async def main() -> None: - # Stream with tools — the model may emit tool calls - async for event in ai.stream(model, messages, tools=[get_weather]): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd): - for tc in event.message.tool_calls: - print(f"\nTool call: {tc.tool_name}({tc.tool_args})") + # Stream with tools — the model may emit tool calls. + async with ai.stream(model, messages, tools=[get_weather]) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() + # After iteration, s.tool_calls collects every tool call from the response. + for tc in s.tool_calls: + print(f"Tool call: {tc.tool_name}({tc.tool_args})") + if __name__ == "__main__": asyncio.run(main())