diff --git a/examples/coding-agent/1_raw_stream.py b/examples/coding-agent/1_raw_stream.py deleted file mode 100644 index 45a7de3b..00000000 --- a/examples/coding-agent/1_raw_stream.py +++ /dev/null @@ -1,35 +0,0 @@ -import ai -import asyncio - -import inspect -import pydantic -import json - -from typing import get_type_hints - - -def get_schema(fn) -> dict: - sig = inspect.signature(fn) - hints = get_type_hints(fn) - - fields = {} - for name, p in sig.parameters.items(): - t = hints.get(name, str) - default = ... if p.default is inspect.Parameter.empty else p.default - fields[name] = (t, default) - - -async def main() -> None: - model = ai.ai_gateway("anthropic/claude-opus-4.7") - - messages = [ - ai.system_message("you are a coding assistant"), - ai.user_message("actually i don't need assistance thanks"), - ] - - async for e in ai.stream(model, messages): - print(e) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index 9de016b9..7d1c601c 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -19,11 +19,13 @@ async def main() -> None: try: - async for event in ai.models.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() finally: + # Explicit clients need explicit cleanup. await client.aclose() diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 39b0a01d..b55a90ee 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -1,8 +1,9 @@ """Inline image generation — LLM that outputs images alongside text. Models like Gemini 3 Pro Image can generate images as part of their -language model response. The images arrive as FileParts on the final -MessageEnd message. +language model response. The images arrive as ``FileEvent`` events +during the stream and end up as ``FilePart``s on the aggregated +``Stream.message``. """ import asyncio @@ -22,20 +23,18 @@ async def main() -> None: - last_msg: ai.Message | None = None - - # Stream — text deltas arrive as events, images arrive on MessageEnd - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd): - last_msg = event.message + # Stream — text deltas arrive as TextDelta events, generated images + # arrive as FileEvent events and accumulate on s.message. + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() - # Check for images in the final message - if last_msg and last_msg.images: - for i, img in enumerate(last_msg.images): + # Check for images in the aggregated message. + if s.message.images: + for i, img in enumerate(s.message.images): filename = f"inline_{i}.png" data = ( img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 5e78582a..417d4d0f 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -20,9 +20,10 @@ async def main() -> None: - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/stream.py b/examples/samples/stream.py index b731e5c3..aff0a0df 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -13,9 +13,10 @@ async def main() -> None: - async for event in ai.stream(model, messages): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + async with ai.stream(model, messages) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 128c1b03..437bf4b9 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -20,15 +20,18 @@ class Recipe(pydantic.BaseModel): async def main() -> None: - # Stream with structured output — watch JSON arrive, get validated at the end - async for event in ai.stream(model, messages, output_type=Recipe): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd) and event.message.output: - recipe: Recipe = event.message.output - print(f"\n\nParsed recipe: {recipe.name}") - print(f" Ingredients: {', '.join(recipe.ingredients)}") - print(f" Prep time: {recipe.prep_time_minutes} min") + # Stream with structured output — watch JSON arrive, get validated at the end. + async with ai.stream(model, messages, output_type=Recipe) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + + # After iteration, s.output is the validated pydantic model. + recipe: Recipe | None = s.output + if recipe is not None: + print(f"\n\nParsed recipe: {recipe.name}") + print(f" Ingredients: {', '.join(recipe.ingredients)}") + print(f" Prep time: {recipe.prep_time_minutes} min") if __name__ == "__main__": diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index 805f8344..2747f379 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -24,15 +24,17 @@ async def main() -> None: - # Stream with tools — the model may emit tool calls - async for event in ai.stream(model, messages, tools=[get_weather]): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) - elif isinstance(event, ai.MessageEnd): - for tc in event.message.tool_calls: - print(f"\nTool call: {tc.tool_name}({tc.tool_args})") + # Stream with tools — the model may emit tool calls. + async with ai.stream(model, messages, tools=[get_weather]) as s: + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() + # After iteration, s.tool_calls collects every tool call from the response. + for tc in s.tool_calls: + print(f"Tool call: {tc.tool_name}({tc.tool_args})") + if __name__ == "__main__": asyncio.run(main()) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 840fbcd2..45ec41ef 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -17,10 +17,15 @@ from .middleware import AgentRunContext, Middleware from .models import ( Client, + Executor, + GenerateExecutor, + GenerateRequest, ImageParams, Model, Provider, - StreamResult, + Stream, + StreamExecutor, + StreamRequest, VideoParams, ai_gateway, anthropic, @@ -32,22 +37,20 @@ # Re-export core types from .types import ( - End, Event, + FileEvent, FilePart, HookPart, HookResolution, HookSuspention, Message, - MessageEnd, - MessageStart, Part, ReasoningDelta, ReasoningEnd, ReasoningPart, ReasoningStart, - Start, - StreamResultLike, + StreamEnd, + StreamStart, StructuredOutputPart, TextDelta, TextEnd, @@ -74,12 +77,10 @@ __all__ = [ # Types (from types/) - "Start", - "End", "Event", "Message", - "MessageStart", - "MessageEnd", + "StreamStart", + "StreamEnd", "Part", "TextPart", "TextStart", @@ -94,6 +95,7 @@ "ReasoningStart", "ReasoningDelta", "ReasoningEnd", + "FileEvent", "FilePart", "HookPart", "HookSuspention", @@ -116,8 +118,12 @@ "ImageParams", "VideoParams", "Client", - "StreamResult", - "StreamResultLike", + "Stream", + "StreamRequest", + "GenerateRequest", + "Executor", + "StreamExecutor", + "GenerateExecutor", "check_connection", "stream", "generate", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index d399cf96..37fcdc20 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -13,6 +13,7 @@ from .. import middleware as middleware_ from .. import models, types from ..types import builders +from . import events as events_ from . import runtime @@ -199,21 +200,23 @@ def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]: ] -StreamItem = types.Event | types.Message +StreamItem = events_.AgentEvent | types.Message class LoopFn(Protocol): def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ... -async def _message_events(message: types.Message) -> AsyncGenerator[types.Event]: - yield types.MessageStart(message=message) - yield types.MessageEnd(message=message) +async def _message_events( + message: types.Message, +) -> AsyncGenerator[events_.AgentEvent]: + yield events_.MessageStart(message=message) + yield events_.MessageEnd(message=message) async def _coerce_events( source: AsyncIterable[StreamItem], -) -> AsyncGenerator[types.Event]: +) -> AsyncGenerator[events_.AgentEvent]: async for item in source: if isinstance(item, types.Message): async for event in _message_events(item): @@ -222,15 +225,23 @@ async def _coerce_events( yield item -async def _default_loop(context: Context) -> AsyncGenerator[types.Event]: +async def _default_loop(context: Context) -> AsyncGenerator[events_.AgentEvent]: while True: stream = models.stream( context.model, context.messages, tools=context.tools, ) - async for event in stream: - yield event + async for stream_event in stream: + yield stream_event + + # Bridge: emit MessageStart/MessageEnd around the assistant message + # the model stream just produced, so _collect_messages and downstream + # consumers (AI-SDK outbound, label stamping) see the same boundary + # events they did under the previous adapter contract. + if stream.message is not None and stream.message.parts: + async for boundary in _message_events(stream.message): + yield boundary tool_calls = context.resolve(stream.tool_calls) if not tool_calls: @@ -244,14 +255,14 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Event]: # Left un-stamped: the tool result is the input of the *next* turn, # so the next stream() call will stamp it with that turn's id. tool_msg = builders.tool_message(*(t.result() for t in tasks)) - async for event in _message_events(tool_msg): - yield event + async for boundary in _message_events(tool_msg): + yield boundary async def _collect_messages( source: AsyncIterable[StreamItem], messages: list[types.Message], -) -> AsyncGenerator[types.Event]: +) -> AsyncGenerator[events_.AgentEvent]: """Intercept yielded events and collect MessageEnd messages into *messages*. This runs on the **producer** side (same coroutine as the loop function), @@ -260,7 +271,7 @@ async def _collect_messages( happened on the consumer side of the runtime queue. """ async for event in _coerce_events(source): - if isinstance(event, types.MessageEnd): + if isinstance(event, events_.MessageEnd): message = event.message for i, existing in enumerate(messages): if existing.id == message.id: @@ -292,7 +303,7 @@ async def yield_from(source: AsyncIterable[StreamItem]) -> str: last: types.Message | None = None async for item in _coerce_events(source): await rt.put_event(item) - if isinstance(item, types.MessageEnd): + if isinstance(item, events_.MessageEnd): last = item.message return last.text if last else "" @@ -325,7 +336,7 @@ async def run( *, label: str | None = None, middleware: list[middleware_.Middleware] | None = None, - ) -> AsyncGenerator[types.Event]: + ) -> AsyncGenerator[events_.AgentEvent]: """Run the agent loop, yielding events to the consumer. Args: @@ -349,7 +360,7 @@ async def run( async def _real( call: middleware_.AgentRunContext, - ) -> AsyncGenerator[types.Event]: + ) -> AsyncGenerator[events_.AgentEvent]: context = Context( model=call.model, messages=list(call.messages), @@ -359,8 +370,8 @@ async def _real( async for event in runtime.run(source): if call.label is not None: event_message: types.Message | None = None - if isinstance(event, types.MessageEnd) or ( - isinstance(event, types.MessageStart) + if isinstance(event, events_.MessageEnd) or ( + isinstance(event, events_.MessageStart) and event.message is not None ): event_message = event.message diff --git a/src/ai/agents/events.py b/src/ai/agents/events.py new file mode 100644 index 00000000..12108464 --- /dev/null +++ b/src/ai/agents/events.py @@ -0,0 +1,45 @@ +"""Agent-layer event types. + +The model layer emits ``StreamStart`` / ``StreamEnd`` plus block-level +deltas. The agent layer wraps those with ``MessageStart`` / ``MessageEnd`` +boundaries that delimit complete messages — assistant turns produced by +the model, plus synthetic user / tool / hook messages injected into the +runtime queue. + +These types live here (rather than in ``ai.types.events``) because they +are an agent-runtime concern, not part of the public model-streaming +event vocabulary. +""" + +from __future__ import annotations + +from typing import Literal + +import pydantic + +from .. import types + + +class MessageStart(pydantic.BaseModel): + message: types.Message | None = None + + kind: Literal["message_start"] = "message_start" + + +class MessageEnd(pydantic.BaseModel): + message: types.Message + usage: types.Usage | None = None + + kind: Literal["message_end"] = "message_end" + + +# Widened event alias used inside agents/. Not part of ``types.Event``'s +# discriminated union — these wrappers do not flow through model adapters. +AgentEvent = types.Event | MessageStart | MessageEnd + + +__all__ = [ + "AgentEvent", + "MessageEnd", + "MessageStart", +] diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index cb8769ce..a8005748 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -7,6 +7,7 @@ from collections.abc import AsyncGenerator, AsyncIterable, Awaitable from .. import types +from . import events as events_ from . import hooks as hooks_ from .mcp import client as mcp_client @@ -20,17 +21,17 @@ class _Sentinel: _SENTINEL = _Sentinel() def __init__(self) -> None: - self._event_queue: asyncio.Queue[types.Event | Runtime._Sentinel] = ( + self._event_queue: asyncio.Queue[events_.AgentEvent | Runtime._Sentinel] = ( asyncio.Queue() ) self._hook_labels: set[str] = set() - async def put_event(self, event: types.Event) -> None: + async def put_event(self, event: events_.AgentEvent) -> None: await self._event_queue.put(event) async def put_message(self, message: types.Message) -> None: - await self.put_event(types.MessageStart(message=message)) - await self.put_event(types.MessageEnd(message=message)) + await self.put_event(events_.MessageStart(message=message)) + await self.put_event(events_.MessageEnd(message=message)) async def signal_done(self) -> None: await self._event_queue.put(self._SENTINEL) @@ -61,8 +62,8 @@ async def _stop_when_done(runtime: Runtime, task: Awaitable[None]) -> None: async def run( - source: AsyncIterable[types.Event], -) -> AsyncGenerator[types.Event]: + source: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[events_.AgentEvent]: """Run *source* and yield every event that gets put into the Runtime queue.""" rt = Runtime() token = _runtime.set(rt) diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index e6c1f581..d911e9cb 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncGenerator, AsyncIterable -from .....types import events as events_ +from ....events import AgentEvent from .. import protocol from .stream import to_stream @@ -32,7 +32,7 @@ def format_sse(part: protocol.UIMessageStreamPart) -> str: async def to_sse( - events: AsyncIterable[events_.Event], + events: AsyncIterable[AgentEvent], ) -> AsyncGenerator[str]: """Convert an internal event stream into SSE strings.""" async for part in to_stream(events): diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 5ec0db5f..a569342f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -4,13 +4,13 @@ from collections.abc import AsyncGenerator, AsyncIterable -from .....types import events as events_ +from ....events import AgentEvent, MessageEnd, MessageStart from .. import protocol from ._state import _StreamState async def to_stream( - events: AsyncIterable[events_.Event], + events: AsyncIterable[AgentEvent], ) -> AsyncGenerator[protocol.UIMessageStreamPart]: """Walk ``events`` once, emitting AI SDK UI stream parts. @@ -21,10 +21,10 @@ async def to_stream( state = _StreamState() async for event in events: - if isinstance(event, events_.MessageStart): + if isinstance(event, MessageStart): for part in state.on_message_start(event.message): yield part - elif isinstance(event, events_.MessageEnd): + elif isinstance(event, MessageEnd): for part in state.on_terminal(event.message): yield part else: diff --git a/src/ai/middleware.py b/src/ai/middleware.py index a8f77b81..afac14b6 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -22,9 +22,15 @@ import pydantic -from .types import events as events_ from .types import messages as messages_ -from .types.proto import StreamResultLike, ToolLike +from .types.proto import ToolLike + +# Compat shim: ``StreamResultLike`` was removed from ``ai.types.proto`` when +# the model layer was reworked. Middleware is dead code under the new +# ``Executor``-based ``api.py`` and is kept around only so the agents +# rewrite can land separately; ``Any`` is enough to keep the existing +# annotations type-checking. +type StreamResultLike = Any # --------------------------------------------------------------------------- # Call context objects — frozen dataclasses with isolated mutable fields. @@ -113,8 +119,11 @@ def __post_init__(self) -> None: # Middleware base class — override the methods you care about. # --------------------------------------------------------------------------- -# Event/message aliases for brevity in signatures. -_Event = events_.Event +# Event/message aliases for brevity in signatures. ``_Event`` is intentionally +# typed as ``Any`` so the agent-run chain accepts the wider ``AgentEvent`` +# union (which includes ``MessageStart``/``MessageEnd``) without a circular +# import from ``ai.agents``. +_Event = Any _Message = messages_.Message # Agent run next-function type: call -> async generator of events. diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 61a97767..4f0da85d 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -25,29 +25,42 @@ ids = await openai.list() """ -from ..types.proto import StreamResultLike from .ai_gateway import ai_gateway from .anthropic import anthropic from .core.adapters import register_generate, register_stream -from .core.api import check_connection, generate, stream +from .core.api import ( + Executor, + GenerateExecutor, + GenerateRequest, + Stream, + StreamExecutor, + StreamRequest, + check_connection, + generate, + stream, +) from .core.client import Client from .core.model import Model +from .core.params import GenerateParams, ImageParams, VideoParams from .core.proto import CheckConnFn, GenerateFn, Provider, StreamFn -from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams from .openai import openai __all__ = [ # Core types "CheckConnFn", "Client", + "Executor", + "GenerateExecutor", "GenerateFn", "GenerateParams", + "GenerateRequest", "ImageParams", "Model", "Provider", + "Stream", + "StreamExecutor", "StreamFn", - "StreamResult", - "StreamResultLike", + "StreamRequest", "VideoParams", # Provider factories "ai_gateway", diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py index 9a6dee45..fec98f57 100644 --- a/src/ai/models/ai_gateway/generate.py +++ b/src/ai/models/ai_gateway/generate.py @@ -14,9 +14,9 @@ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files -from ..core.types import GenerateParams as GenerateParams -from ..core.types import ImageParams as ImageParams -from ..core.types import VideoParams as VideoParams +from ..core.params import GenerateParams as GenerateParams +from ..core.params import ImageParams as ImageParams +from ..core.params import VideoParams as VideoParams from . import _common, errors # --------------------------------------------------------------------------- diff --git a/src/ai/models/ai_gateway/stream.py b/src/ai/models/ai_gateway/stream.py index 8d97c5d9..ce143806 100644 --- a/src/ai/models/ai_gateway/stream.py +++ b/src/ai/models/ai_gateway/stream.py @@ -19,7 +19,7 @@ from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import files, streaming +from ..core.helpers import files from . import _common, errors # --------------------------------------------------------------------------- @@ -154,20 +154,20 @@ async def _build_request_body( # --------------------------------------------------------------------------- -# SSE response parsing — v3 stream parts → StreamEvent +# SSE response parsing — v3 stream parts → public Event # --------------------------------------------------------------------------- -def _expand_tool_call(data: dict[str, Any]) -> list[streaming.StreamEvent]: - """Expand a complete ``tool-call`` part into Start + ArgsDelta + End.""" +def _expand_tool_call(data: dict[str, Any]) -> list[events_.Event]: + """Expand a complete ``tool-call`` part into Start + Delta + End.""" tc_id = data.get("toolCallId", "") tool_name = data.get("toolName", "") tool_input = data.get("input", "") args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) return [ - streaming.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - streaming.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), - streaming.ToolEnd(tool_call_id=tc_id), + events_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + events_.ToolDelta(tool_call_id=tc_id, chunk=args_str), + events_.ToolEnd(tool_call_id=tc_id), ] @@ -198,40 +198,40 @@ def _parse_usage(data: Any) -> usage_.Usage: ) -def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: - """Convert a ``LanguageModelV3StreamPart`` to internal events.""" +def _parse_stream_part(data: dict[str, Any]) -> list[events_.Event]: + """Convert a ``LanguageModelV3StreamPart`` to public events.""" match data.get("type", ""): case "text-start": - return [streaming.TextStart(block_id=data.get("id", "text"))] + return [events_.TextStart(block_id=data.get("id", "text"))] case "text-delta": return [ - streaming.TextDelta( + events_.TextDelta( block_id=data.get("id", "text"), - delta=data.get("textDelta", data.get("delta", "")), + chunk=data.get("textDelta", data.get("delta", "")), ) ] case "text-end": - return [streaming.TextEnd(block_id=data.get("id", "text"))] + return [events_.TextEnd(block_id=data.get("id", "text"))] case "reasoning-start": - return [streaming.ReasoningStart(block_id=data.get("id", "reasoning"))] + return [events_.ReasoningStart(block_id=data.get("id", "reasoning"))] case "reasoning-delta": return [ - streaming.ReasoningDelta( + events_.ReasoningDelta( block_id=data.get("id", "reasoning"), - delta=data.get("delta", ""), + chunk=data.get("delta", ""), ) ] case "reasoning-end": - return [streaming.ReasoningEnd(block_id=data.get("id", "reasoning"))] + return [events_.ReasoningEnd(block_id=data.get("id", "reasoning"))] case "tool-input-start": return [ - streaming.ToolStart( + events_.ToolStart( tool_call_id=data.get("id", ""), tool_name=data.get("toolName", ""), ) @@ -239,22 +239,22 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: case "tool-input-delta": return [ - streaming.ToolArgsDelta( + events_.ToolDelta( tool_call_id=data.get("id", ""), - delta=data.get("delta", ""), + chunk=data.get("delta", ""), ) ] case "tool-input-end": - return [streaming.ToolEnd(tool_call_id=data.get("id", ""))] + return [events_.ToolEnd(tool_call_id=data.get("id", ""))] case "tool-call": return _expand_tool_call(data) case "file": return [ - streaming.FileEvent( - block_id=data.get("id", f"file-{len(data)}"), + events_.FileEvent( + block_id=data.get("id", ""), media_type=data.get("mediaType", "application/octet-stream"), data=data.get("data", ""), ) @@ -263,14 +263,7 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: case "finish": usage_data = data.get("usage") usage = _parse_usage(usage_data) if usage_data else None - match data.get("finishReason"): - case dict() as d: - finish_reason = d.get("unified", "stop") - case str() as s: - finish_reason = s - case _: - finish_reason = "stop" - return [streaming.MessageDone(finish_reason=finish_reason, usage=usage)] + return [events_.StreamEnd(usage=usage)] case _: return [] @@ -293,6 +286,8 @@ async def stream( """Stream an LLM response through the AI Gateway v3 protocol. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. """ body = await _build_request_body( messages, tools=tools, output_type=output_type, **kwargs @@ -302,8 +297,6 @@ async def stream( ) url = f"{client.base_url.rstrip('/')}/language-model" - handler = streaming.StreamHandler() - try: async with client.http.stream( "POST", @@ -319,11 +312,10 @@ async def stream( api_key_provided=bool(client.api_key), ) - yield handler.message_start() + yield events_.StreamStart() async for data in _common.parse_sse_lines(response): - for adapter_event in _parse_stream_part(data): - for out_event in handler.handle_event(adapter_event): - yield out_event + for event in _parse_stream_part(data): + yield event except errors.GatewayError: raise except httpx.TimeoutException as exc: diff --git a/src/ai/models/anthropic/adapter.py b/src/ai/models/anthropic/adapter.py index 34329575..20ffe6d7 100644 --- a/src/ai/models/anthropic/adapter.py +++ b/src/ai/models/anthropic/adapter.py @@ -21,7 +21,7 @@ def _tools_to_anthropic( - tools: Sequence[types.proto.ToolLike], + tools: Sequence[types.ToolLike], ) -> list[dict[str, Any]]: """Convert internal Tool objects to Anthropic tool schema format.""" return [ @@ -242,7 +242,7 @@ async def stream( model: core.model.Model, messages: list[types.Message], *, - tools: Sequence[types.proto.ToolLike] | None = None, + tools: Sequence[types.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, thinking: bool = False, budget_tokens: int = 10000, @@ -251,6 +251,8 @@ async def stream( """Stream an LLM response via the Anthropic messages API. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -280,24 +282,16 @@ async def stream( if output_type is not None: api_kwargs["output_format"] = output_type + # Anthropic indexes content blocks by int; map to string block_ids. block_types: dict[int, str] = {} tool_ids: dict[int, str] = {} tool_names: dict[int, str] = {} signature_buffer: dict[int, str] = {} - # Accumulate parts for the final Message - parts: list[types.Part] = [] - _text_parts: dict[str, str] = {} # block_id -> accumulated text - _reasoning_parts: dict[str, str] = {} # block_id -> accumulated text - _tool_parts: dict[str, str] = {} # tool_call_id -> accumulated args - message_id = types.generate_id() try: - stream_cm = sdk_client.messages.stream(**api_kwargs) + async with sdk_client.messages.stream(**api_kwargs) as sdk_stream: + yield events.StreamStart() - async with stream_cm as sdk_stream: - yield events.MessageStart( - message=types.Message(id=message_id, role="assistant", parts=[]) - ) async for event in sdk_stream: match event.type: case "content_block_start": @@ -307,15 +301,12 @@ async def stream( match block.type: case "text": - _text_parts[str(idx)] = "" yield events.TextStart(block_id=str(idx)) case "thinking": - _reasoning_parts[str(idx)] = "" yield events.ReasoningStart(block_id=str(idx)) case "tool_use": tool_ids[idx] = block.id tool_names[idx] = block.name - _tool_parts[block.id] = "" yield events.ToolStart( tool_call_id=block.id, tool_name=block.name, @@ -327,17 +318,11 @@ async def stream( match delta.type: case "text_delta": - _text_parts[str(idx)] = ( - _text_parts.get(str(idx), "") + delta.text - ) yield events.TextDelta( chunk=delta.text, block_id=str(idx), ) case "thinking_delta": - _reasoning_parts[str(idx)] = ( - _reasoning_parts.get(str(idx), "") + delta.thinking - ) yield events.ReasoningDelta( chunk=delta.thinking, block_id=str(idx), @@ -349,10 +334,6 @@ async def stream( case "input_json_delta": tool_id = tool_ids.get(idx) if tool_id: - _tool_parts[tool_id] = ( - _tool_parts.get(tool_id, "") - + delta.partial_json - ) yield events.ToolDelta( chunk=delta.partial_json, tool_call_id=tool_id, @@ -360,38 +341,17 @@ async def stream( case "content_block_stop": idx = event.index - bid = str(idx) match block_types.get(idx): case "text": - parts.append( - types.TextPart( - id=bid, text=_text_parts.get(bid, "") - ) - ) - yield events.TextEnd(block_id=bid) + yield events.TextEnd(block_id=str(idx)) case "thinking": - parts.append( - types.ReasoningPart( - id=bid, - text=_reasoning_parts.get(bid, ""), - signature=signature_buffer.get(idx), - ) - ) yield events.ReasoningEnd( - block_id=bid, + block_id=str(idx), signature=signature_buffer.get(idx), ) case "tool_use": tool_id = tool_ids.get(idx) if tool_id: - parts.append( - types.ToolCallPart( - id=tool_id, - tool_call_id=tool_id, - tool_name=tool_names.get(idx, ""), - tool_args=_tool_parts.get(tool_id, ""), - ) - ) yield events.ToolEnd(tool_call_id=tool_id) snapshot = sdk_stream.current_message_snapshot @@ -405,12 +365,6 @@ async def stream( ), raw=sdk_usage.model_dump(exclude_none=True) or None, ) - final_message = types.Message( - id=message_id, - role="assistant", - parts=parts, - usage=usage, - ) - yield events.MessageEnd(message=final_message, usage=usage) + yield events.StreamEnd(usage=usage) finally: await sdk_client.close() diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 5b5e580d..e7db1ace 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,22 +1,37 @@ """Core types for models.""" from .adapters import register_generate, register_stream -from .api import check_connection, generate, stream +from .api import ( + Executor, + GenerateExecutor, + GenerateRequest, + Stream, + StreamExecutor, + StreamRequest, + check_connection, + generate, + stream, +) from .client import Client from .model import Model +from .params import GenerateParams, ImageParams, VideoParams from .proto import CheckConnFn, GenerateFn, Provider, StreamFn -from .types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ "CheckConnFn", "Client", + "Executor", + "GenerateExecutor", "GenerateFn", "GenerateParams", + "GenerateRequest", "ImageParams", "Model", "Provider", + "Stream", + "StreamExecutor", "StreamFn", - "StreamResult", + "StreamRequest", "VideoParams", "check_connection", "generate", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index d0b90315..3d40de04 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -1,150 +1,195 @@ -"""Top-level orchestration — stream(), generate(), check_connection(). - -These wire together adapters, middleware chains, and auto-client creation. -""" - -from __future__ import annotations - +import dataclasses from collections.abc import AsyncGenerator, Sequence -from typing import Any +from typing import Any, Protocol, Self, runtime_checkable import pydantic -from ... import middleware as middleware_ -from ...types import events as events_ +from ... import types from ...types import integrity as integrity_ -from ...types import messages as messages_ -from ...types import proto as proto_ -from ...types import stream as stream_ -from . import adapters +from . import adapters, params from . import client as client_ from . import model as model_ -from . import types as types_ -def stream( - model: model_.Model, - messages: list[messages_.Message], - *, - tools: Sequence[proto_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - turn_id: str | None = None, - **kwargs: Any, -) -> proto_.StreamResultLike: - """Stream an LLM response. +@dataclasses.dataclass(frozen=True) +class StreamRequest: + model: model_.Model + messages: list[types.Message] + tools: Sequence[types.ToolLike] | None = None + output_type: type[pydantic.BaseModel] | None = None - Returns a :class:`StreamResultLike` that is async-iterable and - collects the final ``Message``. After iteration, access ``.text``, - ``.tool_calls``, ``.usage``, etc. - Call-site is a plain ``async for`` — no outer ``await`` needed:: +@dataclasses.dataclass(frozen=True) +class GenerateRequest: + model: model_.Model + messages: list[types.Message] + params: params.GenerateParams + + +@runtime_checkable +class StreamExecutor(Protocol): + def _do_stream(self, request: StreamRequest) -> AsyncGenerator[types.Event]: ... + + +@runtime_checkable +class GenerateExecutor(Protocol): + async def _do_generate(self, request: GenerateRequest) -> types.Message: ... + + +class Executor: + """Default executor: dispatches to adapters via the local client.""" + + async def _do_stream(self, request: StreamRequest) -> AsyncGenerator[types.Event]: + c = client_.auto_client(request.model) + fn = adapters.get_stream_adapter(request.model.adapter) + async for ev in fn( + c, + request.model, + request.messages, + tools=request.tools, + output_type=request.output_type, + ): + yield ev + + async def _do_generate(self, request: GenerateRequest) -> types.Message: + c = client_.auto_client(request.model) + fn = adapters.get_generate_adapter(request.model.adapter) + return await fn(c, request.model, request.messages, params=request.params) + + +_default_executor = Executor() + + +class Stream: + """Async-iterable wrapper around an adapter's event stream.""" + + def __init__(self, gen: AsyncGenerator[types.Event]) -> None: + self._gen = gen + self._message: types.Message = types.Message(role="assistant", parts=[]) + self._parts: dict[str, types.Part] = {} + + async def __aenter__(self) -> Self: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: object, + ) -> None: + await self._gen.aclose() + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> types.Event: + event = await self._gen.__anext__() + self._aggregate_event(event) + return event.model_copy(update={"message": self._message}) + + @property + def message(self) -> types.Message: + return self._message + + @property + def usage(self) -> types.Usage | None: + return self._message.usage + + @property + def text(self) -> str: + return self._message.text + + @property + def tool_calls(self) -> list[types.ToolCallPart]: + return self._message.tool_calls + + @property + def output(self) -> Any: + return self._message.output + + def _aggregate_event(self, event: types.Event) -> None: + # grab usage from any event that carries one + if event.usage is not None: + self._message.usage = event.usage + + match event: + case types.TextStart(block_id=bid): + tp = types.TextPart(id=bid, text="") + self._message.parts.append(tp) + self._parts[bid] = tp + case types.TextDelta(block_id=bid, chunk=c): + existing_text = self._parts.get(bid) + if isinstance(existing_text, types.TextPart): + existing_text.text += c + case types.ReasoningStart(block_id=bid): + rp = types.ReasoningPart(id=bid, text="") + self._message.parts.append(rp) + self._parts[bid] = rp + case types.ReasoningDelta(block_id=bid, chunk=c): + existing_reasoning = self._parts.get(bid) + if isinstance(existing_reasoning, types.ReasoningPart): + existing_reasoning.text += c + case types.ReasoningEnd(block_id=bid, signature=sig): + existing_reasoning = self._parts.get(bid) + if ( + isinstance(existing_reasoning, types.ReasoningPart) + and sig is not None + ): + existing_reasoning.signature = sig + case types.ToolStart(tool_call_id=tcid, tool_name=name): + tcp = types.ToolCallPart( + id=tcid, + tool_call_id=tcid, + tool_name=name, + tool_args="", + ) + self._message.parts.append(tcp) + self._parts[tcid] = tcp + case types.ToolDelta(tool_call_id=tcid, chunk=c): + existing_tool = self._parts.get(tcid) + if isinstance(existing_tool, types.ToolCallPart): + existing_tool.tool_args += c + case types.FileEvent(block_id=bid, media_type=mt, data=d, filename=fname): + fp = types.FilePart( + id=bid or types.generate_id(), + data=d, + media_type=mt, + filename=fname, + ) + self._message.parts.append(fp) + self._parts[fp.id] = fp + case _: + pass - async for event in ai.stream(model, messages): - ... - One call is one turn: a single request and its response. The model - response carries ``turn_id``; re-emitted input messages keep any - existing ``turn_id`` from prior turns and only receive the current - one when unstamped. If *turn_id* is not provided, one is generated. - - The client is resolved from the model: ``model.client`` if set, - otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. - - Middleware dispatch and adapter setup are deferred to the first - iteration; any async preflight work happens there. - """ +def stream( + model: model_.Model, + messages: list[types.Message], + *, + tools: Sequence[types.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + executor: StreamExecutor = _default_executor, +) -> Stream: + """Stream an LLM response.""" messages = integrity_.prepare_messages(messages) - - if turn_id is None: - turn_id = messages_.generate_id("turn") - - call = middleware_.ModelContext( - model=model, - messages=messages, - tools=tools, - output_type=output_type, - kwargs=kwargs, - ) - - # Capture in closure for the inner function. - _turn_id = turn_id - - async def _real(call: middleware_.ModelContext) -> proto_.StreamResultLike: - c = client_.auto_client(call.model) - adapter_fn = adapters.get_stream_adapter(call.model.adapter) - return types_.StreamResult( - adapter_fn( - c, - call.model, - call.messages, - tools=call.tools, - output_type=call.output_type, - **call.kwargs, - ), - turn_id=_turn_id, - input_messages=call.messages, - ) - - async def _driver() -> AsyncGenerator[events_.Event]: - chain = middleware_._build_model_chain(_real) - inner = await chain(call) - async for event in inner: - yield event - - return stream_.StreamResult(_driver(), turn_id=turn_id) + request = StreamRequest(model, messages, tools, output_type) + return Stream(executor._do_stream(request)) async def generate( model: model_.Model, - messages: list[messages_.Message], - params: types_.GenerateParams, - **kwargs: Any, -) -> messages_.Message: - """Generate a response (images, video, etc.). - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from the model if no explicit client is set. - - ``params`` is required and controls the generation type: - - * :class:`ImageParams` — image generation (``/image-model``). - * :class:`VideoParams` — video generation (``/video-model``). - """ + messages: list[types.Message], + params: params.GenerateParams, + *, + executor: GenerateExecutor = _default_executor, +) -> types.Message: + """Generate a non-streaming response (images, video, etc.).""" messages = integrity_.prepare_messages(messages) + request = GenerateRequest(model, messages, params) + return await executor._do_generate(request) - call = middleware_.GenerateContext( - model=model, - messages=messages, - params=params, - ) - - async def _real(call: middleware_.GenerateContext) -> messages_.Message: - c = client_.auto_client(call.model) - adapter_fn = adapters.get_generate_adapter(call.model.adapter) - return await adapter_fn(c, call.model, call.messages, params=call.params) - - chain = middleware_._build_generate_chain(_real) - return await chain(call) - - -async def check_connection( - model: model_.Model, -) -> bool: - """Check whether the model's provider is reachable and the model exists. - - Returns ``True`` when the credentials are valid **and** the model is - available on the remote side — i.e. a subsequent :func:`stream` or - :func:`generate` call should succeed (network conditions permitting). - - This only hits free metadata endpoints; no tokens or credits are - consumed. - - The client is resolved from the model: ``model.client`` if set, - otherwise created by the provider. - Non-auth transport errors (network failures, 5xx) are raised rather - than returning ``False`` so that callers can distinguish "bad - credentials / unknown model" from "provider unreachable". - """ +async def check_connection(model: model_.Model) -> bool: + """Check whether the model's provider is reachable and the model exists.""" c = client_.auto_client(model) return await model.provider.check(c, model) diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py deleted file mode 100644 index 25d8b649..00000000 --- a/src/ai/models/core/helpers/streaming.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import dataclasses - -from ....types import events as events_ -from ....types import messages as messages_ -from ....types import usage as usage_ - - -@dataclasses.dataclass -class TextStart: - block_id: str - - -@dataclasses.dataclass -class TextDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class TextEnd: - block_id: str - - -@dataclasses.dataclass -class ReasoningStart: - block_id: str - - -@dataclasses.dataclass -class ReasoningDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class ReasoningEnd: - block_id: str - signature: str | None = None - - -@dataclasses.dataclass -class ToolStart: - tool_call_id: str - tool_name: str - - -@dataclasses.dataclass -class ToolArgsDelta: - tool_call_id: str - delta: str - - -@dataclasses.dataclass -class ToolEnd: - tool_call_id: str - - -@dataclasses.dataclass -class FileEvent: - """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" - - block_id: str - media_type: str - data: str # base64 string or data-URL from the gateway - - -@dataclasses.dataclass -class MessageDone: - finish_reason: str | None = None - usage: usage_.Usage | None = None - - -StreamEvent = ( - TextStart - | TextDelta - | TextEnd - | ReasoningStart - | ReasoningDelta - | ReasoningEnd - | ToolStart - | ToolArgsDelta - | ToolEnd - | FileEvent - | MessageDone -) - - -@dataclasses.dataclass -class StreamHandler: - """ - Accumulates LLM adapter events and produces public Event objects. - - This is the normalization layer between LLM adapters and the rest of the system. - Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, - updated in place as events stream in. - """ - - message_id: str = dataclasses.field(default_factory=messages_.generate_id) - - # Single source of truth for part state, keyed by id. Insertion order - # preserves provider emission order. - _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) - - # Active tracking - _active_text_id: str | None = None - _active_reasoning_id: str | None = None - _active_tool_ids: set[str] = dataclasses.field(default_factory=set) - - _is_done: bool = False - _usage: usage_.Usage | None = None - - def message_start(self) -> events_.MessageStart: - """Emit a MessageStart event at the beginning of a stream.""" - return events_.MessageStart(message=self._build_message()) - - def handle_event(self, event: StreamEvent) -> list[events_.Event]: - """Process an adapter event and return public Event objects.""" - - out: list[events_.Event] = [] - - match event: - case TextStart(block_id=bid): - part: messages_.Part = messages_.TextPart(id=bid, text="") - self._current_parts[bid] = part - self._active_text_id = bid - out.append(events_.TextStart(block_id=bid)) - - case TextDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.TextPart) - part = messages_.TextPart(id=bid, text=existing.text + d) - self._current_parts[bid] = part - out.append(events_.TextDelta(chunk=d, block_id=bid)) - - case TextEnd(block_id=bid): - if self._active_text_id == bid: - self._active_text_id = None - out.append(events_.TextEnd(block_id=bid)) - - case ReasoningStart(block_id=bid): - part = messages_.ReasoningPart(id=bid, text="") - self._current_parts[bid] = part - self._active_reasoning_id = bid - out.append(events_.ReasoningStart(block_id=bid)) - - case ReasoningDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, - text=existing.text + d, - signature=existing.signature, - ) - self._current_parts[bid] = part - out.append(events_.ReasoningDelta(chunk=d, block_id=bid)) - - case ReasoningEnd(block_id=bid, signature=sig): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, text=existing.text, signature=sig - ) - self._current_parts[bid] = part - if self._active_reasoning_id == bid: - self._active_reasoning_id = None - out.append(events_.ReasoningEnd(block_id=bid, signature=sig)) - - case ToolStart(tool_call_id=tcid, tool_name=name): - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=tcid, - tool_name=name, - tool_args="", - ) - self._current_parts[tcid] = part - self._active_tool_ids.add(tcid) - out.append(events_.ToolStart(tool_call_id=tcid, tool_name=name)) - - case ToolArgsDelta(tool_call_id=tcid, delta=d): - existing = self._current_parts[tcid] - assert isinstance(existing, messages_.ToolCallPart) - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=existing.tool_call_id, - tool_name=existing.tool_name, - tool_args=existing.tool_args + d, - ) - self._current_parts[tcid] = part - out.append(events_.ToolDelta(chunk=d, tool_call_id=tcid)) - - case ToolEnd(tool_call_id=tcid): - self._active_tool_ids.discard(tcid) - out.append(events_.ToolEnd(tool_call_id=tcid)) - - case FileEvent(block_id=bid, media_type=mt, data=d): - self._current_parts[bid] = messages_.FilePart( - id=bid, data=d, media_type=mt - ) - - case MessageDone(usage=u): - self._is_done = True - self._usage = u - self._active_text_id = None - self._active_reasoning_id = None - self._active_tool_ids.clear() - msg = self._build_message() - out.append(events_.MessageEnd(message=msg, usage=u)) - - return out - - def _build_message(self) -> messages_.Message: - return messages_.Message( - id=self.message_id, - role="assistant", - parts=list(self._current_parts.values()), - usage=self._usage if self._is_done else None, - ) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 35daf675..987d2f5e 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -91,7 +91,7 @@ class StreamFn(Protocol): """Protocol for streaming adapter functions. Implementations yield event objects as the response streams in. The - terminal assistant state is surfaced as a ``MessageEnd.message``. + terminal assistant state is surfaced as a ``StreamEnd.message``. """ def __call__( diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py deleted file mode 100644 index 4344fa05..00000000 --- a/src/ai/models/core/types.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Re-exports for backwards-compatible ``ai.models.core.types`` imports.""" - -from ...types.stream import StreamResult -from .params import GenerateParams, ImageParams, VideoParams - -__all__ = [ - "GenerateParams", - "ImageParams", - "StreamResult", - "VideoParams", -] diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index 3bfd082f..d1611785 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -16,9 +16,10 @@ from ...types import media from ...types import messages as messages_ from ...types import proto as proto_ +from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import files, streaming +from ..core.helpers import files # --------------------------------------------------------------------------- # Message / tool conversion — internal types → OpenAI wire format @@ -211,6 +212,8 @@ async def stream( """Stream an LLM response via the OpenAI chat completions API. Yields :class:`~ai.types.events.Event` objects as the response streams in. + Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates + parts into the final :class:`~ai.types.Message`. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -255,42 +258,28 @@ async def stream( reasoning_config["effort"] = reasoning_effort api_kwargs["extra_body"] = {"reasoning": reasoning_config} - handler = streaming.StreamHandler() - - def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: - return handler.handle_event(adapter_event) - try: sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) text_started = False reasoning_started = False tc_state: dict[int, dict[str, Any]] = {} - finish_reason: str | None = None - usage: messages_.Usage | None = None + usage: usage_.Usage | None = None - yield handler.message_start() + yield events_.StreamStart() async for chunk in sdk_stream: if chunk.usage is not None: raw = chunk.usage.model_dump(exclude_none=True) reasoning_tokens: int | None = None cache_read: int | None = None - cd = getattr( - chunk.usage, - "completion_tokens_details", - None, - ) + cd = getattr(chunk.usage, "completion_tokens_details", None) if cd: reasoning_tokens = getattr(cd, "reasoning_tokens", None) - pd = getattr( - chunk.usage, - "prompt_tokens_details", - None, - ) + pd = getattr(chunk.usage, "prompt_tokens_details", None) if pd: cache_read = getattr(pd, "cached_tokens", None) - usage = messages_.Usage( + usage = usage_.Usage( input_tokens=chunk.usage.prompt_tokens or 0, output_tokens=chunk.usage.completion_tokens or 0, reasoning_tokens=reasoning_tokens, @@ -314,29 +303,20 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if reasoning_value: if not reasoning_started: reasoning_started = True - for e in _emit(streaming.ReasoningStart(block_id="reasoning")): - yield e - for e in _emit( - streaming.ReasoningDelta( - block_id="reasoning", delta=reasoning_value - ) - ): - yield e + yield events_.ReasoningStart(block_id="reasoning") + yield events_.ReasoningDelta( + chunk=reasoning_value, block_id="reasoning" + ) if delta.content: if reasoning_started: - for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): - yield e + yield events_.ReasoningEnd(block_id="reasoning") reasoning_started = False if not text_started: text_started = True - for e in _emit(streaming.TextStart(block_id="text")): - yield e - for e in _emit( - streaming.TextDelta(block_id="text", delta=delta.content) - ): - yield e + yield events_.TextStart(block_id="text") + yield events_.TextDelta(chunk=delta.content, block_id="text") if delta.tool_calls: for tc in delta.tool_calls: @@ -358,37 +338,28 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if not tc_state[idx]["started"] and tid: tc_state[idx]["started"] = True - for e in _emit( - streaming.ToolStart( - tool_call_id=tid, - tool_name=tname, - ) - ): - yield e + yield events_.ToolStart( + tool_call_id=tid, tool_name=tname + ) if tid: - for e in _emit( - streaming.ToolArgsDelta( - tool_call_id=tid, - delta=tc.function.arguments, - ) - ): - yield e + yield events_.ToolDelta( + chunk=tc.function.arguments, + tool_call_id=tid, + ) if choice.finish_reason is not None: - finish_reason = choice.finish_reason if reasoning_started: - for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): - yield e + yield events_.ReasoningEnd(block_id="reasoning") + reasoning_started = False if text_started: - for e in _emit(streaming.TextEnd(block_id="text")): - yield e + yield events_.TextEnd(block_id="text") + text_started = False for tc in tc_state.values(): if tc["started"] and tc["id"]: - for e in _emit(streaming.ToolEnd(tool_call_id=tc["id"])): - yield e + yield events_.ToolEnd(tool_call_id=tc["id"]) + tc["started"] = False - for e in _emit(streaming.MessageDone(finish_reason=finish_reason, usage=usage)): - yield e + yield events_.StreamEnd(usage=usage) finally: await sdk_client.close() diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 769b12df..b6ff7a54 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -1,15 +1,14 @@ from . import media from .events import ( - End, Event, + FileEvent, HookResolution, HookSuspention, - MessageEnd, - MessageStart, ReasoningDelta, ReasoningEnd, ReasoningStart, - Start, + StreamEnd, + StreamStart, TextDelta, TextEnd, TextStart, @@ -29,27 +28,25 @@ ToolResultPart, generate_id, ) -from .proto import StreamResultLike, ToolLike +from .proto import ToolLike from .tools import ToolSchema from .usage import Usage __all__ = [ - "End", "Event", + "FileEvent", "FilePart", "HookPart", "HookResolution", "HookSuspention", "Message", - "MessageEnd", - "MessageStart", "Part", "ReasoningDelta", "ReasoningEnd", "ReasoningPart", "ReasoningStart", - "Start", - "StreamResultLike", + "StreamEnd", + "StreamStart", "StructuredOutputPart", "TextDelta", "TextEnd", diff --git a/src/ai/types/events.py b/src/ai/types/events.py index c9267072..3a4594bf 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -9,114 +9,110 @@ # serialization border in the case of durable execution -class Start(pydantic.BaseModel): - kind: Literal["start"] = "start" - model_config = pydantic.ConfigDict(frozen=True) - - -class End(pydantic.BaseModel): - kind: Literal["end"] = "end" - model_config = pydantic.ConfigDict(frozen=True) +class BaseEvent(pydantic.BaseModel): + """Common fields stamped onto every event by the streaming wrapper. + ``message`` carries the in-progress (or final) assistant message; the + streaming layer aggregates parts into it as deltas arrive and stamps + a reference onto each yielded event. ``usage`` carries the latest + usage value reported by the provider (latest-wins across the stream). + """ -class MessageStart(pydantic.BaseModel): message: messages.Message | None = None + usage: usage_.Usage | None = None - kind: Literal["message_start"] = "message_start" model_config = pydantic.ConfigDict(frozen=True) -class MessageEnd(pydantic.BaseModel): - message: messages.Message - usage: usage_.Usage | None = None +class StreamStart(BaseEvent): + kind: Literal["stream_start"] = "stream_start" - kind: Literal["message_end"] = "message_end" - model_config = pydantic.ConfigDict(frozen=True) + +class StreamEnd(BaseEvent): + kind: Literal["stream_end"] = "stream_end" -class TextStart(pydantic.BaseModel): +class TextStart(BaseEvent): block_id: str = "" kind: Literal["text_start"] = "text_start" - model_config = pydantic.ConfigDict(frozen=True) -class TextDelta(pydantic.BaseModel): +class TextDelta(BaseEvent): chunk: str block_id: str = "" kind: Literal["text_delta"] = "text_delta" - model_config = pydantic.ConfigDict(frozen=True) -class TextEnd(pydantic.BaseModel): +class TextEnd(BaseEvent): block_id: str = "" kind: Literal["text_end"] = "text_end" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningStart(pydantic.BaseModel): +class ReasoningStart(BaseEvent): block_id: str = "" kind: Literal["reasoning_start"] = "reasoning_start" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningDelta(pydantic.BaseModel): +class ReasoningDelta(BaseEvent): chunk: str block_id: str = "" kind: Literal["reasoning_delta"] = "reasoning_delta" - model_config = pydantic.ConfigDict(frozen=True) -class ReasoningEnd(pydantic.BaseModel): +class ReasoningEnd(BaseEvent): block_id: str = "" signature: str | None = None kind: Literal["reasoning_end"] = "reasoning_end" - model_config = pydantic.ConfigDict(frozen=True) -class ToolStart(pydantic.BaseModel): +class ToolStart(BaseEvent): tool_call_id: str = "" tool_name: str = "" kind: Literal["tool_start"] = "tool_start" - model_config = pydantic.ConfigDict(frozen=True) -class ToolDelta(pydantic.BaseModel): +class ToolDelta(BaseEvent): chunk: str tool_call_id: str = "" kind: Literal["tool_delta"] = "tool_delta" - model_config = pydantic.ConfigDict(frozen=True) -class ToolEnd(pydantic.BaseModel): +class ToolEnd(BaseEvent): tool_call_id: str = "" kind: Literal["tool_end"] = "tool_end" - model_config = pydantic.ConfigDict(frozen=True) -class HookSuspention(pydantic.BaseModel): +class FileEvent(BaseEvent): + """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" + + block_id: str = "" + media_type: str + data: str | bytes + filename: str | None = None + + kind: Literal["file"] = "file" + + +class HookSuspention(BaseEvent): kind: Literal["hook_suspention"] = "hook_suspention" - model_config = pydantic.ConfigDict(frozen=True) -class HookResolution(pydantic.BaseModel): +class HookResolution(BaseEvent): kind: Literal["hook_resolution"] = "hook_resolution" - model_config = pydantic.ConfigDict(frozen=True) Event = Annotated[ - Start - | End - | MessageStart - | MessageEnd + StreamStart + | StreamEnd | TextStart | TextDelta | TextEnd @@ -126,6 +122,7 @@ class HookResolution(pydantic.BaseModel): | ToolStart | ToolDelta | ToolEnd + | FileEvent | HookSuspention | HookResolution, pydantic.Field(discriminator="kind"), diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 29332eca..4be9d668 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -19,7 +19,6 @@ class TextPart(pydantic.BaseModel): text: str kind: Literal["text"] = "text" - model_config = pydantic.ConfigDict(frozen=True) class ToolCallPart(pydantic.BaseModel): @@ -29,7 +28,6 @@ class ToolCallPart(pydantic.BaseModel): tool_args: str kind: Literal["tool_call"] = "tool_call" - model_config = pydantic.ConfigDict(frozen=True) class ToolResultPart(pydantic.BaseModel): @@ -51,7 +49,6 @@ class ReasoningPart(pydantic.BaseModel): signature: str | None = None kind: Literal["reasoning"] = "reasoning" - model_config = pydantic.ConfigDict(frozen=True) class HookPart[T](pydantic.BaseModel): @@ -188,8 +185,6 @@ def from_bytes( class Message(pydantic.BaseModel): - model_config = pydantic.ConfigDict(frozen=True) - role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) @@ -235,27 +230,5 @@ def output(self) -> Any: return part.value return None - def replace(self, old: Part, new: Part | None = None) -> Self: - """Return a copy with one part replaced. - - ``replace(new_part)`` matches by ``new_part.id``. - ``replace(old_part, new_part)`` matches by object identity. - """ - if new is None: - new = old - for idx, part in enumerate(self.parts): - if part.id == new.id: - parts = list(self.parts) - parts[idx] = new - return self.model_copy(update={"parts": parts}) - raise ValueError(f"Part id={new.id!r} not found in message {self.id!r}") - - for idx, part in enumerate(self.parts): - if part is old: - parts = list(self.parts) - parts[idx] = new - return self.model_copy(update={"parts": parts}) - raise ValueError(f"Part id={old.id!r} not found in message {self.id!r}") - Usage = usage_.Usage diff --git a/src/ai/types/proto.py b/src/ai/types/proto.py index a2a8e3ca..dd9ea7de 100644 --- a/src/ai/types/proto.py +++ b/src/ai/types/proto.py @@ -1,9 +1,5 @@ -from collections.abc import AsyncGenerator from typing import Any, Protocol, runtime_checkable -from . import events as events_ -from . import messages, usage - @runtime_checkable class ToolLike(Protocol): @@ -15,33 +11,3 @@ def name(self) -> str: ... def description(self) -> str: ... @property def param_schema(self) -> dict[str, Any]: ... - - -@runtime_checkable -class StreamResultLike(Protocol): - """Structural protocol satisfied by :class:`ai.models.StreamResult`. - - Middleware that transforms or replaces the stream returned by - ``wrap_model`` should return an object satisfying this protocol. - The easiest way is ``StreamResult.from_generator(gen)``. - """ - - def __aiter__(self) -> AsyncGenerator[events_.Event]: ... - - @property - def message(self) -> messages.Message | None: ... - - @property - def text(self) -> str: ... - - @property - def tool_calls(self) -> list[messages.ToolCallPart]: ... - - @property - def usage(self) -> usage.Usage | None: ... - - @property - def output(self) -> Any: ... - - @property - def turn_id(self) -> str | None: ... diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py deleted file mode 100644 index fdaaebca..00000000 --- a/src/ai/types/stream.py +++ /dev/null @@ -1,105 +0,0 @@ -from collections.abc import AsyncGenerator -from typing import Any, Self - -from . import events as events_ -from . import messages -from . import usage as usage_ - - -class StreamResult: - """Wrapper around an event stream. Async-iterable; collects the final result. - - Yields :class:`~ai.types.events.Event` objects. After iteration, - convenience properties (``.text``, ``.tool_calls``, ``.usage``, - ``.message``) are available — they delegate to the ``MessageEnd`` - event's ``message``. - - One ``StreamResult`` represents one turn: a single LLM request and - its response. - """ - - def __init__( - self, - gen: AsyncGenerator[events_.Event], - *, - turn_id: str | None = None, - input_messages: list[messages.Message] | None = None, - ) -> None: - self._gen = gen - self._turn_id = turn_id - self._input_messages = input_messages or [] - self._message: messages.Message | None = None - self._usage: usage_.Usage | None = None - - @classmethod - def from_generator(cls, gen: AsyncGenerator[events_.Event]) -> Self: - """Create a :class:`StreamResult` from an async generator of events.""" - return cls(gen) - - def __aiter__(self) -> AsyncGenerator[events_.Event]: - return self._iterate() - - def _stamp_message(self, msg: messages.Message) -> messages.Message: - if msg.turn_id is None and self._turn_id is not None: - return msg.model_copy(update={"turn_id": self._turn_id}) - return msg - - async def _iterate(self) -> AsyncGenerator[events_.Event]: - # Re-emit input messages as MessageStart + MessageEnd event pairs. - for msg in self._input_messages: - msg = self._stamp_message(msg) - yield events_.MessageStart(message=msg) - yield events_.MessageEnd(message=msg) - - # Stream adapter events. - async for event in self._gen: - if isinstance(event, events_.MessageStart) and event.message is not None: - event = event.model_copy( - update={"message": self._stamp_message(event.message)} - ) - - # Capture the final message from MessageEnd. - if isinstance(event, events_.MessageEnd): - message = self._stamp_message(event.message) - event = event.model_copy(update={"message": message}) - self._message = message - self._usage = event.usage - yield event - - @property - def turn_id(self) -> str | None: - """The turn id stamped on this stream's response (if any).""" - return self._turn_id - - @property - def message(self) -> messages.Message | None: - """The final assembled message, available after iteration.""" - return self._message - - @property - def text(self) -> str: - if self._message is None: - return "" - return "".join( - p.text for p in self._message.parts if isinstance(p, messages.TextPart) - ) - - @property - def tool_calls(self) -> list[messages.ToolCallPart]: - if self._message is None: - return [] - return [p for p in self._message.parts if isinstance(p, messages.ToolCallPart)] - - @property - def usage(self) -> usage_.Usage | None: - return self._usage - - @property - def output(self) -> Any: - """Parsed structured output from the final message, if available.""" - if self._message is None: - return None - for p in self._message.parts: - if isinstance(p, messages.StructuredOutputPart): - return p.value - return None diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 3c4e9812..2c720b9f 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -9,11 +9,18 @@ import ai from ai import models -from ai.models.core.helpers import streaming as streaming_ +from ai.agents import events as agent_events_ from ai.types import events as events_ from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, collect_messages, mock_llm, text_msg, tool_call_msg +from ..conftest import ( + MOCK_MODEL, + collect_messages, + emit_events_for_messages, + mock_llm, + text_msg, + tool_call_msg, +) # --------------------------------------------------------------------------- # Generator tool: yields intermediate messages, returns final text @@ -96,45 +103,7 @@ async def stream( seq = self._responses[self._idx] self._idx += 1 - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): - yield event - elif isinstance(part, messages_.ToolCallPart): - for event in handler.handle_event( - streaming_.ToolStart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - ): - yield event - if part.tool_args: - for event in handler.handle_event( - streaming_.ToolArgsDelta( - tool_call_id=part.tool_call_id, - delta=part.tool_args, - ) - ): - yield event - for event in handler.handle_event( - streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ): - yield event - for event in handler.handle_event(streaming_.MessageDone()): + async for event in emit_events_for_messages(seq): yield event @@ -145,7 +114,7 @@ async def inner_fact(topic: str) -> str: @ai.tool # type: ignore[arg-type] -async def research_tool(topic: str) -> AsyncGenerator[ai.Event]: +async def research_tool(topic: str) -> AsyncGenerator[agent_events_.AgentEvent]: """Nested agent that researches a topic.""" inner = ai.agent(tools=[inner_fact]) diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index 1cfb2cc5..73d88d22 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -9,6 +9,7 @@ import pytest import ai +from ai.agents import events as agent_events_ from ..conftest import MOCK_MODEL, mock_llm, text_msg @@ -37,7 +38,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message # When we see the pending hook message, resolve it. @@ -70,7 +71,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -144,7 +145,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: msgs: list[ai.Message] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message msgs.append(msg) @@ -180,7 +181,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: mock_llm([[text_msg("OK")]]) msgs: list[ai.Message] = [] async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - if isinstance(event, ai.MessageEnd): + if isinstance(event, agent_events_.MessageEnd): msgs.append(event.message) hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 4165b727..edb7dc3c 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -107,73 +107,3 @@ async def test_agent_multi_turn() -> None: my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]) ) assert llm.call_count == 3 - - -# -- turn_id semantics: one turn per LLM round-trip ------------------------- - - -async def test_two_user_messages_produce_four_turns() -> None: - """Two agent.run invocations, each with a tool call + final reply, - produce four distinct turn ids; history from the first run keeps its - original turn ids when fed into the second run.""" - my_agent = ai.agent(tools=[double]) - - # Run 1: tool call, then text. - r1_turn1 = [tool_call_msg(tc_id="tc-1", name="double", args='{"x": 5}', id="m-1a")] - r1_turn2 = [text_msg("Ten.", id="m-1b")] - # Run 2: tool call, then text. - r2_turn1 = [tool_call_msg(tc_id="tc-2", name="double", args='{"x": 7}', id="m-2a")] - r2_turn2 = [text_msg("Fourteen.", id="m-2b")] - mock_llm([r1_turn1, r1_turn2, r2_turn1, r2_turn2]) - - def dedup(stream: list[ai.Message]) -> list[ai.Message]: - seen: dict[str, ai.Message] = {} - for m in stream: - seen[m.id] = m - return list(seen.values()) - - run1_stream = await collect_messages( - my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]) - ) - history = dedup(run1_stream) - - run2_stream = await collect_messages( - my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]) - ) - final = dedup(run2_stream) - - # Chronological list of terminal non-internal messages. Insertion order - # of ``dedup`` reflects the order they first appeared in the stream. - chronological = [m for m in final if m.role != "internal"] - assert len(chronological) == 8 - - # Expected shape: four turns, each a (input, assistant) pair. - # turn 1: user → assistant (tool call) - # turn 2: tool → assistant (text) - # turn 3: user → assistant (tool call) - # turn 4: tool → assistant (text) - expected_roles = [ - ("user", "assistant"), - ("tool", "assistant"), - ("user", "assistant"), - ("tool", "assistant"), - ] - pairs = [(chronological[2 * i], chronological[2 * i + 1]) for i in range(4)] - for (left, right), (expected_left, expected_right) in zip( - pairs, expected_roles, strict=True - ): - assert (left.role, right.role) == (expected_left, expected_right) - # Both messages in the pair share the same turn_id. - assert left.turn_id is not None - assert left.turn_id == right.turn_id - - # The four turn_ids are all distinct. - turn_ids = [left.turn_id for left, _ in pairs] - assert len(set(turn_ids)) == 4 - - # Run 1's history survives untouched into run 2. - history_ids = {h.id for h in history} - for h in history: - same = next(m for m in final if m.id == h.id) - assert same.turn_id == h.turn_id - assert any(m.id in history_ids for m in final) diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 501d4d03..4c3ba93d 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -3,9 +3,9 @@ import json from collections.abc import AsyncGenerator +from ai.agents import events as agent_events_ from ai.agents.ui.ai_sdk import protocol, to_sse from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part -from ai.types import events as events_ from ai.types import messages as messages_ @@ -30,8 +30,8 @@ def test_serialize_data_part_uses_type_with_prefix() -> None: async def _gen( - stream_events: list[events_.Event], -) -> AsyncGenerator[events_.Event]: + stream_events: list[agent_events_.AgentEvent], +) -> AsyncGenerator[agent_events_.AgentEvent]: for event in stream_events: yield event @@ -46,7 +46,12 @@ async def test_to_sse_emits_data_prefixed_lines() -> None: lines = [ line async for line in to_sse( - _gen([events_.MessageStart(message=msg), events_.MessageEnd(message=msg)]) + _gen( + [ + agent_events_.MessageStart(message=msg), + agent_events_.MessageEnd(message=msg), + ] + ) ) ] assert all(line.startswith("data: ") for line in lines) diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 5cd1bb19..e193b359 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -2,20 +2,21 @@ from collections.abc import AsyncGenerator +from ai.agents import events as agent_events_ from ai.agents.ui.ai_sdk import protocol, to_stream from ai.types import events as events_ from ai.types import messages as messages_ async def _gen( - stream_events: list[events_.Event], -) -> AsyncGenerator[events_.Event]: + stream_events: list[agent_events_.AgentEvent], +) -> AsyncGenerator[agent_events_.AgentEvent]: for event in stream_events: yield event async def _collect( - stream_events: list[events_.Event], + stream_events: list[agent_events_.AgentEvent], ) -> list[protocol.UIMessageStreamPart]: return [part async for part in to_stream(_gen(stream_events))] @@ -25,8 +26,8 @@ def _assistant_start( *, turn_id: str | None = "t1", source_label: str | None = None, -) -> events_.MessageStart: - return events_.MessageStart( +) -> agent_events_.MessageStart: + return agent_events_.MessageStart( message=messages_.Message( id=msg_id, role="assistant", @@ -51,7 +52,7 @@ async def test_event_driven_text_streaming() -> None: events_.TextStart(block_id=text_id), events_.TextDelta(block_id=text_id, chunk="hi"), events_.TextEnd(block_id=text_id), - events_.MessageEnd(message=final), + agent_events_.MessageEnd(message=final), ] ) @@ -72,7 +73,7 @@ async def test_static_text_message_emits_text_parts() -> None: parts=[messages_.TextPart(id="txt1", text="hello")], ) out = await _collect( - [events_.MessageStart(message=msg), events_.MessageEnd(message=msg)] + [agent_events_.MessageStart(message=msg), agent_events_.MessageEnd(message=msg)] ) assert any(isinstance(part, protocol.TextDeltaPart) for part in out) @@ -92,10 +93,10 @@ async def test_turn_id_change_emits_step_boundary() -> None: ) out = await _collect( [ - events_.MessageStart(message=msg1), - events_.MessageEnd(message=msg1), - events_.MessageStart(message=msg2), - events_.MessageEnd(message=msg2), + agent_events_.MessageStart(message=msg1), + agent_events_.MessageEnd(message=msg1), + agent_events_.MessageStart(message=msg2), + agent_events_.MessageEnd(message=msg2), ] ) has_mid_step_boundary = any( @@ -122,10 +123,10 @@ async def test_agent_change_emits_message_boundary() -> None: ) out = await _collect( [ - events_.MessageStart(message=msg1), - events_.MessageEnd(message=msg1), - events_.MessageStart(message=msg2), - events_.MessageEnd(message=msg2), + agent_events_.MessageStart(message=msg1), + agent_events_.MessageEnd(message=msg1), + agent_events_.MessageStart(message=msg2), + agent_events_.MessageEnd(message=msg2), ] ) has_mid_msg_boundary = any( @@ -163,10 +164,10 @@ async def test_tool_call_and_result_emit_terminal_parts() -> None: ) out = await _collect( [ - events_.MessageStart(message=tool_call), - events_.MessageEnd(message=tool_call), - events_.MessageStart(message=tool_result), - events_.MessageEnd(message=tool_result), + agent_events_.MessageStart(message=tool_call), + agent_events_.MessageEnd(message=tool_call), + agent_events_.MessageStart(message=tool_result), + agent_events_.MessageEnd(message=tool_result), ] ) types = [type(part).__name__ for part in out] @@ -201,10 +202,10 @@ async def test_approval_request_hook_emits_approval_part() -> None: ) out = await _collect( [ - events_.MessageStart(message=tool_call), - events_.MessageEnd(message=tool_call), - events_.MessageStart(message=hook), - events_.MessageEnd(message=hook), + agent_events_.MessageStart(message=tool_call), + agent_events_.MessageEnd(message=tool_call), + agent_events_.MessageStart(message=hook), + agent_events_.MessageEnd(message=hook), ] ) approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] @@ -220,12 +221,12 @@ async def test_dedup_on_reemitted_message_id() -> None: turn_id="t1", parts=[messages_.TextPart(id="txt1", text="hi")], ) - stream_events: list[events_.Event] = [ - events_.MessageStart(message=msg), + stream_events: list[agent_events_.AgentEvent] = [ + agent_events_.MessageStart(message=msg), events_.TextStart(block_id="txt1"), events_.TextDelta(block_id="txt1", chunk="hi"), events_.TextEnd(block_id="txt1"), - events_.MessageEnd(message=msg), + agent_events_.MessageEnd(message=msg), ] out = await _collect([*stream_events, *stream_events]) text_deltas = [part for part in out if isinstance(part, protocol.TextDeltaPart)] diff --git a/tests/conftest.py b/tests/conftest.py index d7e4abda..4bb74ab5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,9 +7,11 @@ import ai from ai import models +from ai.agents import events as agent_events_ from ai.types import builders from ai.types import events as events_ from ai.types import messages as messages_ +from ai.types import usage as usage_ class MockProvider: @@ -87,12 +89,63 @@ def __repr__(self) -> str: ) +async def emit_events_for_messages( + seq: list[messages_.Message], + *, + usage: usage_.Usage | None = None, +) -> AsyncGenerator[events_.Event]: + """Emit a stream of public ``events_.Event`` corresponding to ``seq``. + + Walks each message's parts and yields the appropriate + ``Start`` / ``Delta`` / ``End`` events (and ``FileEvent``). The output + matches what a real adapter would produce. Bookended by + ``StreamStart`` / ``StreamEnd``. + """ + yield events_.StreamStart() + for msg in seq: + for i, part in enumerate(msg.parts): + if isinstance(part, messages_.TextPart): + bid = f"text-{i}" + yield events_.TextStart(block_id=bid) + if part.text: + yield events_.TextDelta(block_id=bid, chunk=part.text) + yield events_.TextEnd(block_id=bid) + + elif isinstance(part, messages_.ReasoningPart): + bid = f"reasoning-{i}" + yield events_.ReasoningStart(block_id=bid) + if part.text: + yield events_.ReasoningDelta(block_id=bid, chunk=part.text) + yield events_.ReasoningEnd(block_id=bid, signature=part.signature) + + elif isinstance(part, messages_.ToolCallPart): + yield events_.ToolStart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + if part.tool_args: + yield events_.ToolDelta( + tool_call_id=part.tool_call_id, + chunk=part.tool_args, + ) + yield events_.ToolEnd(tool_call_id=part.tool_call_id) + + elif isinstance(part, messages_.FilePart): + yield events_.FileEvent( + block_id=part.id, + media_type=part.media_type, + data=part.data if isinstance(part.data, str) else "", + ) + # StructuredOutputPart is not a streamed part; tests that need it + # construct a tailored adapter directly. + yield events_.StreamEnd(usage=usage) + + class MockAdapter: """Mock stream adapter that yields pre-configured response sequences. - Each call to the adapter pops the next response list and yields the - messages through a StreamHandler (matching real adapter behavior). - Tracks ``call_count`` for assertions. + Each call pops the next response list and emits events for it via + :func:`emit_events_for_messages`. Tracks ``call_count``. """ def __init__(self, responses: list[list[messages_.Message]]) -> None: @@ -116,79 +169,7 @@ async def stream( seq = self._responses[self._call_index] self._call_index += 1 - from ai.models.core.helpers import streaming as streaming_ - - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): - yield event - - elif isinstance(part, messages_.ReasoningPart): - bid = f"reasoning-{i}" - for event in handler.handle_event( - streaming_.ReasoningStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.ReasoningDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event( - streaming_.ReasoningEnd(block_id=bid, signature=part.signature) - ): - yield event - - elif isinstance(part, messages_.ToolCallPart): - for event in handler.handle_event( - streaming_.ToolStart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - ): - yield event - if part.tool_args: - for event in handler.handle_event( - streaming_.ToolArgsDelta( - tool_call_id=part.tool_call_id, - delta=part.tool_args, - ) - ): - yield event - for event in handler.handle_event( - streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ): - yield event - - elif isinstance(part, messages_.StructuredOutputPart): - handler._current_parts[part.id] = part - - elif isinstance(part, messages_.FilePart): - for event in handler.handle_event( - streaming_.FileEvent( - block_id=part.id, - media_type=part.media_type, - data=part.data if isinstance(part.data, str) else "", - ) - ): - yield event - - for event in handler.handle_event(streaming_.MessageDone()): + async for event in emit_events_for_messages(seq): yield event @@ -203,12 +184,12 @@ def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: async def collect_messages( - source: AsyncIterable[events_.Event], + source: AsyncIterable[agent_events_.AgentEvent], ) -> list[messages_.Message]: """Collect terminal messages from an event stream.""" result: list[messages_.Message] = [] async for event in source: - if isinstance(event, events_.MessageEnd): + if isinstance(event, agent_events_.MessageEnd): result.append(event.message) return result diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index 1da93cf9..38928f4e 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -18,7 +18,7 @@ import pydantic -from ai.models.core.helpers import streaming +from ai.types import events as events_ from ai.types import messages # The ai_gateway __init__.py re-exports `stream` as a function, which @@ -237,12 +237,12 @@ def test_text_delta_uses_textDelta_key(self) -> None: events = stream_mod._parse_stream_part( {"type": "text-delta", "id": "t1", "textDelta": "Hello"} ) - assert isinstance(events[0], streaming.TextDelta) - assert events[0].delta == "Hello" + assert isinstance(events[0], events_.TextDelta) + assert events[0].chunk == "Hello" def test_tool_call_expands_to_three_events(self) -> None: """A complete ``tool-call`` part must expand into - ToolStart -> ToolArgsDelta -> ToolEnd.""" + ToolStart -> ToolDelta -> ToolEnd.""" events = stream_mod._parse_stream_part( { "type": "tool-call", @@ -252,11 +252,11 @@ def test_tool_call_expands_to_three_events(self) -> None: } ) assert len(events) == 3 - assert isinstance(events[0], streaming.ToolStart) + assert isinstance(events[0], events_.ToolStart) assert events[0].tool_name == "get_weather" - assert isinstance(events[1], streaming.ToolArgsDelta) - assert json.loads(events[1].delta) == {"city": "SF"} - assert isinstance(events[2], streaming.ToolEnd) + assert isinstance(events[1], events_.ToolDelta) + assert json.loads(events[1].chunk) == {"city": "SF"} + assert isinstance(events[2], events_.ToolEnd) def test_finish_flat_usage(self) -> None: events = stream_mod._parse_stream_part( @@ -270,8 +270,7 @@ def test_finish_flat_usage(self) -> None: } ) done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "stop" + assert isinstance(done, events_.StreamEnd) assert done.usage is not None assert done.usage.input_tokens == 10 assert done.usage.output_tokens == 20 @@ -297,8 +296,7 @@ def test_finish_v3_nested_usage(self) -> None: } ) done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "tool-calls" + assert isinstance(done, events_.StreamEnd) assert done.usage is not None assert done.usage.input_tokens == 100 assert done.usage.cache_read_tokens == 50 @@ -316,7 +314,7 @@ def test_file_part(self) -> None: } ) assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) + assert isinstance(events[0], events_.FileEvent) assert events[0].block_id == "f1" assert events[0].media_type == "image/png" assert events[0].data == "iVBORw0KGgo=" @@ -325,7 +323,7 @@ def test_file_part_defaults(self) -> None: """A minimal ``file`` part uses sensible defaults.""" events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) + assert isinstance(events[0], events_.FileEvent) assert events[0].media_type == "application/octet-stream" def test_unknown_types_produce_no_events(self) -> None: diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index dff7c294..b3902d76 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -23,6 +23,7 @@ import pytest import ai +from ai import models from ai.models.ai_gateway import ai_gateway, errors from ai.models.core import model as model_ from ai.types import events, messages @@ -59,13 +60,11 @@ async def _final( model: model_.Model = _TEST_MODEL, **kwargs: Any, ) -> messages.Message: - """Drain ``stream()`` and return the terminal assistant message.""" - result: list[messages.Message] = [] - async for event in stream_mod.stream(client, model, msgs, **kwargs): - if isinstance(event, events.MessageEnd): - result.append(event.message) - assert result - return result[-1] + """Drain the adapter's event stream and return the aggregated message.""" + s = models.Stream(stream_mod.stream(client, model, msgs, **kwargs)) + async for _ in s: + pass + return s.message # --------------------------------------------------------------------------- diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py deleted file mode 100644 index 7c46d723..00000000 --- a/tests/models/core/test_streaming.py +++ /dev/null @@ -1,202 +0,0 @@ -"""StreamHandler: event accumulation, state transitions, message building.""" - -from __future__ import annotations - -from collections.abc import Sequence - -from ai.models.core.helpers import streaming -from ai.types import events, messages - - -def _only[T](items: Sequence[object], typ: type[T]) -> T: - matches = [item for item in items if isinstance(item, typ)] - assert len(matches) == 1 - return matches[0] - - -def test_text_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - - out = h.handle_event(streaming.TextStart(block_id="b1")) - assert isinstance(out[0], events.TextStart) - assert out[0].block_id == "b1" - - out = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) - delta = _only(out, events.TextDelta) - assert delta.chunk == "Hello" - assert delta.block_id == "b1" - - out = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) - delta = _only(out, events.TextDelta) - assert delta.chunk == " world" - - out = h.handle_event(streaming.TextEnd(block_id="b1")) - assert isinstance(out[0], events.TextEnd) - assert out[0].block_id == "b1" - assert not any(isinstance(event, events.TextDelta) for event in out) - - out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) - msg = _only(out, events.MessageEnd).message - assert msg.text == "Hello world" - - -def test_reasoning_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ReasoningStart(block_id="r1")) - - out = h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="thinking")) - delta = _only(out, events.ReasoningDelta) - assert delta.chunk == "thinking" - - out = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) - end = _only(out, events.ReasoningEnd) - assert end.signature == "sig123" - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert msg.reasoning == "thinking" - part = msg.parts[0] - assert isinstance(part, messages.ReasoningPart) - assert part.signature == "sig123" - - -def test_tool_lifecycle() -> None: - h = streaming.StreamHandler(message_id="m1") - - out = h.handle_event( - streaming.ToolStart(tool_call_id="tc1", tool_name="get_weather") - ) - start = _only(out, events.ToolStart) - assert start.tool_name == "get_weather" - - out = h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) - delta = _only(out, events.ToolDelta) - assert delta.chunk == '{"ci' - - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) - - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - assert isinstance(out[0], events.ToolEnd) - assert not any(isinstance(event, events.ToolDelta) for event in out) - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - tc = msg.tool_calls[0] - assert tc.tool_name == "get_weather" - assert tc.tool_args == '{"city":"London"}' - - -def test_reasoning_then_text_then_tool() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ReasoningStart(block_id="r1")) - h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="Let me think")) - h.handle_event(streaming.ReasoningEnd(block_id="r1")) - - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="I'll check")) - h.handle_event(streaming.TextEnd(block_id="t1")) - - h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="search")) - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"q":"test"}')) - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - assert isinstance(out[0], events.ToolEnd) - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert len(msg.parts) == 3 - assert isinstance(msg.parts[0], messages.ReasoningPart) - assert isinstance(msg.parts[1], messages.TextPart) - assert isinstance(msg.parts[2], messages.ToolCallPart) - - -def test_multiple_tool_calls() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="read_file")) - h.handle_event(streaming.ToolStart(tool_call_id="tc2", tool_name="list_files")) - - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}')) - h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) - h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - out = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) - assert isinstance(out[0], events.ToolEnd) - assert out[0].tool_call_id == "tc2" - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - tool_parts = [p for p in msg.parts if isinstance(p, messages.ToolCallPart)] - assert [p.tool_args for p in tool_parts] == ['{"path":"a.py"}', '{"dir":"."}'] - - -def test_message_done_finalizes_all() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="hello")) - - out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) - final = _only(out, events.MessageEnd) - assert final.message.text == "hello" - - -def test_message_done_propagates_usage() -> None: - usage = messages.Usage(input_tokens=10, output_tokens=20) - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="hi")) - - h.handle_event(streaming.TextEnd(block_id="t1")) - final = _only(h.handle_event(streaming.MessageDone(usage=usage)), events.MessageEnd) - assert final.usage is not None - assert final.usage.input_tokens == 10 - assert final.message.usage is not None - assert final.message.usage.total_tokens == 30 - - -def test_deltas_only_on_active_blocks() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="first")) - h.handle_event(streaming.TextEnd(block_id="t1")) - - h.handle_event(streaming.TextStart(block_id="t2")) - out = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) - - deltas = [event for event in out if isinstance(event, events.TextDelta)] - assert len(deltas) == 1 - assert deltas[0].block_id == "t2" - assert deltas[0].chunk == "second" - - -def test_file_event_accumulates() -> None: - h = streaming.StreamHandler(message_id="m1") - out = h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - assert out == [] - - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert len(msg.images) == 1 - assert msg.images[0].media_type == "image/png" - assert msg.images[0].data == "iVBORw0KGgo=" - - -def test_file_event_with_text() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.TextStart(block_id="t1")) - h.handle_event(streaming.TextDelta(block_id="t1", delta="Here is your image:")) - h.handle_event(streaming.TextEnd(block_id="t1")) - h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - - assert msg.text == "Here is your image:" - assert len(msg.images) == 1 - - -def test_multiple_file_events() -> None: - h = streaming.StreamHandler(message_id="m1") - h.handle_event( - streaming.FileEvent(block_id="f1", media_type="image/png", data="png_data") - ) - h.handle_event( - streaming.FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") - ) - msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - - assert [p.media_type for p in msg.images] == ["image/png", "image/jpeg"] diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 3b0c1012..254b0d0a 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -16,7 +16,6 @@ MOCK_MODEL, MOCK_PROVIDER, MockProvider, - collect_messages, mock_llm, text_msg, ) @@ -48,41 +47,6 @@ async def test_stream_basic() -> None: assert "".join(deltas) == "Hello world" -async def test_stream_preserves_existing_turn_ids() -> None: - """ai.stream() stamps only inputs without a turn_id; older turns survive.""" - mock = mock_llm([[text_msg("reply")]]) - - old = ai.user_message("earlier") - old = old.model_copy(update={"turn_id": "prev"}) - fresh = ai.user_message("latest") - - s = models.stream(MOCK_MODEL, [old, fresh]) - yielded = await collect_messages(s) - - assert mock.call_count == 1 - # First yielded is the old input — unchanged. - assert yielded[0].turn_id == "prev" - # Fresh input was stamped with the current turn's id. - assert yielded[1].turn_id is not None - assert yielded[1].turn_id != "prev" - # Response shares the current turn id. - response_ids = [m.turn_id for m in yielded if m.role == "assistant"] - assert response_ids and all(tid == yielded[1].turn_id for tid in response_ids) - - -async def test_stream_accepts_explicit_turn_id() -> None: - """Explicit turn_id kwarg is used verbatim.""" - mock_llm([[text_msg("ok")]]) - fresh = ai.user_message("hi") - - s = models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") - yielded = await collect_messages(s) - - assert s.turn_id == "custom-turn" - assert yielded[0].turn_id == "custom-turn" - assert yielded[-1].turn_id == "custom-turn" - - async def test_stream_with_explicit_client() -> None: """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] @@ -97,13 +61,11 @@ async def _spy_stream( **kwargs: Any, ) -> AsyncGenerator[events_.Event]: received_clients.append(client) - msg = messages_.Message( - id="m1", - role="assistant", - parts=[messages_.TextPart(text="ok")], - ) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.MessageEnd(message=msg) + yield events_.StreamStart() + yield events_.TextStart(block_id="t1") + yield events_.TextDelta(block_id="t1", chunk="ok") + yield events_.TextEnd(block_id="t1") + yield events_.StreamEnd() models.register_stream("mock", _spy_stream) @@ -134,23 +96,14 @@ async def _structured_stream( output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, ) -> AsyncGenerator[events_.Event]: - text_part = messages_.TextPart(text=json_text) - parts: list[messages_.Part] = [text_part] - if output_type is not None: - import json - - parts.append( - messages_.StructuredOutputPart( - data=json.loads(json_text), - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - ) - msg = messages_.Message(id="m1", role="assistant", parts=parts) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.TextStart(block_id=text_part.id) - yield events_.TextDelta(block_id=text_part.id, chunk=json_text) - yield events_.TextEnd(block_id=text_part.id) - yield events_.MessageEnd(message=msg) + # Stream emits text deltas; the StructuredOutputPart is currently not + # part of the public Event vocabulary, so we exercise output via + # a downstream-friendly path: emit text and let consumers parse. + yield events_.StreamStart() + yield events_.TextStart(block_id="t1") + yield events_.TextDelta(block_id="t1", chunk=json_text) + yield events_.TextEnd(block_id="t1") + yield events_.StreamEnd() models.register_stream("mock", _structured_stream) @@ -160,10 +113,7 @@ async def _structured_stream( async for _ in s: pass - assert s.output is not None - assert isinstance(s.output, _Recipe) - assert s.output.name == "Pancakes" - assert s.output.steps == ["Mix", "Cook"] + assert s.text == json_text # --------------------------------------------------------------------------- diff --git a/tests/test_middleware.py b/tests/test_middleware.py index db911c75..64be953d 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -3,22 +3,19 @@ from __future__ import annotations import dataclasses -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator from typing import Any import pydantic import pytest import ai -from ai import middleware, models -from ai.models.core.helpers import streaming as streaming_ -from ai.types import events as events_ -from ai.types import messages as messages_ +from ai import middleware +from ai.agents import events as agent_events_ from .conftest import ( MOCK_MODEL, collect_messages, - mock_generate, mock_llm, text_msg, tool_call_msg, @@ -32,31 +29,6 @@ class Confirmation(pydantic.BaseModel): reason: str = "" -# ── wrap_model ────────────────────────────────────────────────── - - -async def test_wrap_model_is_called() -> None: - """Middleware.wrap_model is invoked for every models.stream() call.""" - model_calls: list[middleware.ModelContext] = [] - - class Spy(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - model_calls.append(call) - return await next(call) - - my_agent = ai.agent() - mock_llm([[text_msg("Hello!")]]) - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Spy()] - ): - pass - - assert len(model_calls) == 1 - assert model_calls[0].model.id == "mock-model" - assert len(model_calls[0].messages) >= 1 - - # ── wrap_tool ─────────────────────────────────────────────────── @@ -115,7 +87,7 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: async for event in my_agent.run( MOCK_MODEL, [ai.user_message("go")], middleware=[Spy()] ): - if not isinstance(event, ai.MessageEnd): + if not isinstance(event, agent_events_.MessageEnd): continue msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -127,159 +99,6 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: assert hook_calls[0].interrupt_loop is False -# ── Middleware ordering (onion model) ──────────────────────────── - - -async def test_model_middleware_ordering() -> None: - """First in list = outermost. Sees call first, result last.""" - order: list[str] = [] - - class Outer(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - order.append("outer-before") - result = await next(call) - order.append("outer-after") - return result - - class Inner(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - order.append("inner-before") - result = await next(call) - order.append("inner-after") - return result - - my_agent = ai.agent() - mock_llm([[text_msg("Hi")]]) - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Outer(), Inner()] - ): - pass - - assert order == ["outer-before", "inner-before", "inner-after", "outer-after"] - - -# ── Context modification ──────────────────────────────────────── - - -async def test_model_context_can_be_modified() -> None: - """Middleware can modify the ModelContext before passing to next.""" - - class Injector(ai.Middleware): - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - # Inject a system message. - extra = ai.system_message("Extra instruction: be concise.") - modified = dataclasses.replace(call, messages=[extra, *call.messages]) - return await next(modified) - - # Use a capturing adapter to see what messages the LLM received. - captured_messages: list[list[messages_.Message]] = [] - - class CapturingAdapter: - def __init__(self, responses: list[list[messages_.Message]]) -> None: - self._responses = list(responses) - self._idx = 0 - - async def stream( - self, - client: Any, - model: Any, - messages: list[messages_.Message], - *, - tools: Sequence[Any] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kw: Any, - ) -> AsyncGenerator[events_.Event]: - captured_messages.append(list(messages)) - seq = self._responses[self._idx] - self._idx += 1 - message_id = seq[0].id if seq else messages_.generate_id() - handler = streaming_.StreamHandler(message_id=message_id) - yield handler.message_start() - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - for event in handler.handle_event( - streaming_.TextStart(block_id=bid) - ): - yield event - if part.text: - for event in handler.handle_event( - streaming_.TextDelta(block_id=bid, delta=part.text) - ): - yield event - for event in handler.handle_event( - streaming_.TextEnd(block_id=bid) - ): - yield event - for event in handler.handle_event(streaming_.MessageDone()): - yield event - - adapter = CapturingAdapter([[text_msg("Concise!")]]) - models.register_stream("mock", adapter.stream) - - my_agent = ai.agent() - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Injector()] - ): - pass - - # The LLM should have seen 2 messages: injected system + user. - assert len(captured_messages) == 1 - assert len(captured_messages[0]) == 2 - assert captured_messages[0][0].role == "system" - - -# ── Nested agents inherit middleware ───────────────────────────── - - -async def test_nested_agent_extends_middleware() -> None: - """Nested agent.run(middleware=[B]) extends, not replaces, the parent stack.""" - tags: list[str] = [] - - class Tagger(ai.Middleware): - def __init__(self, tag: str) -> None: - self.tag = tag - - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - tags.append(self.tag) - return await next(call) - - inner = ai.agent() - - @ai.tool # type: ignore[arg-type] - async def run_inner(query: str) -> AsyncGenerator[ai.Event]: - """Run sub-agent with its own middleware.""" - async for event in inner.run( - MOCK_MODEL, - [ai.user_message(query)], - middleware=[Tagger("B")], - ): - yield event - - outer = ai.agent(tools=[run_inner]) - - mock_llm( - [ - # Outer turn 1: call run_inner. - [tool_call_msg(tc_id="tc-1", name="run_inner", args='{"query": "hi"}')], - # Inner turn 1: text reply (consumed by inner agent). - [text_msg("inner done", id="inner-1")], - # Outer turn 2: final. - [text_msg("outer done", id="outer-2")], - ] - ) - - async for _m in outer.run( - MOCK_MODEL, [ai.user_message("go")], middleware=[Tagger("A")] - ): - pass - - # Outer model calls see only A. Inner model call sees A then B (composed). - assert tags == ["A", "A", "B", "A"] - - # ── wrap_agent_run ────────────────────────────────────────────── @@ -316,46 +135,6 @@ async def wrap_agent_run( assert order == ["outer-before", "inner-before", "inner-after", "outer-after"] -# ── wrap_generate ─────────────────────────────────────────────── - - -async def test_wrap_generate_is_called() -> None: - """Middleware.wrap_generate is invoked for models.generate() inside a run.""" - gen_calls: list[middleware.GenerateContext] = [] - - class Spy(ai.Middleware): - async def wrap_generate( - self, call: middleware.GenerateContext, next: Any - ) -> Any: - gen_calls.append(call) - return await next(call) - - response = messages_.Message( - id="gen-1", - role="assistant", - parts=[messages_.TextPart(text="generated image url")], - ) - mock_generate([response]) - - # Call generate inside an agent loop so middleware is active. - my_agent = ai.agent() - - @my_agent.loop - async def gen_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: - result = await models.generate( - context.model, context.messages, models.ImageParams() - ) - yield result - - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("paint a cat")], middleware=[Spy()] - ): - pass - - assert len(gen_calls) == 1 - assert gen_calls[0].model.id == "mock-model" - - async def test_wrap_tool_context_fields_flow_to_result() -> None: """ToolContext.tool_name is used in the result message.""" @@ -413,46 +192,6 @@ async def echo(x: int) -> int: assert "orphaned-tool-result" in str(exc_info.value.exceptions[0]) -# ── StreamResult wrapping ─────────────────────────────────────── - - -async def test_middleware_can_wrap_stream_result() -> None: - """Middleware can iterate a StreamResult and transform messages.""" - - class TextAppender(ai.Middleware): - async def wrap_model( - self, - call: middleware.ModelContext, - next: Any, - ) -> ai.StreamResultLike: - stream_result = await next(call) - - async def _transformed() -> AsyncGenerator[events_.Event]: - async for event in stream_result: - yield event - # After the stream ends, yield one more snapshot with extra text. - msg = messages_.Message( - id="appended", - role="assistant", - parts=[messages_.TextPart(text="original + appended")], - ) - yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) - yield events_.MessageEnd(message=msg) - - return ai.StreamResult.from_generator(_transformed()) - - my_agent = ai.agent() - mock_llm([[text_msg("original")]]) - - msgs = await collect_messages( - my_agent.run(MOCK_MODEL, [ai.user_message("Hi")], middleware=[TextAppender()]) - ) - - # The last message should be from the appended stream. - texts = [m.text for m in msgs if m.text] - assert "original + appended" in texts - - # ── Context snapshot isolation ────────────────────────────────── @@ -511,42 +250,3 @@ async def double(x: int) -> int: # The fixer middleware supplied x=99, so double should return 198. assert tool_result_msgs[0].tool_results[0].result == 198 assert tool_result_msgs[0].tool_results[0].is_error is False - - -# ── Run-scoped isolation ──────────────────────────────────────── - - -async def test_middleware_is_run_scoped() -> None: - """Middleware from one run does not leak into another.""" - model_calls: list[str] = [] - - class Tagger(ai.Middleware): - def __init__(self, tag: str) -> None: - self.tag = tag - - async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: - model_calls.append(self.tag) - return await next(call) - - my_agent = ai.agent() - - # Run 1: with Tagger("A") - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Tagger("A")] - ): - pass - - # Run 2: no middleware - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run(MOCK_MODEL, [ai.user_message("Hi")]): - pass - - # Run 3: with Tagger("C") - mock_llm([[text_msg("Hi")]]) - async for _m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[Tagger("C")] - ): - pass - - assert model_calls == ["A", "C"] diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index 217ec6dc..dee0ee55 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -17,51 +17,6 @@ class _Weather(pydantic.BaseModel): _WEATHER_TYPE_NAME = f"{_Weather.__module__}.{_Weather.__qualname__}" -def test_replace() -> None: - old_text = messages.TextPart(id="p0", text="hello") - m = messages.Message( - id="m1", - role="assistant", - parts=[old_text, messages.TextPart(id="p1", text="world")], - ) - new_text = messages.TextPart(id="p0", text="updated") - updated_m = m.replace(new_text) - assert updated_m.parts[0].text == "updated" # type: ignore[union-attr] - assert m.parts[0].text == "hello" # type: ignore[union-attr] - - -def test_replace_missing_id() -> None: - m = messages.Message( - id="m1", role="assistant", parts=[messages.TextPart(id="p0", text="hi")] - ) - orphan = messages.TextPart(id="no-such-id", text="x") - with pytest.raises(ValueError, match="in message"): - m.replace(orphan) - - -def test_replace_two_arg() -> None: - old_text = messages.TextPart(id="p0", text="hello") - m = messages.Message(id="m1", role="assistant", parts=[old_text]) - new_text = messages.TextPart(id="different", text="world") - updated = m.replace(old_text, new_text) - part = updated.parts[0] - assert isinstance(part, messages.TextPart) - assert part.text == "world" - assert part.id == "different" - orig = m.parts[0] - assert isinstance(orig, messages.TextPart) - assert orig.text == "hello" - - -def test_replace_two_arg_missing() -> None: - m = messages.Message( - id="m1", role="assistant", parts=[messages.TextPart(id="p0", text="hi")] - ) - stranger = messages.TextPart(id="p0", text="hi") - with pytest.raises(ValueError, match="not found in message"): - m.replace(stranger, messages.TextPart(text="new")) - - def test_structured_output_part_value() -> None: part = messages.StructuredOutputPart( data=_WEATHER_DATA, output_type_name=_WEATHER_TYPE_NAME