From 3dc52dc54dbe443c1e7714cbf1fd383f8087bfb6 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 21 Apr 2026 10:50:12 -0700 Subject: [PATCH 1/3] Rewrite datamodel to separate events and messages --- src/ai/types/__init__.py | 40 ++-- src/ai/types/builders.py | 223 ------------------ src/ai/types/events.py | 78 ++++++ src/ai/types/integrity.py | 261 -------------------- src/ai/types/messages.py | 341 ++++----------------------- src/ai/types/{stream.py => proto.py} | 26 +- src/ai/types/tools.py | 14 +- src/ai/types/usage.py | 57 +++++ 8 files changed, 224 insertions(+), 816 deletions(-) delete mode 100644 src/ai/types/builders.py create mode 100644 src/ai/types/events.py delete mode 100644 src/ai/types/integrity.py rename src/ai/types/{stream.py => proto.py} (56%) create mode 100644 src/ai/types/usage.py diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 6a41f9d6..6372fec5 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -1,40 +1,48 @@ -"""Shared transport types — the universal interchange format. - -Message, Part types, Usage, and tool schema protocols used across -both the models and agents layers. -""" - +from .events import ( + End, + Event, + HookResolution, + HookSuspention, + MessageEnd, + MessageStart, + PartDelta, + PartEnd, + PartStart, + Start, +) from .messages import ( FilePart, HookPart, Message, Part, - PartClosed, - PartDelta, - PartOpened, ReasoningPart, - StreamState, StructuredOutputPart, TextPart, ToolCallPart, ToolResultPart, - Usage, generate_id, ) -from .stream import StreamResultLike -from .tools import ToolLike, ToolSchema +from .proto import StreamResultLike, ToolLike +from .tools import ToolSchema +from .usage import Usage __all__ = [ + "End", + "Event", "FilePart", "HookPart", + "HookResolution", + "HookSuspention", "Message", + "MessageEnd", + "MessageStart", "Part", - "PartClosed", "PartDelta", - "PartOpened", + "PartEnd", + "PartStart", "ReasoningPart", + "Start", "StreamResultLike", - "StreamState", "StructuredOutputPart", "TextPart", "ToolCallPart", diff --git a/src/ai/types/builders.py b/src/ai/types/builders.py deleted file mode 100644 index 94ddda5e..00000000 --- a/src/ai/types/builders.py +++ /dev/null @@ -1,223 +0,0 @@ -"""Composable message construction helpers. - -Convenience functions for building Message objects without manually -constructing Part lists. Each ``*_message`` function returns a single -``Message``. -""" - -from __future__ import annotations - -from typing import Any, overload - -from .messages import ( - FilePart, - HookPart, - Message, - Part, - ReasoningPart, - StructuredOutputPart, - TextPart, - ToolCallPart, - ToolResultPart, -) - -_PART_TYPES = ( - TextPart, - ToolCallPart, - ToolResultPart, - ReasoningPart, - HookPart, - StructuredOutputPart, - FilePart, -) - -# A value that can appear as message content: bare strings become TextPart. -PartLike = str | Part - - -def _coerce_parts(args: tuple[PartLike, ...]) -> list[Part]: - parts: list[Part] = [] - for arg in args: - if isinstance(arg, str): - parts.append(TextPart(text=arg)) - elif isinstance(arg, _PART_TYPES): - parts.append(arg) - else: - raise TypeError(f"Expected str or Part, got {type(arg).__name__}") - return parts - - -def system_message(*content: PartLike) -> Message: - """Create a system message. - - >>> ai.system_message("You are a helpful robot.") - """ - return Message(role="system", parts=_coerce_parts(content)) - - -def user_message(*content: PartLike) -> Message: - """Create a user message from strings and/or Part objects. - - >>> ai.user_message("Describe this image:", ai.file_part(url)) - """ - return Message(role="user", parts=_coerce_parts(content)) - - -def assistant_message(*content: PartLike) -> Message: - """Create an assistant message from strings and/or Part objects. - - >>> ai.assistant_message(ai.thinking("hmm"), "Hello!") - """ - return Message(role="assistant", parts=_coerce_parts(content)) - - -def file_part( - data: str | bytes, - *, - media_type: str | None = None, - filename: str | None = None, -) -> FilePart: - """Create a :class:`FilePart` from a URL string or raw bytes. - - Dispatches to :meth:`FilePart.from_url` (for ``str``) or - :meth:`FilePart.from_bytes` (for ``bytes``), with automatic - media-type detection. - """ - if isinstance(data, str): - return FilePart.from_url(data, media_type=media_type) - return FilePart.from_bytes(data, media_type=media_type, filename=filename) - - -def thinking(text: str, *, signature: str | None = None) -> ReasoningPart: - """Create a :class:`ReasoningPart`. - - Useful for replaying conversation history that includes model reasoning. - """ - return ReasoningPart(text=text, signature=signature) - - -def _tool_results_from_messages(messages: list[Message]) -> list[ToolResultPart]: - parts: list[ToolResultPart] = [] - for message in messages: - if message.role != "tool": - raise TypeError(f"Expected tool message, got role={message.role!r}") - for part in message.parts: - if not isinstance(part, ToolResultPart): - raise TypeError( - "tool_message() only accepts tool messages containing " - "ToolResultPart parts" - ) - parts.append(part) - return parts - - -@overload -def tool_message(*messages: Message | list[Message]) -> Message: ... - - -@overload -def tool_message(*parts: ToolResultPart) -> Message: ... - - -@overload -def tool_message( - *, - tool_call_id: str, - result: Any = None, - tool_name: str = "", - is_error: bool = False, -) -> Message: ... - - -def tool_message( - *items: Message | ToolResultPart | list[Message], - tool_call_id: str | None = None, - result: Any = None, - tool_name: str = "", - is_error: bool = False, -) -> Message: - """Create or merge a tool-result message. - - >>> part = ai.tool_result("tc-1", result=72, tool_name="weather") - >>> ai.tool_message(part) - >>> ai.tool_message(tool_call_id="tc-1", result=72, tool_name="weather") - """ - if tool_call_id is None and (result is not None or tool_name or is_error): - raise TypeError( - "tool_message() keyword tool-result fields require tool_call_id" - ) - - if tool_call_id is not None: - if items: - raise TypeError( - "tool_message() cannot mix keyword tool-result fields with " - "positional messages or ToolResultPart values" - ) - return Message( - role="tool", - parts=[ - tool_result( - tool_call_id, - result=result, - tool_name=tool_name, - is_error=is_error, - ) - ], - ) - - if not items: - raise TypeError("tool_message() requires at least one tool message or result") - - flattened_messages: list[Message] = [] - result_parts: list[ToolResultPart] = [] - saw_message = False - saw_result_part = False - - for item in items: - if isinstance(item, list): - saw_message = True - flattened_messages.extend(item) - elif isinstance(item, Message): - saw_message = True - flattened_messages.append(item) - elif isinstance(item, ToolResultPart): - saw_result_part = True - result_parts.append(item) - else: - raise TypeError( - "tool_message() only accepts tool messages, lists of tool " - "messages, or ToolResultPart values" - ) - - if saw_message and saw_result_part: - raise TypeError( - "tool_message() cannot mix tool messages with ToolResultPart values" - ) - - if saw_message: - merged_parts: list[Part] = [] - merged_parts.extend(_tool_results_from_messages(flattened_messages)) - return Message(role="tool", parts=merged_parts) - - tool_parts: list[Part] = [] - tool_parts.extend(result_parts) - return Message(role="tool", parts=tool_parts) - - -def tool_result( - tool_call_id: str, - *, - result: Any = None, - tool_name: str = "", - is_error: bool = False, -) -> ToolResultPart: - """Create a :class:`ToolResultPart`. - - >>> ai.tool_result("tc-1", result={"temp": 72}, tool_name="weather") - """ - return ToolResultPart( - tool_call_id=tool_call_id, - tool_name=tool_name, - result=result, - is_error=is_error, - ) diff --git a/src/ai/types/events.py b/src/ai/types/events.py new file mode 100644 index 00000000..1c55370d --- /dev/null +++ b/src/ai/types/events.py @@ -0,0 +1,78 @@ +from typing import Annotated, Literal + +import pydantic + +from . import messages + +# we're using pydantic because events are crossing +# serialization border in the case of durable execution + + +class Start(pydantic.BaseModel): + kind: Literal["start"] = "start" + model_config = pydantic.ConfigDict(frozen=True) + + +class End(pydantic.BaseModel): + kind: Literal["end"] = "end" + model_config = pydantic.ConfigDict(frozen=True) + + +class MessageStart(pydantic.BaseModel): + message: messages.Message + + kind: Literal["message_start"] = "message_start" + model_config = pydantic.ConfigDict(frozen=True) + + +class MessageEnd(pydantic.BaseModel): + message: messages.Message + + kind: Literal["message_end"] = "message_end" + model_config = pydantic.ConfigDict(frozen=True) + + +class PartStart(pydantic.BaseModel): + part: messages.Part + + kind: Literal["part_start"] = "part_start" + model_config = pydantic.ConfigDict(frozen=True) + + +class PartDelta(pydantic.BaseModel): + part: messages.Part + chunk: str + + kind: Literal["part_delta"] = "part_delta" + model_config = pydantic.ConfigDict(frozen=True) + + +class PartEnd(pydantic.BaseModel): + part: messages.Part + + kind: Literal["part_end"] = "part_end" + model_config = pydantic.ConfigDict(frozen=True) + + +class HookSuspention(pydantic.BaseModel): + kind: Literal["hook_suspention"] = "hook_suspention" + model_config = pydantic.ConfigDict(frozen=True) + + +class HookResolution(pydantic.BaseModel): + kind: Literal["hook_resolution"] = "hook_resolution" + model_config = pydantic.ConfigDict(frozen=True) + + +Event = Annotated[ + Start + | End + | MessageStart + | MessageEnd + | PartStart + | PartDelta + | PartEnd + | HookSuspention + | HookResolution, + pydantic.Field(discriminator="kind"), +] diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py deleted file mode 100644 index 9f737101..00000000 --- a/src/ai/types/integrity.py +++ /dev/null @@ -1,261 +0,0 @@ -import json -import logging -from typing import Literal - -from . import builders -from . import messages as messages_ - -logger = logging.getLogger(__name__) - -Mode = Literal["strict", "auto"] - -IssueKind = Literal[ - "duplicate-tool-call", - "duplicate-tool-result", - "internal-part", - "invalid-tool-args", - "orphaned-tool-call", - "orphaned-tool-result", - "internal-message", -] - - -class IntegrityError(ValueError): - def __init__(self, issues: list[IssueKind]) -> None: - self.issues = issues - super().__init__( - f"Message history has {len(issues)} issue(s): " + ", ".join(issues) - ) - - -# used for stripping internal parts -_LLM_PART_TYPES = ( - messages_.TextPart, - messages_.ToolCallPart, - messages_.ToolResultPart, - messages_.ReasoningPart, - messages_.FilePart, -) - - -def _clean_messages( - messages: list[messages_.Message], mode: Mode -) -> tuple[list[messages_.Message], list[IssueKind]]: - """Strip internal messages, fix broken tool args""" - - issues: list[IssueKind] = [] - result: list[messages_.Message] = [] - - for msg in messages: - # 1. drop internal messages emitted by hooks - if msg.role == "internal": - issues.append("internal-message") - if mode == "strict": - result.append(msg) - continue - - parts: list[messages_.Part] = list(msg.parts) - changed = False - - # 2. strip everything that isn't an LLM part - kept: list[messages_.Part] = [ - p for p in parts if isinstance(p, _LLM_PART_TYPES) - ] - if len(kept) < len(parts): - issues.append("internal-part") - if mode == "auto": - parts = kept - changed = True - - # 3. ensure tool args are json-decodable - new_parts: list[messages_.Part] = [] - for part in parts: - if isinstance(part, messages_.ToolCallPart): - try: - json.loads(part.tool_args) - except (json.JSONDecodeError, TypeError): - if mode == "auto": - part = part.model_copy(update={"tool_args": "{}"}) - issues.append("invalid-tool-args") - changed = True - new_parts.append(part) - - if changed and mode == "auto": - parts = new_parts - - # 4. drop empty messages - if mode == "auto" and not parts: - continue - - if changed and mode == "auto": - # messages are immutable so we have to do this - result.append(msg.model_copy(update={"parts": parts})) - else: - result.append(msg) - - return result, issues - - -def _validate_tool_ids(messages: list[messages_.Message]) -> list[IssueKind]: - """Check for fatal issues: duplicate tool ids, orphaned tool results.""" - - issues: list[IssueKind] = [] - seen_call_ids: set[str] = set() - seen_result_ids: set[str] = set() - pending_call_ids: set[str] = set() - - duplicate_call = False - duplicate_result = False - orphaned_result = False - - for msg in messages: - if msg.role in ("user", "assistant") and pending_call_ids: - # result should have been in a tool message before this - # if it wasn't then it's a stray call, will be auto-fixed later - pending_call_ids.clear() - - if msg.role == "assistant": - # check if tool call is duplicate - # if not, mark it and append it to pending - for part in msg.parts: - if not isinstance(part, messages_.ToolCallPart): - continue - if part.tool_call_id in seen_call_ids: - duplicate_call = True - else: - seen_call_ids.add(part.tool_call_id) - pending_call_ids.add(part.tool_call_id) - - elif msg.role == "tool": - # check that this tool result is not duplicate and that - # there's a pending call from previous assistant message - for part in msg.parts: - if not isinstance(part, messages_.ToolResultPart): - continue - if part.tool_call_id in seen_result_ids: - duplicate_result = True - else: - seen_result_ids.add(part.tool_call_id) - if part.tool_call_id not in pending_call_ids: - orphaned_result = True - continue - pending_call_ids.remove(part.tool_call_id) - - if duplicate_call: - issues.append("duplicate-tool-call") - if duplicate_result: - issues.append("duplicate-tool-result") - if orphaned_result: - issues.append("orphaned-tool-result") - - return issues - - -def _fix_missing_results( - messages: list[messages_.Message], mode: Mode -) -> tuple[list[messages_.Message], list[IssueKind]]: - """Insert fake error results for stray tool calls.""" - issues: list[IssueKind] = [] - result: list[messages_.Message] = [] - - # 1. collect all result ids - answered: set[str] = set() - for msg in messages: - if msg.role == "tool": - for part in msg.parts: - if isinstance(part, messages_.ToolResultPart): - answered.add(part.tool_call_id) - - # pending tool calls from the current assistant turn - pending: dict[str, messages_.ToolCallPart] = {} - - def _flush_pending() -> None: - if not pending: - return - issues.append("orphaned-tool-call") - if mode == "auto": - synthetic = builders.tool_message( - *( - messages_.ToolResultPart( - tool_call_id=tc.tool_call_id, - tool_name=tc.tool_name, - result="Tool result not available", - is_error=True, - ) - for tc in pending.values() - ) - ) - result.append(synthetic) - - for msg in messages: - # if we're seeing a user / assistant message, then - # all pending tool calls are strays, because their results - # should have followed immediately after in a tool message - if msg.role in ("user", "assistant") and pending: - _flush_pending() - pending.clear() - - # 2. track calls - if msg.role == "assistant": - for part in msg.parts: - if ( - isinstance(part, messages_.ToolCallPart) - and part.tool_call_id not in answered - ): - pending[part.tool_call_id] = part - result.append(msg) - # 3. match results with calls - elif msg.role == "tool": - for part in msg.parts: - if isinstance(part, messages_.ToolResultPart): - pending.pop(part.tool_call_id, None) - result.append(msg) - else: - result.append(msg) - - _flush_pending() - - return result, issues - - -def prepare_messages( - messages: list[messages_.Message], - *, - mode: Mode = "auto", -) -> list[messages_.Message]: - """Fix and validate message list. - - ``"auto"`` (default) -- silently fixes recoverable issues (signal - messages, internal parts, invalid tool args, missing tool results). - ``"strict"`` -- collects every recoverable issue and raises - :class:`IntegrityError`. - - Duplicate tool-call IDs, duplicate tool-result IDs, and orphaned - tool results always raise :class:`IntegrityError` regardless of mode. - - Always returns a **new** list; never mutates the input. - """ - issues: list[IssueKind] = [] - - result, phase1_issues = _clean_messages(list(messages), mode) - issues.extend(phase1_issues) - - # never auto-fixed - fatal_issues = _validate_tool_ids(result) - issues.extend(fatal_issues) - - if not fatal_issues: - result, phase3_issues = _fix_missing_results(result, mode) - issues.extend(phase3_issues) - - if fatal_issues or (mode == "strict" and issues): - raise IntegrityError(issues) - - if issues: - logger.warning( - "Auto-fixed %d message issue(s): %s", - len(issues), - ", ".join(issues), - ) - - return result diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 25ecd6ce..0f6974cc 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -1,11 +1,12 @@ -from __future__ import annotations - import importlib import uuid -from typing import Annotated, Any, Literal, overload +from typing import Annotated, Any, Literal, Self import pydantic +from . import media +from . import usage as usage_ + def generate_id(prefix: str | None = None) -> str: """Generate a short random ID for messages and parts.""" @@ -14,74 +15,58 @@ def generate_id(prefix: str | None = None) -> str: class TextPart(pydantic.BaseModel): - model_config = pydantic.ConfigDict(frozen=True) - id: str = pydantic.Field(default_factory=generate_id) text: str - type: Literal["text"] = "text" - - -class ToolCallPart(pydantic.BaseModel): - """A tool invocation requested by the LLM. - - Lives inside ``role="assistant"`` messages. The corresponding result - (if any) will appear as a :class:`ToolResultPart` in a separate - ``role="tool"`` message, linked by ``tool_call_id``. - """ + kind: Literal["text"] = "text" model_config = pydantic.ConfigDict(frozen=True) + +class ToolCallPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) tool_call_id: str tool_name: str tool_args: str - type: Literal["tool_call"] = "tool_call" - - -class ToolResultPart(pydantic.BaseModel): - """The result of executing a tool call. - - Lives inside ``role="tool"`` messages. Back-references the - originating call via ``tool_call_id``. - """ + kind: Literal["tool_call"] = "tool_call" model_config = pydantic.ConfigDict(frozen=True) + +class ToolResultPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) tool_call_id: str tool_name: str result: Any = None is_error: bool = False - type: Literal["tool_result"] = "tool_result" - -class ReasoningPart(pydantic.BaseModel): + kind: Literal["tool_result"] = "tool_result" model_config = pydantic.ConfigDict(frozen=True) + +class ReasoningPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) text: str - type: Literal["reasoning"] = "reasoning" # Anthropic's thinking blocks include a signature for cache/verification. # This must be preserved and sent back in multi-turn conversations. signature: str | None = None - -class HookPart(pydantic.BaseModel): - """Part representing a hook suspension point in the agent's turn.""" - + kind: Literal["reasoning"] = "reasoning" model_config = pydantic.ConfigDict(frozen=True) + +class HookPart[T](pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) hook_id: str hook_type: str - status: Literal[ - "pending", "resolved", "cancelled" - ] # TODO should be shared with hook type + status: Literal["pending", "resolved"] metadata: dict[str, Any] = pydantic.Field(default_factory=dict) - resolution: dict[str, Any] | None = None # TODO should have payload type - type: Literal["hook"] = "hook" + resolution: T | None = None + + kind: Literal["hook"] = "hook" + model_config = pydantic.ConfigDict(frozen=True) +# todo redo this structured output situation and simplify it def _resolve_class(fully_qualified_name: str) -> type[pydantic.BaseModel]: """Import and return a class from its fully qualified name. @@ -117,7 +102,7 @@ class StructuredOutputPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) data: dict[str, Any] output_type_name: str - type: Literal["structured_output"] = "structured_output" + kind: Literal["structured_output"] = "structured_output" _hydrated: Any = pydantic.PrivateAttr(default=None) @@ -150,10 +135,10 @@ class FilePart(pydantic.BaseModel): data: str | bytes media_type: str # IANA media type, e.g. "image/png", "audio/wav" filename: str | None = None - type: Literal["file"] = "file" + kind: Literal["file"] = "file" @classmethod - def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: + def from_url(cls, url: str, *, media_type: str | None = None) -> Self: """Create from a URL, inferring ``media_type`` from the URL if omitted. Inference handles ``data:`` URLs (the media type is embedded in the @@ -162,9 +147,7 @@ def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: ``media_type`` is provided. """ if media_type is None: - from . import media as media_helpers - - media_type = media_helpers.infer_media_type(url) + media_type = media.infer_media_type(url) return cls(data=url, media_type=media_type) @classmethod @@ -174,7 +157,7 @@ def from_bytes( *, media_type: str | None = None, filename: str | None = None, - ) -> FilePart: + ) -> Self: """Create from raw bytes, detecting ``media_type`` via magic bytes. Attempts image detection first, then audio. Raises @@ -182,11 +165,9 @@ def from_bytes( detection fails. """ if media_type is None: - from . import media as media_helpers - - media_type = media_helpers.detect_image_media_type( + media_type = media.detect_image_media_type( data - ) or media_helpers.detect_audio_media_type(data) + ) or media.detect_audio_media_type(data) if media_type is None: raise ValueError( "Cannot detect media_type from bytes. Provide media_type explicitly." @@ -199,125 +180,20 @@ def from_bytes( | ToolCallPart | ToolResultPart | ReasoningPart - | HookPart + | HookPart[Any] | StructuredOutputPart | FilePart, - pydantic.Field(discriminator="type"), + pydantic.Field(discriminator="kind"), ] -class Usage(pydantic.BaseModel): - """Normalized token usage from a single LLM call. - - Provides a provider-agnostic view of token consumption. Fields that a - provider does not report are left as ``None`` (not zero) so callers - can distinguish "not reported" from "zero tokens used". - """ - - model_config = pydantic.ConfigDict(frozen=True) - - input_tokens: int = 0 - output_tokens: int = 0 - - # Optional breakdowns — not all providers report these. - reasoning_tokens: int | None = None - cache_read_tokens: int | None = None - cache_write_tokens: int | None = None - - # Pass-through of the raw provider usage payload so callers can access - # provider-specific fields (e.g. OpenAI's accepted_prediction_tokens). - raw: dict[str, Any] | None = None - - @property - def total_tokens(self) -> int: - """input_tokens + output_tokens (always consistent).""" - return self.input_tokens + self.output_tokens - - def __add__(self, other: Usage) -> Usage: - """Accumulate usage across multiple LLM calls.""" - - def _add_optional(a: int | None, b: int | None) -> int | None: - """Add two optional ints. Returns None only if both are None.""" - if a is None and b is None: - return None - return (a or 0) + (b or 0) - - return Usage( - input_tokens=self.input_tokens + other.input_tokens, - output_tokens=self.output_tokens + other.output_tokens, - reasoning_tokens=_add_optional( - self.reasoning_tokens, other.reasoning_tokens - ), - cache_read_tokens=_add_optional( - self.cache_read_tokens, other.cache_read_tokens - ), - cache_write_tokens=_add_optional( - self.cache_write_tokens, other.cache_write_tokens - ), - # Don't merge raw — it's per-call and provider-specific. - ) - - -# --------------------------------------------------------------------------- -# Streaming sidecar — transient state excluded from persistence. -# --------------------------------------------------------------------------- - - -class PartOpened(pydantic.BaseModel): - """A new streaming block was opened by the LLM. - - ``part`` holds the initial snapshot of the part (empty text/args). - """ - - model_config = pydantic.ConfigDict(frozen=True) - - part: Part - type: Literal["part_opened"] = "part_opened" - - -class PartDelta(pydantic.BaseModel): - """An incremental update to a streaming part. - - ``part`` is the post-delta snapshot (state accumulated up to and including - ``chunk``). ``chunk`` is the new fragment appended this step (plain text - for :class:`TextPart` / :class:`ReasoningPart`, a JSON-args fragment for - :class:`ToolCallPart`). - """ - - model_config = pydantic.ConfigDict(frozen=True) - - part: Part - chunk: str - type: Literal["part_delta"] = "part_delta" - - -class PartClosed(pydantic.BaseModel): - """A streaming block was closed by the LLM. - - ``part`` holds the final snapshot of the part. - """ - - model_config = pydantic.ConfigDict(frozen=True) - - part: Part - type: Literal["part_closed"] = "part_closed" - - -StreamEvent = Annotated[ - PartOpened | PartDelta | PartClosed, - pydantic.Field(discriminator="type"), -] - - -class StreamState(pydantic.BaseModel): - """Transient streaming state attached to a Message during streaming. - - ``new_events`` contains the events since the previous yield — never cumulative. - ``is_done`` is True once the stream has finished. - """ - - new_events: list[StreamEvent] = pydantic.Field(default_factory=list) - is_done: bool = False +ALLOWED_PARTS: dict[str, set[str]] = { + "user": {"text", "file"}, + "assistant": {"text", "tool_call", "reasoning", "structured_output"}, + "system": {"text"}, + "tool": {"tool_result"}, + "internal": {"hook"}, +} class Message(pydantic.BaseModel): @@ -328,134 +204,15 @@ class Message(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) turn_id: str | None = None source_label: str | None = None - usage: Usage | None = None - stream: StreamState | None = pydantic.Field(default=None, exclude=True) - - @overload - def replace(self, new: Part, /) -> Message: ... - @overload - def replace(self, old: Part, new: Part, /) -> Message: ... - def replace(self, *args: Part) -> Message: - """Return a copy with a part replaced. - - Single arg: ``msg.replace(updated_part)`` — matches by ``id``. - Two args: ``msg.replace(old, new)`` — matches by identity. - - Raises ValueError if the target part is not found. - """ - if len(args) == 1: - (new,) = args - match_id: str | None = new.id - match_ref = None - elif len(args) == 2: - old, new = args - match_id = None - match_ref = old - else: - raise TypeError(f"replace() takes 1 or 2 arguments ({len(args)} given)") - found = False - new_parts: list[Part] = [] - for p in self.parts: - if not found and ( - (match_id is not None and p.id == match_id) - or (match_ref is not None and p is match_ref) - ): - found = True - new_parts.append(new) - else: - new_parts.append(p) - if not found: - if match_id is not None: - raise ValueError(f"No part with id '{match_id}' in message") - raise ValueError("Part not found in message") - return self.model_copy(update={"parts": new_parts}) - - @property - def output(self) -> Any: - """Return the validated structured output, or None.""" - for part in self.parts: - if isinstance(part, StructuredOutputPart): - return part.value - return None - - @property - def is_done(self) -> bool: - """No sidecar (persisted/restored) means done. Otherwise ``stream.is_done``.""" - if self.stream is None: - return True - return self.stream.is_done - - def get_part(self, part_id: str) -> Part | None: - """Find a part by id, or return None if not found.""" - for part in self.parts: - if part.id == part_id: - return part - return None - - @property - def deltas(self) -> list[PartDelta]: - """PartDelta events from this yield step, in order. - - Empty list means nothing streamed in this step. Each event carries - its post-delta :class:`Part` snapshot via ``ev.part`` and the chunk - fragment via ``ev.chunk``. - """ - if self.stream is None: - return [] - return [ev for ev in self.stream.new_events if isinstance(ev, PartDelta)] - - @property - def files(self) -> list[FilePart]: - """All file parts in the message.""" - return [p for p in self.parts if isinstance(p, FilePart)] - - @property - def images(self) -> list[FilePart]: - """File parts with ``image/*`` media types.""" - return [ - p - for p in self.parts - if isinstance(p, FilePart) and p.media_type.startswith("image/") - ] - - @property - def videos(self) -> list[FilePart]: - """File parts with ``video/*`` media types.""" - return [ - p - for p in self.parts - if isinstance(p, FilePart) and p.media_type.startswith("video/") - ] - - @property - def text(self) -> str: - for part in self.parts: - if isinstance(part, TextPart): - return part.text - return "" - - @property - def reasoning(self) -> str: - for part in self.parts: - if isinstance(part, ReasoningPart): - return part.text - return "" - - @property - def tool_calls(self) -> list[ToolCallPart]: - """All tool-call parts in this message.""" - return [part for part in self.parts if isinstance(part, ToolCallPart)] + usage: usage_.Usage | None = None - @property - def tool_results(self) -> list[ToolResultPart]: - """All tool-result parts in this message.""" - return [part for part in self.parts if isinstance(part, ToolResultPart)] - - def get_hook_part(self, hook_id: str | None = None) -> HookPart | None: - """Find a HookPart by hook_id, or return the first HookPart if no id given.""" - for part in self.parts: - if isinstance(part, HookPart) and ( - hook_id is None or part.hook_id == hook_id - ): - return part - return None + @pydantic.model_validator(mode="after") + def _check_parts(self) -> Self: + allowed = ALLOWED_PARTS[self.role] + bad = [p.kind for p in self.parts if p.kind not in allowed] + if bad: + raise ValueError( + f"role={self.role!r} cannot contain parts of kind(s) " + f"{sorted(set(bad))}; allowed: {sorted(allowed)}" + ) + return self diff --git a/src/ai/types/stream.py b/src/ai/types/proto.py similarity index 56% rename from src/ai/types/stream.py rename to src/ai/types/proto.py index f6389bc0..30d64bc7 100644 --- a/src/ai/types/stream.py +++ b/src/ai/types/proto.py @@ -1,15 +1,19 @@ -"""StreamResultLike — structural protocol for stream results. +from collections.abc import AsyncGenerator +from typing import Any, Protocol, runtime_checkable -Middleware authors can type-check against this protocol without depending -on the concrete ``StreamResult`` class in ``ai.models``. -""" +from . import messages, usage -from __future__ import annotations -from collections.abc import AsyncGenerator -from typing import Any, Protocol, runtime_checkable +@runtime_checkable +class ToolLike(Protocol): + """Anything the LLM layer can use as a tool definition.""" -from . import messages as messages_ + @property + def name(self) -> str: ... + @property + def description(self) -> str: ... + @property + def param_schema(self) -> dict[str, Any]: ... @runtime_checkable @@ -21,16 +25,16 @@ class StreamResultLike(Protocol): The easiest way is ``StreamResult.from_generator(gen)``. """ - def __aiter__(self) -> AsyncGenerator[messages_.Message]: ... + def __aiter__(self) -> AsyncGenerator[messages.Message]: ... @property def text(self) -> str: ... @property - def tool_calls(self) -> list[messages_.ToolCallPart]: ... + def tool_calls(self) -> list[messages.ToolCallPart]: ... @property - def usage(self) -> messages_.Usage | None: ... + def usage(self) -> usage.Usage | None: ... @property def output(self) -> Any: ... diff --git a/src/ai/types/tools.py b/src/ai/types/tools.py index 0661ff29..0cb64b08 100644 --- a/src/ai/types/tools.py +++ b/src/ai/types/tools.py @@ -6,23 +6,11 @@ from __future__ import annotations -from typing import Any, Protocol, runtime_checkable +from typing import Any import pydantic -@runtime_checkable -class ToolLike(Protocol): - """Anything the LLM layer can use as a tool definition.""" - - @property - def name(self) -> str: ... - @property - def description(self) -> str: ... - @property - def param_schema(self) -> dict[str, Any]: ... - - class ToolSchema(pydantic.BaseModel): """What the LLM sees: name, description, and JSON Schema for parameters.""" diff --git a/src/ai/types/usage.py b/src/ai/types/usage.py new file mode 100644 index 00000000..d33b3fcd --- /dev/null +++ b/src/ai/types/usage.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +from typing import Any + +import pydantic + + +class Usage(pydantic.BaseModel): + """Normalized token usage from a single LLM call. + + Provides a provider-agnostic view of token consumption. Fields that a + provider does not report are left as ``None`` (not zero) so callers + can distinguish "not reported" from "zero tokens used". + """ + + model_config = pydantic.ConfigDict(frozen=True) + + input_tokens: int = 0 + output_tokens: int = 0 + + # Optional breakdowns — not all providers report these. + reasoning_tokens: int | None = None + cache_read_tokens: int | None = None + cache_write_tokens: int | None = None + + # Pass-through of the raw provider usage payload so callers can access + # provider-specific fields (e.g. OpenAI's accepted_prediction_tokens). + raw: dict[str, Any] | None = None + + @property + def total_tokens(self) -> int: + """input_tokens + output_tokens (always consistent).""" + return self.input_tokens + self.output_tokens + + def __add__(self, other: Usage) -> Usage: + """Accumulate usage across multiple LLM calls.""" + + def _add_optional(a: int | None, b: int | None) -> int | None: + """Add two optional ints. Returns None only if both are None.""" + if a is None and b is None: + return None + return (a or 0) + (b or 0) + + return Usage( + input_tokens=self.input_tokens + other.input_tokens, + output_tokens=self.output_tokens + other.output_tokens, + reasoning_tokens=_add_optional( + self.reasoning_tokens, other.reasoning_tokens + ), + cache_read_tokens=_add_optional( + self.cache_read_tokens, other.cache_read_tokens + ), + cache_write_tokens=_add_optional( + self.cache_write_tokens, other.cache_write_tokens + ), + # Don't merge raw — it's per-call and provider-specific. + ) From c53f609b04ce68aea5137d7d487d1b3475ab3498 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 23 Apr 2026 12:48:44 -0700 Subject: [PATCH 2/3] Wire new datamodel through the models --- examples/coding-agent/1_raw_stream.py | 18 ++ src/ai/__init__.py | 8 - src/ai/middleware.py | 2 +- src/ai/models/__init__.py | 2 +- src/ai/models/ai_gateway/stream.py | 23 ++- src/ai/models/anthropic/adapter.py | 159 ++++++++------- src/ai/models/core/api.py | 22 +- src/ai/models/core/helpers/streaming.py | 61 +++--- src/ai/models/core/params.py | 42 ++++ src/ai/models/core/types.py | 157 +------------- src/ai/models/openai/adapter.py | 61 ++++-- src/ai/types/__init__.py | 26 ++- src/ai/types/builders.py | 223 ++++++++++++++++++++ src/ai/types/events.py | 80 ++++++-- src/ai/types/group.py | 222 ++++++++++++++++++++ src/ai/types/integrity.py | 261 ++++++++++++++++++++++++ src/ai/types/proto.py | 3 +- src/ai/types/stream.py | 96 +++++++++ 18 files changed, 1149 insertions(+), 317 deletions(-) create mode 100644 examples/coding-agent/1_raw_stream.py create mode 100644 src/ai/models/core/params.py create mode 100644 src/ai/types/builders.py create mode 100644 src/ai/types/group.py create mode 100644 src/ai/types/integrity.py create mode 100644 src/ai/types/stream.py diff --git a/examples/coding-agent/1_raw_stream.py b/examples/coding-agent/1_raw_stream.py new file mode 100644 index 00000000..20b47057 --- /dev/null +++ b/examples/coding-agent/1_raw_stream.py @@ -0,0 +1,18 @@ +import ai +import asyncio + + +async def main() -> None: + model = ai.ai_gateway("anthropic/claude-opus-4.7") + + messages = [ + ai.system_message("you are a coding assistant"), + ai.user_message("actually i don't need assistance thanks"), + ] + + async for e in ai.stream(model, messages): + print(e) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 570d9cdc..6f1edb48 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -36,12 +36,8 @@ HookPart, Message, Part, - PartClosed, - PartDelta, - PartOpened, ReasoningPart, StreamResultLike, - StreamState, StructuredOutputPart, TextPart, ToolCallPart, @@ -64,16 +60,12 @@ # Types (from types/) "Message", "Part", - "PartClosed", - "PartDelta", - "PartOpened", "TextPart", "ToolCallPart", "ToolResultPart", "ReasoningPart", "FilePart", "HookPart", - "StreamState", "StructuredOutputPart", "ToolLike", "ToolSchema", diff --git a/src/ai/middleware.py b/src/ai/middleware.py index e5eaaa97..2390bb98 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -24,7 +24,7 @@ from .types import messages as messages_ from .types import tools as tools_ -from .types.stream import StreamResultLike +from .types.proto import StreamResultLike # --------------------------------------------------------------------------- # Call context objects — frozen dataclasses with isolated mutable fields. diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index a4d9c4ca..f87913f8 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -26,7 +26,7 @@ ids = await openai.list() """ -from ..types.stream import StreamResultLike +from ..types.proto import StreamResultLike from .ai_gateway import ai_gateway from .anthropic import anthropic from .core.adapters import register_generate, register_stream diff --git a/src/ai/models/ai_gateway/stream.py b/src/ai/models/ai_gateway/stream.py index 9ba0f73e..1bbb5c6c 100644 --- a/src/ai/models/ai_gateway/stream.py +++ b/src/ai/models/ai_gateway/stream.py @@ -12,9 +12,11 @@ import httpx import pydantic +from ...types import events as events_ from ...types import media from ...types import messages as messages_ from ...types import tools as tools_ +from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files, streaming @@ -169,10 +171,10 @@ def _expand_tool_call(data: dict[str, Any]) -> list[streaming.StreamEvent]: ] -def _parse_usage(data: Any) -> messages_.Usage: +def _parse_usage(data: Any) -> usage_.Usage: """Parse v3 usage data into an internal ``Usage``.""" if not isinstance(data, dict): - return messages_.Usage() + return usage_.Usage() input_tokens_obj = data.get("inputTokens") output_tokens_obj = data.get("outputTokens") @@ -180,7 +182,7 @@ def _parse_usage(data: Any) -> messages_.Usage: if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} - return messages_.Usage( + return usage_.Usage( input_tokens=inp.get("total") or 0, output_tokens=out.get("total") or 0, reasoning_tokens=out.get("reasoning"), @@ -189,7 +191,7 @@ def _parse_usage(data: Any) -> messages_.Usage: raw=data, ) - return messages_.Usage( + return usage_.Usage( input_tokens=data.get("prompt_tokens") or data.get("inputTokens") or 0, output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), raw=data, @@ -287,12 +289,10 @@ async def stream( tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: +) -> AsyncGenerator[events_.Event]: """Stream an LLM response through the AI Gateway v3 protocol. - Yields ``Message`` snapshots as the response streams in. Each - snapshot is a complete, self-contained message reflecting the - accumulated state up to that point. + Yields :class:`~ai.types.events.Event` objects as the response streams in. """ body = await _build_request_body( messages, tools=tools, output_type=output_type, **kwargs @@ -319,10 +319,11 @@ async def stream( api_key_provided=bool(client.api_key), ) + yield handler.message_start() async for data in _common.parse_sse_lines(response): - for event in _parse_stream_part(data): - msg = handler.handle_event(event) - yield msg + for adapter_event in _parse_stream_part(data): + for out_event in handler.handle_event(adapter_event): + yield out_event except errors.GatewayError: raise except httpx.TimeoutException as exc: diff --git a/src/ai/models/anthropic/adapter.py b/src/ai/models/anthropic/adapter.py index b5d2e024..c54c16d7 100644 --- a/src/ai/models/anthropic/adapter.py +++ b/src/ai/models/anthropic/adapter.py @@ -4,8 +4,6 @@ The SDK client is constructed from :class:`Client` params on each call. """ -from __future__ import annotations - import json from collections.abc import AsyncGenerator, Sequence from typing import Any @@ -13,12 +11,9 @@ import anthropic import pydantic -from ...types import media -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import client as client_ -from ..core import model as model_ -from ..core.helpers import streaming +from ... import types +from ...types import events +from .. import core # --------------------------------------------------------------------------- # Message / tool conversion — internal types → Anthropic wire format @@ -26,7 +21,7 @@ def _tools_to_anthropic( - tools: Sequence[tools_.ToolLike], + tools: Sequence[types.proto.ToolLike], ) -> list[dict[str, Any]]: """Convert internal Tool objects to Anthropic tool schema format.""" return [ @@ -40,7 +35,7 @@ def _tools_to_anthropic( def _file_part_to_anthropic( - part: messages_.FilePart, + part: types.FilePart, ) -> dict[str, Any]: """Convert a :class:`FilePart` to an Anthropic content block. @@ -53,7 +48,7 @@ def _file_part_to_anthropic( if mt.startswith("image/"): media_type = "image/jpeg" if mt == "image/*" else mt - if isinstance(part.data, str) and media.is_url(part.data): + if isinstance(part.data, str) and types.media.is_url(part.data): return { "type": "image", "source": {"type": "url", "url": part.data}, @@ -63,12 +58,12 @@ def _file_part_to_anthropic( "source": { "type": "base64", "media_type": media_type, - "data": media.data_to_base64(part.data), + "data": types.media.data_to_base64(part.data), }, } if mt == "application/pdf": - if isinstance(part.data, str) and media.is_url(part.data): + if isinstance(part.data, str) and types.media.is_url(part.data): return { "type": "document", "source": {"type": "url", "url": part.data}, @@ -78,14 +73,14 @@ def _file_part_to_anthropic( "source": { "type": "base64", "media_type": "application/pdf", - "data": media.data_to_base64(part.data), + "data": types.media.data_to_base64(part.data), }, } if mt == "text/plain": if isinstance(part.data, bytes): text_data = part.data.decode("utf-8") - elif media.is_url(part.data): + elif types.media.is_url(part.data): return { "type": "document", "source": {"type": "url", "url": part.data}, @@ -107,7 +102,7 @@ def _file_part_to_anthropic( async def _messages_to_anthropic( - messages: list[messages_.Message], + messages: list[types.Message], ) -> tuple[str | None, list[dict[str, Any]]]: """Convert internal messages to Anthropic API format. @@ -122,13 +117,13 @@ async def _messages_to_anthropic( match msg.role: case "system": system_prompt = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) + p.text for p in msg.parts if isinstance(p, types.TextPart) ) case "assistant": content: list[dict[str, Any]] = [] for part in msg.parts: match part: - case messages_.ReasoningPart(text=text, signature=signature): + case types.ReasoningPart(text=text, signature=signature): if signature: content.append( { @@ -137,9 +132,9 @@ async def _messages_to_anthropic( "signature": signature, } ) - case messages_.TextPart(text=text): + case types.TextPart(text=text): content.append({"type": "text", "text": text}) - case messages_.ToolCallPart(): + case types.ToolCallPart(): tool_input = ( json.loads(part.tool_args) if part.tool_args else {} ) @@ -157,7 +152,7 @@ async def _messages_to_anthropic( case "tool": tool_results: list[dict[str, Any]] = [] for part in msg.parts: - if isinstance(part, messages_.ToolResultPart): + if isinstance(part, types.ToolResultPart): entry: dict[str, Any] = { "type": "tool_result", "tool_use_id": part.tool_call_id, @@ -172,19 +167,19 @@ async def _messages_to_anthropic( result.append({"role": "user", "content": tool_results}) case "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + has_files = any(isinstance(p, types.FilePart) for p in msg.parts) if not has_files: content_text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) + p.text for p in msg.parts if isinstance(p, types.TextPart) ) result.append({"role": "user", "content": content_text}) else: user_content: list[dict[str, Any]] = [] for p in msg.parts: match p: - case messages_.TextPart(text=text): + case types.TextPart(text=text): user_content.append({"type": "text", "text": text}) - case messages_.FilePart(): + case types.FilePart(): user_content.append(_file_part_to_anthropic(p)) result.append({"role": "user", "content": user_content}) @@ -228,7 +223,7 @@ def _to_content_list(content: Any) -> list[dict[str, Any]]: def _make_client( - client: client_.Client, + client: core.client.Client, ) -> anthropic.AsyncAnthropic: """Construct an ``AsyncAnthropic`` from our generic ``Client``.""" return anthropic.AsyncAnthropic( @@ -243,19 +238,19 @@ def _make_client( async def stream( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], *, - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[types.proto.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, thinking: bool = False, budget_tokens: int = 10000, **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: +) -> AsyncGenerator[events.Event]: """Stream an LLM response via the Anthropic messages API. - Yields ``Message`` snapshots as the response streams in. + Yields :class:`~ai.types.events.Event` objects as the response streams in. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -285,16 +280,20 @@ async def stream( if output_type is not None: api_kwargs["output_format"] = output_type - handler = streaming.StreamHandler() - block_types: dict[int, str] = {} tool_ids: dict[int, str] = {} signature_buffer: dict[int, str] = {} + # Accumulate parts for the final Message + parts: list[types.Part] = [] + _text_parts: dict[str, str] = {} # block_id -> accumulated text + _reasoning_parts: dict[str, str] = {} # block_id -> accumulated text + _tool_parts: dict[str, str] = {} # tool_call_id -> accumulated args try: stream_cm = sdk_client.messages.stream(**api_kwargs) async with stream_cm as sdk_stream: + yield events.MessageStart() async for event in sdk_stream: match event.type: case "content_block_start": @@ -304,20 +303,17 @@ async def stream( match block.type: case "text": - yield handler.handle_event( - streaming.TextStart(block_id=str(idx)) - ) + _text_parts[str(idx)] = "" + yield events.TextStart(block_id=str(idx)) case "thinking": - yield handler.handle_event( - streaming.ReasoningStart(block_id=str(idx)) - ) + _reasoning_parts[str(idx)] = "" + yield events.ReasoningStart(block_id=str(idx)) case "tool_use": tool_ids[idx] = block.id - yield handler.handle_event( - streaming.ToolStart( - tool_call_id=block.id, - tool_name=block.name, - ) + _tool_parts[block.id] = "" + yield events.ToolStart( + tool_call_id=block.id, + tool_name=block.name, ) case "content_block_delta": @@ -326,18 +322,21 @@ async def stream( match delta.type: case "text_delta": - yield handler.handle_event( - streaming.TextDelta( - block_id=str(idx), - delta=delta.text, - ) + _text_parts[str(idx)] = ( + _text_parts.get(str(idx), "") + delta.text + ) + yield events.TextDelta( + chunk=delta.text, + block_id=str(idx), ) case "thinking_delta": - yield handler.handle_event( - streaming.ReasoningDelta( - block_id=str(idx), - delta=delta.thinking, - ) + _reasoning_parts[str(idx)] = ( + _reasoning_parts.get(str(idx), "") + + delta.thinking + ) + yield events.ReasoningDelta( + chunk=delta.thinking, + block_id=str(idx), ) case "signature_delta": signature_buffer[idx] = ( @@ -346,37 +345,54 @@ async def stream( case "input_json_delta": tool_id = tool_ids.get(idx) if tool_id: - yield handler.handle_event( - streaming.ToolArgsDelta( - tool_call_id=tool_id, - delta=delta.partial_json, - ) + _tool_parts[tool_id] = ( + _tool_parts.get(tool_id, "") + + delta.partial_json + ) + yield events.ToolDelta( + chunk=delta.partial_json, + tool_call_id=tool_id, ) case "content_block_stop": idx = event.index + bid = str(idx) match block_types.get(idx): case "text": - yield handler.handle_event( - streaming.TextEnd(block_id=str(idx)) + parts.append( + types.TextPart( + id=bid, text=_text_parts.get(bid, "") + ) ) + yield events.TextEnd(block_id=bid) case "thinking": - yield handler.handle_event( - streaming.ReasoningEnd( - block_id=str(idx), + parts.append( + types.ReasoningPart( + id=bid, + text=_reasoning_parts.get(bid, ""), signature=signature_buffer.get(idx), ) ) + yield events.ReasoningEnd( + block_id=bid, + signature=signature_buffer.get(idx), + ) case "tool_use": tool_id = tool_ids.get(idx) if tool_id: - yield handler.handle_event( - streaming.ToolEnd(tool_call_id=tool_id) + parts.append( + types.ToolCallPart( + id=tool_id, + tool_call_id=tool_id, + tool_name=block_types.get(idx, ""), + tool_args=_tool_parts.get(tool_id, ""), + ) ) + yield events.ToolEnd(tool_call_id=tool_id) snapshot = sdk_stream.current_message_snapshot sdk_usage = snapshot.usage - usage = messages_.Usage( + usage = types.Usage( input_tokens=sdk_usage.input_tokens or 0, output_tokens=sdk_usage.output_tokens or 0, cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), @@ -385,6 +401,11 @@ async def stream( ), raw=sdk_usage.model_dump(exclude_none=True) or None, ) - yield handler.handle_event(streaming.MessageDone(usage=usage)) + final_message = types.Message( + role="assistant", + parts=parts, + usage=usage, + ) + yield events.MessageEnd(message=final_message, usage=usage) finally: await sdk_client.close() diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 898eafcd..a38e6362 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -5,12 +5,13 @@ from __future__ import annotations -from collections.abc import Sequence +from collections.abc import AsyncGenerator, Sequence from typing import Any import pydantic from ... import middleware as middleware_ +from ...types import events as events_ from ...types import integrity as integrity_ from ...types import messages as messages_ from ...types import stream as stream_ @@ -21,7 +22,7 @@ from . import types as types_ -async def stream( +def stream( model: model_.Model, messages: list[messages_.Message], *, @@ -36,6 +37,11 @@ async def stream( collects the final ``Message``. After iteration, access ``.text``, ``.tool_calls``, ``.usage``, etc. + Call-site is a plain ``async for`` — no outer ``await`` needed:: + + async for msg in ai.stream(model, messages): + ... + One call is one turn: a single request and its response. The model response carries ``turn_id``; re-emitted input messages keep any existing ``turn_id`` from prior turns and only receive the current @@ -43,6 +49,9 @@ async def stream( The client is resolved from the model: ``model.client`` if set, otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. + + Middleware dispatch and adapter setup are deferred to the first + iteration; any async preflight work happens there. """ messages = integrity_.prepare_messages(messages) @@ -76,8 +85,13 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: input_messages=call.messages, ) - chain = middleware_._build_model_chain(_real) - return await chain(call) + async def _driver() -> AsyncGenerator[events_.Event]: + chain = middleware_._build_model_chain(_real) + inner = await chain(call) + async for event in inner: + yield event + + return stream_.StreamResult.from_generator(_driver()) async def generate( diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index bd7db603..57ceebe5 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -2,7 +2,9 @@ import dataclasses +from ....types import events as events_ from ....types import messages as messages_ +from ....types import usage as usage_ @dataclasses.dataclass @@ -67,7 +69,7 @@ class FileEvent: @dataclasses.dataclass class MessageDone: finish_reason: str | None = None - usage: messages_.Usage | None = None + usage: usage_.Usage | None = None StreamEvent = ( @@ -88,12 +90,11 @@ class MessageDone: @dataclasses.dataclass class StreamHandler: """ - Accumulates LLM adapter events and produces Messages with stateful parts. + Accumulates LLM adapter events and produces public Event objects. This is the normalization layer between LLM adapters and the rest of the system. Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, - updated in place as events stream in. Each event carries the just-constructed - frozen part snapshot, so consumers never need to look parts up by id. + updated in place as events stream in. """ message_id: str = dataclasses.field(default_factory=messages_.generate_id) @@ -108,40 +109,41 @@ class StreamHandler: _active_tool_ids: set[str] = dataclasses.field(default_factory=set) _is_done: bool = False - _usage: messages_.Usage | None = None + _usage: usage_.Usage | None = None - def handle_event(self, event: StreamEvent) -> messages_.Message: - """Process event and return current Message state.""" + def message_start(self) -> events_.MessageStart: + """Emit a MessageStart event at the beginning of a stream.""" + return events_.MessageStart() - # Sidecar events for this yield (reset each call). - stream_events: list[messages_.StreamEvent] = [] + def handle_event(self, event: StreamEvent) -> list[events_.Event]: + """Process an adapter event and return public Event objects.""" + + out: list[events_.Event] = [] match event: case TextStart(block_id=bid): part: messages_.Part = messages_.TextPart(id=bid, text="") self._current_parts[bid] = part self._active_text_id = bid - stream_events.append(messages_.PartOpened(part=part)) + out.append(events_.TextStart(block_id=bid)) case TextDelta(block_id=bid, delta=d): existing = self._current_parts[bid] assert isinstance(existing, messages_.TextPart) part = messages_.TextPart(id=bid, text=existing.text + d) self._current_parts[bid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) + out.append(events_.TextDelta(chunk=d, block_id=bid)) case TextEnd(block_id=bid): if self._active_text_id == bid: self._active_text_id = None - stream_events.append( - messages_.PartClosed(part=self._current_parts[bid]) - ) + out.append(events_.TextEnd(block_id=bid)) case ReasoningStart(block_id=bid): part = messages_.ReasoningPart(id=bid, text="") self._current_parts[bid] = part self._active_reasoning_id = bid - stream_events.append(messages_.PartOpened(part=part)) + out.append(events_.ReasoningStart(block_id=bid)) case ReasoningDelta(block_id=bid, delta=d): existing = self._current_parts[bid] @@ -152,7 +154,7 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: signature=existing.signature, ) self._current_parts[bid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) + out.append(events_.ReasoningDelta(chunk=d, block_id=bid)) case ReasoningEnd(block_id=bid, signature=sig): existing = self._current_parts[bid] @@ -163,7 +165,7 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: self._current_parts[bid] = part if self._active_reasoning_id == bid: self._active_reasoning_id = None - stream_events.append(messages_.PartClosed(part=part)) + out.append(events_.ReasoningEnd(block_id=bid, signature=sig)) case ToolStart(tool_call_id=tcid, tool_name=name): part = messages_.ToolCallPart( @@ -174,7 +176,7 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: ) self._current_parts[tcid] = part self._active_tool_ids.add(tcid) - stream_events.append(messages_.PartOpened(part=part)) + out.append(events_.ToolStart(tool_call_id=tcid, tool_name=name)) case ToolArgsDelta(tool_call_id=tcid, delta=d): existing = self._current_parts[tcid] @@ -186,39 +188,32 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: tool_args=existing.tool_args + d, ) self._current_parts[tcid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) + out.append(events_.ToolDelta(chunk=d, tool_call_id=tcid)) case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) - stream_events.append( - messages_.PartClosed(part=self._current_parts[tcid]) - ) + out.append(events_.ToolEnd(tool_call_id=tcid)) case FileEvent(block_id=bid, media_type=mt, data=d): self._current_parts[bid] = messages_.FilePart( id=bid, data=d, media_type=mt ) - case MessageDone(usage=usage): + case MessageDone(usage=u): self._is_done = True - self._usage = usage + self._usage = u self._active_text_id = None self._active_reasoning_id = None self._active_tool_ids.clear() + msg = self._build_message() + out.append(events_.MessageEnd(message=msg, usage=u)) - return self._build_message(stream_events) + return out - def _build_message( - self, - stream_events: list[messages_.StreamEvent], - ) -> messages_.Message: + def _build_message(self) -> messages_.Message: return messages_.Message( id=self.message_id, role="assistant", parts=list(self._current_parts.values()), usage=self._usage if self._is_done else None, - stream=messages_.StreamState( - new_events=stream_events, - is_done=self._is_done, - ), ) diff --git a/src/ai/models/core/params.py b/src/ai/models/core/params.py new file mode 100644 index 00000000..afd143f9 --- /dev/null +++ b/src/ai/models/core/params.py @@ -0,0 +1,42 @@ +from typing import Any +import pydantic + + +_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) + + +class ImageParams(pydantic.BaseModel): + """Parameters for image generation (``/image-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + size: str | None = None + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +class VideoParams(pydantic.BaseModel): + """Parameters for video generation (``/video-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + resolution: str | None = None + duration: int | None = None + fps: int | None = None + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +GenerateParams = ImageParams | VideoParams diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py index a15d6dd1..4344fa05 100644 --- a/src/ai/models/core/types.py +++ b/src/ai/models/core/types.py @@ -1,150 +1,11 @@ -"""Core model-layer types — parameter objects and StreamResult. +"""Re-exports for backwards-compatible ``ai.models.core.types`` imports.""" -Parameter types (:class:`ImageParams`, :class:`VideoParams`) live here -because they parameterise the public :func:`generate` API. +from ...types.stream import StreamResult +from .params import GenerateParams, ImageParams, VideoParams -:class:`StreamResult` is the concrete wrapper returned by :func:`stream`. -""" - -from __future__ import annotations - -from collections.abc import AsyncGenerator -from typing import Any - -import pydantic - -from ...types import messages as messages_ - -# --------------------------------------------------------------------------- -# Generation parameter types -# --------------------------------------------------------------------------- - -_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) - - -class ImageParams(pydantic.BaseModel): - """Parameters for image generation (``/image-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - size: str | None = None - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -class VideoParams(pydantic.BaseModel): - """Parameters for video generation (``/video-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - resolution: str | None = None - duration: int | None = None - fps: int | None = None - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -GenerateParams = ImageParams | VideoParams - -# --------------------------------------------------------------------------- -# StreamResult -# --------------------------------------------------------------------------- - - -class StreamResult: - """Wrapper around a message stream. Async-iterable; collects the final result. - - Properties like ``.text`` and ``.tool_calls`` delegate to the final - ``Message`` snapshot and are available after iteration completes. - - One ``StreamResult`` represents one turn: a single LLM request and its - response. When *turn_id* is provided, the model response is stamped - with it. When *input_messages* is provided, they are re-emitted ahead - of the response; inputs that already carry a ``turn_id`` (from earlier - turns) are preserved as-is, only inputs with ``turn_id=None`` receive - the current *turn_id*. - - Satisfies :class:`~ai.types.StreamResultLike`. - """ - - def __init__( - self, - gen: AsyncGenerator[messages_.Message], - *, - turn_id: str | None = None, - input_messages: list[messages_.Message] | None = None, - ) -> None: - self._gen = gen - self._turn_id = turn_id - self._input_messages = input_messages or [] - self._final: messages_.Message | None = None - - @classmethod - def from_generator(cls, gen: AsyncGenerator[messages_.Message]) -> StreamResult: - """Create a :class:`StreamResult` from an async generator. - - This is the public API for middleware that needs to transform or - replace the stream returned by ``wrap_model``:: - - async def wrap_model(self, call, next): - original = await next(call) - - async def _transformed(): - async for msg in original: - yield modify(msg) - - return StreamResult.from_generator(_transformed()) - """ - return cls(gen) - - def __aiter__(self) -> AsyncGenerator[messages_.Message]: - return self._iterate() - - async def _iterate(self) -> AsyncGenerator[messages_.Message]: - # Re-emit input messages; stamp only the ones without a turn_id. - # Prior turns keep their existing ids. - for msg in self._input_messages: - if msg.turn_id is None and self._turn_id is not None: - msg = msg.model_copy(update={"turn_id": self._turn_id}) - yield msg - - # Stream model response with turn_id stamped (when missing). - async for msg in self._gen: - if msg.turn_id is None and self._turn_id is not None: - msg = msg.model_copy(update={"turn_id": self._turn_id}) - self._final = msg - yield msg - - @property - def turn_id(self) -> str | None: - """The turn id stamped on this stream's response (if any).""" - return self._turn_id - - @property - def text(self) -> str: - return self._final.text if self._final else "" - - @property - def tool_calls(self) -> list[messages_.ToolCallPart]: - return self._final.tool_calls if self._final else [] - - @property - def usage(self) -> messages_.Usage | None: - return self._final.usage if self._final else None - - @property - def output(self) -> Any: - """Parsed structured output from the final message, if available.""" - return self._final.output if self._final else None +__all__ = [ + "GenerateParams", + "ImageParams", + "StreamResult", + "VideoParams", +] diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index eb77b9f6..764d9b29 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -12,6 +12,7 @@ import openai import pydantic +from ...types import events as events_ from ...types import media from ...types import messages as messages_ from ...types import tools as tools_ @@ -206,10 +207,10 @@ async def stream( budget_tokens: int | None = None, reasoning_effort: str | None = None, **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: +) -> AsyncGenerator[events_.Event]: """Stream an LLM response via the OpenAI chat completions API. - Yields ``Message`` snapshots as the response streams in. + Yields :class:`~ai.types.events.Event` objects as the response streams in. Extra keyword arguments beyond the ``StreamFn`` protocol: @@ -256,6 +257,9 @@ async def stream( handler = streaming.StreamHandler() + def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: + return handler.handle_event(adapter_event) + try: sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) @@ -265,6 +269,8 @@ async def stream( finish_reason: str | None = None usage: messages_.Usage | None = None + yield handler.message_start() + async for chunk in sdk_stream: if chunk.usage is not None: raw = chunk.usage.model_dump(exclude_none=True) @@ -308,28 +314,33 @@ async def stream( if reasoning_value: if not reasoning_started: reasoning_started = True - yield handler.handle_event( + for e in _emit( streaming.ReasoningStart(block_id="reasoning") - ) - yield handler.handle_event( + ): + yield e + for e in _emit( streaming.ReasoningDelta( block_id="reasoning", delta=reasoning_value ) - ) + ): + yield e if delta.content: if reasoning_started: - yield handler.handle_event( + for e in _emit( streaming.ReasoningEnd(block_id="reasoning") - ) + ): + yield e reasoning_started = False if not text_started: text_started = True - yield handler.handle_event(streaming.TextStart(block_id="text")) - yield handler.handle_event( + for e in _emit(streaming.TextStart(block_id="text")): + yield e + for e in _emit( streaming.TextDelta(block_id="text", delta=delta.content) - ) + ): + yield e if delta.tool_calls: for tc in delta.tool_calls: @@ -351,37 +362,43 @@ async def stream( if not tc_state[idx]["started"] and tid: tc_state[idx]["started"] = True - yield handler.handle_event( + for e in _emit( streaming.ToolStart( tool_call_id=tid, tool_name=tname, ) - ) + ): + yield e if tid: - yield handler.handle_event( + for e in _emit( streaming.ToolArgsDelta( tool_call_id=tid, delta=tc.function.arguments, ) - ) + ): + yield e if choice.finish_reason is not None: finish_reason = choice.finish_reason if reasoning_started: - yield handler.handle_event( + for e in _emit( streaming.ReasoningEnd(block_id="reasoning") - ) + ): + yield e if text_started: - yield handler.handle_event(streaming.TextEnd(block_id="text")) + for e in _emit(streaming.TextEnd(block_id="text")): + yield e for tc in tc_state.values(): if tc["started"] and tc["id"]: - yield handler.handle_event( + for e in _emit( streaming.ToolEnd(tool_call_id=tc["id"]) - ) + ): + yield e - yield handler.handle_event( + for e in _emit( streaming.MessageDone(finish_reason=finish_reason, usage=usage) - ) + ): + yield e finally: await sdk_client.close() diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 6372fec5..4c9a277e 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -5,10 +5,16 @@ HookSuspention, MessageEnd, MessageStart, - PartDelta, - PartEnd, - PartStart, + ReasoningDelta, + ReasoningEnd, + ReasoningStart, Start, + TextDelta, + TextEnd, + TextStart, + ToolDelta, + ToolEnd, + ToolStart, ) from .messages import ( FilePart, @@ -25,6 +31,7 @@ from .proto import StreamResultLike, ToolLike from .tools import ToolSchema from .usage import Usage +from . import media __all__ = [ "End", @@ -37,18 +44,25 @@ "MessageEnd", "MessageStart", "Part", - "PartDelta", - "PartEnd", - "PartStart", + "ReasoningDelta", + "ReasoningEnd", "ReasoningPart", + "ReasoningStart", "Start", "StreamResultLike", "StructuredOutputPart", + "TextDelta", + "TextEnd", "TextPart", + "TextStart", "ToolCallPart", + "ToolDelta", + "ToolEnd", "ToolLike", "ToolResultPart", "ToolSchema", + "ToolStart", "Usage", "generate_id", + "media", ] diff --git a/src/ai/types/builders.py b/src/ai/types/builders.py new file mode 100644 index 00000000..94ddda5e --- /dev/null +++ b/src/ai/types/builders.py @@ -0,0 +1,223 @@ +"""Composable message construction helpers. + +Convenience functions for building Message objects without manually +constructing Part lists. Each ``*_message`` function returns a single +``Message``. +""" + +from __future__ import annotations + +from typing import Any, overload + +from .messages import ( + FilePart, + HookPart, + Message, + Part, + ReasoningPart, + StructuredOutputPart, + TextPart, + ToolCallPart, + ToolResultPart, +) + +_PART_TYPES = ( + TextPart, + ToolCallPart, + ToolResultPart, + ReasoningPart, + HookPart, + StructuredOutputPart, + FilePart, +) + +# A value that can appear as message content: bare strings become TextPart. +PartLike = str | Part + + +def _coerce_parts(args: tuple[PartLike, ...]) -> list[Part]: + parts: list[Part] = [] + for arg in args: + if isinstance(arg, str): + parts.append(TextPart(text=arg)) + elif isinstance(arg, _PART_TYPES): + parts.append(arg) + else: + raise TypeError(f"Expected str or Part, got {type(arg).__name__}") + return parts + + +def system_message(*content: PartLike) -> Message: + """Create a system message. + + >>> ai.system_message("You are a helpful robot.") + """ + return Message(role="system", parts=_coerce_parts(content)) + + +def user_message(*content: PartLike) -> Message: + """Create a user message from strings and/or Part objects. + + >>> ai.user_message("Describe this image:", ai.file_part(url)) + """ + return Message(role="user", parts=_coerce_parts(content)) + + +def assistant_message(*content: PartLike) -> Message: + """Create an assistant message from strings and/or Part objects. + + >>> ai.assistant_message(ai.thinking("hmm"), "Hello!") + """ + return Message(role="assistant", parts=_coerce_parts(content)) + + +def file_part( + data: str | bytes, + *, + media_type: str | None = None, + filename: str | None = None, +) -> FilePart: + """Create a :class:`FilePart` from a URL string or raw bytes. + + Dispatches to :meth:`FilePart.from_url` (for ``str``) or + :meth:`FilePart.from_bytes` (for ``bytes``), with automatic + media-type detection. + """ + if isinstance(data, str): + return FilePart.from_url(data, media_type=media_type) + return FilePart.from_bytes(data, media_type=media_type, filename=filename) + + +def thinking(text: str, *, signature: str | None = None) -> ReasoningPart: + """Create a :class:`ReasoningPart`. + + Useful for replaying conversation history that includes model reasoning. + """ + return ReasoningPart(text=text, signature=signature) + + +def _tool_results_from_messages(messages: list[Message]) -> list[ToolResultPart]: + parts: list[ToolResultPart] = [] + for message in messages: + if message.role != "tool": + raise TypeError(f"Expected tool message, got role={message.role!r}") + for part in message.parts: + if not isinstance(part, ToolResultPart): + raise TypeError( + "tool_message() only accepts tool messages containing " + "ToolResultPart parts" + ) + parts.append(part) + return parts + + +@overload +def tool_message(*messages: Message | list[Message]) -> Message: ... + + +@overload +def tool_message(*parts: ToolResultPart) -> Message: ... + + +@overload +def tool_message( + *, + tool_call_id: str, + result: Any = None, + tool_name: str = "", + is_error: bool = False, +) -> Message: ... + + +def tool_message( + *items: Message | ToolResultPart | list[Message], + tool_call_id: str | None = None, + result: Any = None, + tool_name: str = "", + is_error: bool = False, +) -> Message: + """Create or merge a tool-result message. + + >>> part = ai.tool_result("tc-1", result=72, tool_name="weather") + >>> ai.tool_message(part) + >>> ai.tool_message(tool_call_id="tc-1", result=72, tool_name="weather") + """ + if tool_call_id is None and (result is not None or tool_name or is_error): + raise TypeError( + "tool_message() keyword tool-result fields require tool_call_id" + ) + + if tool_call_id is not None: + if items: + raise TypeError( + "tool_message() cannot mix keyword tool-result fields with " + "positional messages or ToolResultPart values" + ) + return Message( + role="tool", + parts=[ + tool_result( + tool_call_id, + result=result, + tool_name=tool_name, + is_error=is_error, + ) + ], + ) + + if not items: + raise TypeError("tool_message() requires at least one tool message or result") + + flattened_messages: list[Message] = [] + result_parts: list[ToolResultPart] = [] + saw_message = False + saw_result_part = False + + for item in items: + if isinstance(item, list): + saw_message = True + flattened_messages.extend(item) + elif isinstance(item, Message): + saw_message = True + flattened_messages.append(item) + elif isinstance(item, ToolResultPart): + saw_result_part = True + result_parts.append(item) + else: + raise TypeError( + "tool_message() only accepts tool messages, lists of tool " + "messages, or ToolResultPart values" + ) + + if saw_message and saw_result_part: + raise TypeError( + "tool_message() cannot mix tool messages with ToolResultPart values" + ) + + if saw_message: + merged_parts: list[Part] = [] + merged_parts.extend(_tool_results_from_messages(flattened_messages)) + return Message(role="tool", parts=merged_parts) + + tool_parts: list[Part] = [] + tool_parts.extend(result_parts) + return Message(role="tool", parts=tool_parts) + + +def tool_result( + tool_call_id: str, + *, + result: Any = None, + tool_name: str = "", + is_error: bool = False, +) -> ToolResultPart: + """Create a :class:`ToolResultPart`. + + >>> ai.tool_result("tc-1", result={"temp": 72}, tool_name="weather") + """ + return ToolResultPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result, + is_error=is_error, + ) diff --git a/src/ai/types/events.py b/src/ai/types/events.py index 1c55370d..c9267072 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -3,6 +3,7 @@ import pydantic from . import messages +from . import usage as usage_ # we're using pydantic because events are crossing # serialization border in the case of durable execution @@ -19,7 +20,7 @@ class End(pydantic.BaseModel): class MessageStart(pydantic.BaseModel): - message: messages.Message + message: messages.Message | None = None kind: Literal["message_start"] = "message_start" model_config = pydantic.ConfigDict(frozen=True) @@ -27,30 +28,77 @@ class MessageStart(pydantic.BaseModel): class MessageEnd(pydantic.BaseModel): message: messages.Message + usage: usage_.Usage | None = None kind: Literal["message_end"] = "message_end" model_config = pydantic.ConfigDict(frozen=True) -class PartStart(pydantic.BaseModel): - part: messages.Part +class TextStart(pydantic.BaseModel): + block_id: str = "" + + kind: Literal["text_start"] = "text_start" + model_config = pydantic.ConfigDict(frozen=True) + + +class TextDelta(pydantic.BaseModel): + chunk: str + block_id: str = "" + + kind: Literal["text_delta"] = "text_delta" + model_config = pydantic.ConfigDict(frozen=True) + + +class TextEnd(pydantic.BaseModel): + block_id: str = "" + + kind: Literal["text_end"] = "text_end" + model_config = pydantic.ConfigDict(frozen=True) + + +class ReasoningStart(pydantic.BaseModel): + block_id: str = "" + + kind: Literal["reasoning_start"] = "reasoning_start" + model_config = pydantic.ConfigDict(frozen=True) + + +class ReasoningDelta(pydantic.BaseModel): + chunk: str + block_id: str = "" + + kind: Literal["reasoning_delta"] = "reasoning_delta" + model_config = pydantic.ConfigDict(frozen=True) + + +class ReasoningEnd(pydantic.BaseModel): + block_id: str = "" + signature: str | None = None + + kind: Literal["reasoning_end"] = "reasoning_end" + model_config = pydantic.ConfigDict(frozen=True) + + +class ToolStart(pydantic.BaseModel): + tool_call_id: str = "" + tool_name: str = "" - kind: Literal["part_start"] = "part_start" + kind: Literal["tool_start"] = "tool_start" model_config = pydantic.ConfigDict(frozen=True) -class PartDelta(pydantic.BaseModel): - part: messages.Part +class ToolDelta(pydantic.BaseModel): chunk: str + tool_call_id: str = "" - kind: Literal["part_delta"] = "part_delta" + kind: Literal["tool_delta"] = "tool_delta" model_config = pydantic.ConfigDict(frozen=True) -class PartEnd(pydantic.BaseModel): - part: messages.Part +class ToolEnd(pydantic.BaseModel): + tool_call_id: str = "" - kind: Literal["part_end"] = "part_end" + kind: Literal["tool_end"] = "tool_end" model_config = pydantic.ConfigDict(frozen=True) @@ -69,9 +117,15 @@ class HookResolution(pydantic.BaseModel): | End | MessageStart | MessageEnd - | PartStart - | PartDelta - | PartEnd + | TextStart + | TextDelta + | TextEnd + | ReasoningStart + | ReasoningDelta + | ReasoningEnd + | ToolStart + | ToolDelta + | ToolEnd | HookSuspention | HookResolution, pydantic.Field(discriminator="kind"), diff --git a/src/ai/types/group.py b/src/ai/types/group.py new file mode 100644 index 00000000..942003c8 --- /dev/null +++ b/src/ai/types/group.py @@ -0,0 +1,222 @@ +import dataclasses + +from ....types import messages as messages_ + + +@dataclasses.dataclass +class TextStart: + block_id: str + + +@dataclasses.dataclass +class TextDelta: + block_id: str + delta: str + + +@dataclasses.dataclass +class TextEnd: + block_id: str + + +@dataclasses.dataclass +class ReasoningStart: + block_id: str + + +@dataclasses.dataclass +class ReasoningDelta: + block_id: str + delta: str + + +@dataclasses.dataclass +class ReasoningEnd: + block_id: str + signature: str | None = None + + +@dataclasses.dataclass +class ToolStart: + tool_call_id: str + tool_name: str + + +@dataclasses.dataclass +class ToolArgsDelta: + tool_call_id: str + delta: str + + +@dataclasses.dataclass +class ToolEnd: + tool_call_id: str + + +@dataclasses.dataclass +class FileEvent: + """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" + + block_id: str + media_type: str + data: str # base64 string or data-URL from the gateway + + +@dataclasses.dataclass +class MessageDone: + finish_reason: str | None = None + usage: messages_.Usage | None = None + + +StreamEvent = ( + TextStart + | TextDelta + | TextEnd + | ReasoningStart + | ReasoningDelta + | ReasoningEnd + | ToolStart + | ToolArgsDelta + | ToolEnd + | FileEvent + | MessageDone +) + + +@dataclasses.dataclass +class StreamHandler: + """ + Accumulates LLM adapter events and produces Messages with stateful parts. + + This is the normalization layer between LLM adapters and the rest of the system. + Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, + updated in place as events stream in. Each event carries the just-constructed + frozen part snapshot, so consumers never need to look parts up by id. + """ + + message_id: str = dataclasses.field(default_factory=messages_.generate_id) + + # Single source of truth for part state, keyed by id. Insertion order + # preserves provider emission order. + _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) + + # Active tracking + _active_text_id: str | None = None + _active_reasoning_id: str | None = None + _active_tool_ids: set[str] = dataclasses.field(default_factory=set) + + _is_done: bool = False + _usage: messages_.Usage | None = None + + def handle_event(self, event: StreamEvent) -> messages_.Message: + """Process event and return current Message state.""" + + # Sidecar events for this yield (reset each call). + stream_events: list[messages_.StreamEvent] = [] + + match event: + case TextStart(block_id=bid): + part: messages_.Part = messages_.TextPart(id=bid, text="") + self._current_parts[bid] = part + self._active_text_id = bid + stream_events.append(messages_.PartOpened(part=part)) + + case TextDelta(block_id=bid, delta=d): + existing = self._current_parts[bid] + assert isinstance(existing, messages_.TextPart) + part = messages_.TextPart(id=bid, text=existing.text + d) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) + + case TextEnd(block_id=bid): + if self._active_text_id == bid: + self._active_text_id = None + stream_events.append( + messages_.PartClosed(part=self._current_parts[bid]) + ) + + case ReasoningStart(block_id=bid): + part = messages_.ReasoningPart(id=bid, text="") + self._current_parts[bid] = part + self._active_reasoning_id = bid + stream_events.append(messages_.PartOpened(part=part)) + + case ReasoningDelta(block_id=bid, delta=d): + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, + text=existing.text + d, + signature=existing.signature, + ) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) + + case ReasoningEnd(block_id=bid, signature=sig): + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, text=existing.text, signature=sig + ) + self._current_parts[bid] = part + if self._active_reasoning_id == bid: + self._active_reasoning_id = None + stream_events.append(messages_.PartClosed(part=part)) + + case ToolStart(tool_call_id=tcid, tool_name=name): + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=tcid, + tool_name=name, + tool_args="", + ) + self._current_parts[tcid] = part + self._active_tool_ids.add(tcid) + stream_events.append(messages_.PartOpened(part=part)) + + case ToolArgsDelta(tool_call_id=tcid, delta=d): + existing = self._current_parts[tcid] + assert isinstance(existing, messages_.ToolCallPart) + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=existing.tool_call_id, + tool_name=existing.tool_name, + tool_args=existing.tool_args + d, + ) + self._current_parts[tcid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) + + case ToolEnd(tool_call_id=tcid): + self._active_tool_ids.discard(tcid) + stream_events.append( + messages_.PartClosed(part=self._current_parts[tcid]) + ) + + case FileEvent(block_id=bid, media_type=mt, data=d): + self._current_parts[bid] = messages_.FilePart( + id=bid, data=d, media_type=mt + ) + + case MessageDone(usage=usage): + self._is_done = True + self._usage = usage + self._active_text_id = None + self._active_reasoning_id = None + self._active_tool_ids.clear() + + return self._build_message(stream_events) + + def _build_message( + self, + stream_events: list[messages_.StreamEvent], + ) -> messages_.Message: + return messages_.Message( + id=self.message_id, + role="assistant", + parts=list(self._current_parts.values()), + usage=self._usage if self._is_done else None, + stream=messages_.StreamState( + new_events=stream_events, + is_done=self._is_done, + ), + ) diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py new file mode 100644 index 00000000..9f737101 --- /dev/null +++ b/src/ai/types/integrity.py @@ -0,0 +1,261 @@ +import json +import logging +from typing import Literal + +from . import builders +from . import messages as messages_ + +logger = logging.getLogger(__name__) + +Mode = Literal["strict", "auto"] + +IssueKind = Literal[ + "duplicate-tool-call", + "duplicate-tool-result", + "internal-part", + "invalid-tool-args", + "orphaned-tool-call", + "orphaned-tool-result", + "internal-message", +] + + +class IntegrityError(ValueError): + def __init__(self, issues: list[IssueKind]) -> None: + self.issues = issues + super().__init__( + f"Message history has {len(issues)} issue(s): " + ", ".join(issues) + ) + + +# used for stripping internal parts +_LLM_PART_TYPES = ( + messages_.TextPart, + messages_.ToolCallPart, + messages_.ToolResultPart, + messages_.ReasoningPart, + messages_.FilePart, +) + + +def _clean_messages( + messages: list[messages_.Message], mode: Mode +) -> tuple[list[messages_.Message], list[IssueKind]]: + """Strip internal messages, fix broken tool args""" + + issues: list[IssueKind] = [] + result: list[messages_.Message] = [] + + for msg in messages: + # 1. drop internal messages emitted by hooks + if msg.role == "internal": + issues.append("internal-message") + if mode == "strict": + result.append(msg) + continue + + parts: list[messages_.Part] = list(msg.parts) + changed = False + + # 2. strip everything that isn't an LLM part + kept: list[messages_.Part] = [ + p for p in parts if isinstance(p, _LLM_PART_TYPES) + ] + if len(kept) < len(parts): + issues.append("internal-part") + if mode == "auto": + parts = kept + changed = True + + # 3. ensure tool args are json-decodable + new_parts: list[messages_.Part] = [] + for part in parts: + if isinstance(part, messages_.ToolCallPart): + try: + json.loads(part.tool_args) + except (json.JSONDecodeError, TypeError): + if mode == "auto": + part = part.model_copy(update={"tool_args": "{}"}) + issues.append("invalid-tool-args") + changed = True + new_parts.append(part) + + if changed and mode == "auto": + parts = new_parts + + # 4. drop empty messages + if mode == "auto" and not parts: + continue + + if changed and mode == "auto": + # messages are immutable so we have to do this + result.append(msg.model_copy(update={"parts": parts})) + else: + result.append(msg) + + return result, issues + + +def _validate_tool_ids(messages: list[messages_.Message]) -> list[IssueKind]: + """Check for fatal issues: duplicate tool ids, orphaned tool results.""" + + issues: list[IssueKind] = [] + seen_call_ids: set[str] = set() + seen_result_ids: set[str] = set() + pending_call_ids: set[str] = set() + + duplicate_call = False + duplicate_result = False + orphaned_result = False + + for msg in messages: + if msg.role in ("user", "assistant") and pending_call_ids: + # result should have been in a tool message before this + # if it wasn't then it's a stray call, will be auto-fixed later + pending_call_ids.clear() + + if msg.role == "assistant": + # check if tool call is duplicate + # if not, mark it and append it to pending + for part in msg.parts: + if not isinstance(part, messages_.ToolCallPart): + continue + if part.tool_call_id in seen_call_ids: + duplicate_call = True + else: + seen_call_ids.add(part.tool_call_id) + pending_call_ids.add(part.tool_call_id) + + elif msg.role == "tool": + # check that this tool result is not duplicate and that + # there's a pending call from previous assistant message + for part in msg.parts: + if not isinstance(part, messages_.ToolResultPart): + continue + if part.tool_call_id in seen_result_ids: + duplicate_result = True + else: + seen_result_ids.add(part.tool_call_id) + if part.tool_call_id not in pending_call_ids: + orphaned_result = True + continue + pending_call_ids.remove(part.tool_call_id) + + if duplicate_call: + issues.append("duplicate-tool-call") + if duplicate_result: + issues.append("duplicate-tool-result") + if orphaned_result: + issues.append("orphaned-tool-result") + + return issues + + +def _fix_missing_results( + messages: list[messages_.Message], mode: Mode +) -> tuple[list[messages_.Message], list[IssueKind]]: + """Insert fake error results for stray tool calls.""" + issues: list[IssueKind] = [] + result: list[messages_.Message] = [] + + # 1. collect all result ids + answered: set[str] = set() + for msg in messages: + if msg.role == "tool": + for part in msg.parts: + if isinstance(part, messages_.ToolResultPart): + answered.add(part.tool_call_id) + + # pending tool calls from the current assistant turn + pending: dict[str, messages_.ToolCallPart] = {} + + def _flush_pending() -> None: + if not pending: + return + issues.append("orphaned-tool-call") + if mode == "auto": + synthetic = builders.tool_message( + *( + messages_.ToolResultPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + result="Tool result not available", + is_error=True, + ) + for tc in pending.values() + ) + ) + result.append(synthetic) + + for msg in messages: + # if we're seeing a user / assistant message, then + # all pending tool calls are strays, because their results + # should have followed immediately after in a tool message + if msg.role in ("user", "assistant") and pending: + _flush_pending() + pending.clear() + + # 2. track calls + if msg.role == "assistant": + for part in msg.parts: + if ( + isinstance(part, messages_.ToolCallPart) + and part.tool_call_id not in answered + ): + pending[part.tool_call_id] = part + result.append(msg) + # 3. match results with calls + elif msg.role == "tool": + for part in msg.parts: + if isinstance(part, messages_.ToolResultPart): + pending.pop(part.tool_call_id, None) + result.append(msg) + else: + result.append(msg) + + _flush_pending() + + return result, issues + + +def prepare_messages( + messages: list[messages_.Message], + *, + mode: Mode = "auto", +) -> list[messages_.Message]: + """Fix and validate message list. + + ``"auto"`` (default) -- silently fixes recoverable issues (signal + messages, internal parts, invalid tool args, missing tool results). + ``"strict"`` -- collects every recoverable issue and raises + :class:`IntegrityError`. + + Duplicate tool-call IDs, duplicate tool-result IDs, and orphaned + tool results always raise :class:`IntegrityError` regardless of mode. + + Always returns a **new** list; never mutates the input. + """ + issues: list[IssueKind] = [] + + result, phase1_issues = _clean_messages(list(messages), mode) + issues.extend(phase1_issues) + + # never auto-fixed + fatal_issues = _validate_tool_ids(result) + issues.extend(fatal_issues) + + if not fatal_issues: + result, phase3_issues = _fix_missing_results(result, mode) + issues.extend(phase3_issues) + + if fatal_issues or (mode == "strict" and issues): + raise IntegrityError(issues) + + if issues: + logger.warning( + "Auto-fixed %d message issue(s): %s", + len(issues), + ", ".join(issues), + ) + + return result diff --git a/src/ai/types/proto.py b/src/ai/types/proto.py index 30d64bc7..e389e4d1 100644 --- a/src/ai/types/proto.py +++ b/src/ai/types/proto.py @@ -1,6 +1,7 @@ from collections.abc import AsyncGenerator from typing import Any, Protocol, runtime_checkable +from . import events as events_ from . import messages, usage @@ -25,7 +26,7 @@ class StreamResultLike(Protocol): The easiest way is ``StreamResult.from_generator(gen)``. """ - def __aiter__(self) -> AsyncGenerator[messages.Message]: ... + def __aiter__(self) -> AsyncGenerator[events_.Event]: ... @property def text(self) -> str: ... diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py new file mode 100644 index 00000000..b64d648d --- /dev/null +++ b/src/ai/types/stream.py @@ -0,0 +1,96 @@ +from collections.abc import AsyncGenerator +from typing import Any, Self + +from . import events as events_ +from . import messages +from . import usage as usage_ + + +class StreamResult: + """Wrapper around an event stream. Async-iterable; collects the final result. + + Yields :class:`~ai.types.events.Event` objects. After iteration, + convenience properties (``.text``, ``.tool_calls``, ``.usage``, + ``.message``) are available — they delegate to the ``MessageEnd`` + event's ``message``. + + One ``StreamResult`` represents one turn: a single LLM request and + its response. + """ + + def __init__( + self, + gen: AsyncGenerator[events_.Event], + *, + turn_id: str | None = None, + input_messages: list[messages.Message] | None = None, + ) -> None: + self._gen = gen + self._turn_id = turn_id + self._input_messages = input_messages or [] + self._message: messages.Message | None = None + self._usage: usage_.Usage | None = None + + @classmethod + def from_generator(cls, gen: AsyncGenerator[events_.Event]) -> Self: + """Create a :class:`StreamResult` from an async generator of events.""" + return cls(gen) + + def __aiter__(self) -> AsyncGenerator[events_.Event]: + return self._iterate() + + async def _iterate(self) -> AsyncGenerator[events_.Event]: + # Re-emit input messages as MessageStart + MessageEnd event pairs. + for msg in self._input_messages: + if msg.turn_id is None and self._turn_id is not None: + msg = msg.model_copy(update={"turn_id": self._turn_id}) + yield events_.MessageStart(message=msg) + yield events_.MessageEnd(message=msg) + + # Stream adapter events. + async for event in self._gen: + # Capture the final message from MessageEnd. + if isinstance(event, events_.MessageEnd): + self._message = event.message + self._usage = event.usage + yield event + + @property + def turn_id(self) -> str | None: + """The turn id stamped on this stream's response (if any).""" + return self._turn_id + + @property + def message(self) -> messages.Message | None: + """The final assembled message, available after iteration.""" + return self._message + + @property + def text(self) -> str: + if self._message is None: + return "" + return "".join( + p.text for p in self._message.parts if isinstance(p, messages.TextPart) + ) + + @property + def tool_calls(self) -> list[messages.ToolCallPart]: + if self._message is None: + return [] + return [ + p for p in self._message.parts if isinstance(p, messages.ToolCallPart) + ] + + @property + def usage(self) -> usage_.Usage | None: + return self._usage + + @property + def output(self) -> Any: + """Parsed structured output from the final message, if available.""" + if self._message is None: + return None + for p in self._message.parts: + if isinstance(p, messages.StructuredOutputPart): + return p.value + return None From a9e1848c42aed5a6aa6c60f4b2fd9445ddf916b2 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 27 Apr 2026 14:51:33 -0700 Subject: [PATCH 3/3] Do a sweep to land the rest of the friendly PR --- README.md | 18 +- examples/coding-agent/1_raw_stream.py | 17 + examples/fastapi-vite/README.md | 13 +- examples/fastapi-vite/backend/agent.py | 12 +- examples/fastapi-vite/backend/main.py | 28 +- examples/multiagent-textual/client.py | 68 ++-- examples/multiagent-textual/server.py | 59 +++- examples/samples/agent_custom_loop.py | 19 +- examples/samples/agent_hooks.py | 34 +- examples/samples/agent_hooks_serverless.py | 56 ++- examples/samples/agent_nested.py | 13 +- examples/samples/agent_simple.py | 7 +- examples/samples/explicit_client.py | 7 +- examples/samples/inline_image.py | 16 +- examples/samples/mcp_tools.py | 7 +- examples/samples/middleware_simple.py | 13 +- examples/samples/multimodal_input.py | 7 +- examples/samples/stream.py | 7 +- examples/samples/streaming_tool.py | 27 +- examples/samples/structured_output.py | 11 +- examples/samples/tools_schema.py | 12 +- examples/temporal-direct/main.py | 24 +- examples/temporal-middleware/main.py | 20 +- src/ai/__init__.py | 32 ++ src/ai/agents/agent.py | 82 +++-- src/ai/agents/runtime.py | 24 +- src/ai/agents/ui/ai_sdk/_approvals.py | 4 +- src/ai/agents/ui/ai_sdk/_parts.py | 4 +- src/ai/agents/ui/ai_sdk/inbound.py | 4 +- src/ai/agents/ui/ai_sdk/outbound/_state.py | 186 ++++++---- src/ai/agents/ui/ai_sdk/outbound/sse.py | 8 +- src/ai/agents/ui/ai_sdk/outbound/stream.py | 41 +-- src/ai/middleware.py | 37 +- src/ai/models/__init__.py | 11 +- src/ai/models/ai_gateway/stream.py | 6 +- src/ai/models/anthropic/adapter.py | 13 +- src/ai/models/core/api.py | 12 +- src/ai/models/core/helpers/streaming.py | 2 +- src/ai/models/core/params.py | 2 +- src/ai/models/core/proto.py | 12 +- src/ai/models/openai/adapter.py | 26 +- src/ai/types/__init__.py | 2 +- src/ai/types/group.py | 222 ------------ src/ai/types/messages.py | 83 +++-- src/ai/types/proto.py | 3 + src/ai/types/stream.py | 21 +- tests/agents/mcp/test_client.py | 8 +- tests/agents/test_generator_tools.py | 58 +-- tests/agents/test_hooks.py | 50 +-- tests/agents/test_runtime.py | 39 +- tests/agents/ui/ai_sdk/outbound/test_sse.py | 27 +- .../agents/ui/ai_sdk/outbound/test_stream.py | 334 +++++++++--------- tests/agents/ui/ai_sdk/test_approvals.py | 8 +- tests/agents/ui/ai_sdk/test_inbound.py | 2 +- tests/conftest.py | 79 ++++- tests/models/ai_gateway/test_stream.py | 61 ++-- tests/models/core/test_streaming.py | 270 +++++--------- tests/models/test_public_api.py | 50 +-- tests/test_middleware.py | 101 +++--- tests/types/test_integrity.py | 11 +- uv.lock | 2 +- 61 files changed, 1212 insertions(+), 1220 deletions(-) delete mode 100644 src/ai/types/group.py diff --git a/README.md b/README.md index 573f730d..1ca41863 100644 --- a/README.md +++ b/README.md @@ -38,9 +38,9 @@ async def main() -> None: ai.user_message("What's the weather in Tokyo?"), ] - async for msg in agent.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + async for event in agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() @@ -106,18 +106,18 @@ Override the default loop when you need approval gates, routing, or custom orche @agent.loop async def custom(context: ai.Context): while True: - s = await ai.stream( - context.model, context.messages, tools=context.tools - ) - async for msg in s: - yield msg + s = ai.stream(context.model, context.messages, tools=context.tools) + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: return results = [await tc() for tc in tool_calls] - yield ai.tool_message(*results) + tool_msg = ai.tool_message(*results) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) ``` ## Examples diff --git a/examples/coding-agent/1_raw_stream.py b/examples/coding-agent/1_raw_stream.py index 20b47057..45a7de3b 100644 --- a/examples/coding-agent/1_raw_stream.py +++ b/examples/coding-agent/1_raw_stream.py @@ -1,6 +1,23 @@ import ai import asyncio +import inspect +import pydantic +import json + +from typing import get_type_hints + + +def get_schema(fn) -> dict: + sig = inspect.signature(fn) + hints = get_type_hints(fn) + + fields = {} + for name, p in sig.parameters.items(): + t = hints.get(name, str) + default = ... if p.default is inspect.Parameter.empty else p.default + fields[name] = (t, default) + async def main() -> None: model = ai.ai_gateway("anthropic/claude-opus-4.7") diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index baf1c13b..9fff5052 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -16,13 +16,14 @@ to suspend execution whenever the LLM wants to call a tool. The flow is: 1. LLM emits a tool call 2. Backend calls `await ai.hook(...)` with `payload=ai.ToolApproval` -3. The runtime emits a `role="internal"` message containing a pending `HookPart` +3. The runtime emits a `MessageEnd` event containing an internal `HookPart` 4. The frontend renders Approve / Reject buttons via the `` component (from AI Elements) 5. When the user clicks a button, `addToolApprovalResponse()` patches the message and sends a new request with the decision -6. The backend resumes from the checkpoint, calls `ai.resolve_hook(...)`, - and either executes the tool or returns an error tool-result message +6. The backend pre-registers the approval via `ai.resolve_hook(...)` on the + next request, then either executes the tool or returns an error tool-result + message Tool results are appended as separate `role="tool"` messages. The assistant tool-call message remains immutable. @@ -60,7 +61,5 @@ The frontend dev server proxies `/api` requests to the backend at `localhost:800 ## Storage -Checkpoints are persisted to `./data/` as JSON files via `FileStorage`. -The storage backend implements a simple `Storage` protocol — swap in -Redis, Postgres, or any async key-value store by providing a different -implementation. +The demo backend is stateless. The frontend sends the conversation history +and approval responses on each request. diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index b3af15c3..4a5ad165 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -25,7 +25,7 @@ async def talk_to_mothership(question: str) -> str: @chat_agent.loop -async def graph(context: ai.Context) -> AsyncGenerator[ai.Message]: +async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]: """Agent graph with human-in-the-loop tool approval. Loops: stream LLM -> request approval -> execute tools -> repeat. @@ -34,9 +34,9 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Message]: Reject buttons and sends the decision back on the next request. """ while True: - s = await ai.models.stream(context.model, context.messages, tools=context.tools) - async for msg in s: - yield msg + s = ai.models.stream(context.model, context.messages, tools=context.tools) + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -45,7 +45,9 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Message]: results = await asyncio.gather( *(_execute_with_approval(tc) for tc in tool_calls) ) - yield ai.tool_message(*results) + tool_msg = ai.tool_message(*results) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message: diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 3a15a20f..e4e49d5a 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -9,7 +9,6 @@ import fastapi.middleware.cors import fastapi.responses import pydantic -import storage import ai @@ -33,9 +32,6 @@ async def health() -> dict[str, str]: return {"status": "ok"} -file_storage = storage.FileStorage() - - class ChatRequest(pydantic.BaseModel): """Request body for the chat endpoint.""" @@ -47,32 +43,12 @@ class ChatRequest(pydantic.BaseModel): async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: """Handle chat requests and stream responses.""" messages = ai.ai_sdk_ui.to_messages(request.messages) - session_id = request.session_id or "default" - checkpoint_key = f"checkpoint:{session_id}" - - checkpoint = None - saved = await file_storage.get(checkpoint_key) - if saved: - checkpoint = ai.Checkpoint.model_validate(saved) - - durability = ai.EventLogProvider(checkpoint) - result = agent_.chat_agent.run(agent_.MODEL, messages, durability=durability) + result = agent_.chat_agent.run(agent_.MODEL, messages) async def stream_response() -> AsyncGenerator[str]: - async for chunk in ai.ai_sdk_ui.to_sse_stream(result): + async for chunk in ai.ai_sdk_ui.to_sse(result): yield chunk - # Persist checkpoint so interrupted runs (approval hooks with - # interrupt_loop=True) can resume on re-entry. Clean up when - # the run completes without pending hooks. - cp = durability.checkpoint() - if cp.steps and not cp.hooks: - # Steps recorded but no hooks resolved — the run was likely - # interrupted by an approval hook. Save for replay. - await file_storage.put(checkpoint_key, cp.model_dump()) - else: - await file_storage.delete(checkpoint_key) - return fastapi.responses.StreamingResponse( stream_response(), headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 819e9f74..d6123af5 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -13,6 +13,7 @@ import json import rich.text +import pydantic import textual import textual.app import textual.containers @@ -111,6 +112,8 @@ def __init__(self) -> None: self._hook_queue: asyncio.Queue[ai.HookPart] = asyncio.Queue() self._current_hook: ai.HookPart | None = None self._ws: websockets.ClientConnection | None = None + self._event_adapter = pydantic.TypeAdapter(ai.Event) + self._current_label = "unknown" def compose(self) -> textual.app.ComposeResult: yield AgentPanel("mothership", "mothership") @@ -147,20 +150,45 @@ async def run_websocket(self) -> None: self._on_run_complete() break - msg = ai.Message.model_validate(data) - self._handle_message(msg) + event = self._event_adapter.validate_python(data) + self._handle_event(event) except (ConnectionRefusedError, OSError) as exc: self._set_input_placeholder(f"connection failed: {exc}") # ------------------------------------------------------------------ - # Message routing + # Event routing # ------------------------------------------------------------------ - def _handle_message(self, msg: ai.Message) -> None: - label = msg.source_label or "unknown" + def _handle_event(self, event: ai.Event) -> None: + if isinstance(event, ai.MessageStart) and event.message is not None: + self._current_label = event.message.source_label or "unknown" + panel = self._get_panel(self._current_label) + if panel is not None and panel.status == "idle": + panel.status = "streaming..." + return + + if isinstance(event, ai.TextDelta): + panel = self._get_panel(self._current_label) + if panel is not None: + panel.append_text(event.chunk) + return + + if isinstance(event, ai.ReasoningDelta | ai.ToolDelta): + panel = self._get_panel(self._current_label) + if panel is not None: + panel.append_text(event.chunk, style="dim") + return - if (hook_part := msg.get_hook_part()) is not None: + if not isinstance(event, ai.MessageEnd): + return + + msg = event.message + label = msg.source_label or self._current_label + + hook_parts = [p for p in msg.parts if isinstance(p, ai.HookPart)] + if hook_parts: + hook_part = hook_parts[0] if hook_part.status == "pending": self._on_hook_pending(hook_part) return @@ -172,28 +200,12 @@ def _handle_message(self, msg: ai.Message) -> None: if panel is None: return - # Mark panel as actively streaming - if panel.status == "idle": - panel.status = "streaming..." - - # Text / reasoning / tool-arg deltas - for ev in msg.deltas: - match ev.part: - case ai.TextPart(): - panel.append_text(ev.chunk) - case ai.ReasoningPart(): - panel.append_text(ev.chunk, style="dim") - case ai.ToolCallPart(): - panel.append_text(ev.chunk, style="dim") - - # Completed message — show tool calls and results - if msg.is_done: - for part in msg.parts: - match part: - case ai.ToolCallPart(tool_name=name, tool_args=args): - panel.append_line(f"> {name}({args})") - case ai.ToolResultPart(tool_name=name, result=result): - panel.append_line(f"< {name} = {result}") + for part in msg.parts: + match part: + case ai.ToolCallPart(tool_name=name, tool_args=args): + panel.append_line(f"> {name}({args})") + case ai.ToolResultPart(tool_name=name, result=result): + panel.append_line(f"< {name} = {result}") # ------------------------------------------------------------------ # Hook lifecycle diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index 37a5bf8d..f23f76e3 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -78,11 +78,11 @@ def _gated_agent( gated = ai.agent(tools=tools) @gated.loop - async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: while True: - s = await ai.stream(context.model, context.messages, tools=context.tools) - async for msg in s: - yield msg + s = ai.stream(context.model, context.messages, tools=context.tools) + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -110,7 +110,9 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: else: results.append(await tc()) - yield ai.tool_message(*results) + tool_msg = ai.tool_message(*results) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) return gated @@ -136,7 +138,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: @orchestrator.loop -async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: +async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: """Run two gated agents in parallel, then summarise their results.""" query = context.messages[-1].text @@ -175,7 +177,7 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: combined = f"Mothership: {r1}\nData centers: {r2}" # Fan in: summarise. - s = await ai.stream( + s = ai.stream( context.model, [ ai.system_message( @@ -184,8 +186,25 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: ai.user_message(combined), ], ) - async for msg in s: - yield msg.model_copy(update={"agent": "summary"}) + async for event in s: + if isinstance(event, ai.MessageEnd): + yield event.model_copy( + update={ + "message": event.message.model_copy( + update={"source_label": "summary"} + ) + } + ) + elif isinstance(event, ai.MessageStart) and event.message is not None: + yield event.model_copy( + update={ + "message": event.message.model_copy( + update={"source_label": "summary"} + ) + } + ) + else: + yield event # --------------------------------------------------------------------------- @@ -196,11 +215,18 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: def _normalise_message(data: dict[str, Any]) -> dict[str, Any]: """Ensure ToolResultPart.result is always a dict for safe deserialisation.""" for part in data.get("parts", []): - if part.get("type") == "tool_result" and isinstance(part.get("result"), str): + if part.get("kind") == "tool_result" and isinstance(part.get("result"), str): part["result"] = {"value": part["result"]} return data +def _normalise_event(data: dict[str, Any]) -> dict[str, Any]: + message = data.get("message") + if isinstance(message, dict): + data["message"] = _normalise_message(message) + return data + + # --------------------------------------------------------------------------- # WebSocket endpoint # --------------------------------------------------------------------------- @@ -232,12 +258,17 @@ async def read_resolutions() -> None: reader = asyncio.create_task(read_resolutions()) try: - async for msg in result: - data = _normalise_message(msg.model_dump()) + async for event in result: + data = _normalise_event(event.model_dump()) await websocket.send_json(data) - if hook_part := msg.get_hook_part(): - print(f" Hook {hook_part.status}: {hook_part.hook_id}") + if isinstance(event, ai.MessageEnd) and event.message.role == "internal": + hook_parts = [ + p for p in event.message.parts if isinstance(p, ai.HookPart) + ] + if hook_parts: + hook_part = hook_parts[0] + print(f" Hook {hook_part.status}: {hook_part.hook_id}") finally: reader.cancel() with contextlib.suppress(asyncio.CancelledError): diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 5f2bce27..91819ea9 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -25,14 +25,14 @@ async def main() -> None: my_agent = ai.agent(tools=tools) @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: """Stream, execute tools with logging, repeat.""" while True: - s = await ai.models.stream( + s = ai.models.stream( context.model, context.messages, tools=context.tools ) - async for msg in s: - yield msg + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -49,15 +49,16 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: tasks = [tg.create_task(tc()) for tc in tool_calls] # Yield one merged tool-result message — history auto-collects it. - yield ai.tool_message(*(t.result() for t in tasks)) + tool_msg = ai.tool_message(*(t.result() for t in tasks)) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) - async for msg in my_agent.run( + async for event in my_agent.run( model, [ai.user_message("Compare the weather and population of New York and Tokyo.")], ): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index d062bee5..8fe4d887 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -3,7 +3,7 @@ Demonstrates the function-based hook API: - await hook("label", payload=Model) to suspend inside the loop - resolve_hook("label", data) to unblock from outside - - Hook messages arrive with role="internal" + - Hook messages arrive as MessageEnd events with role="internal" """ import asyncio @@ -31,13 +31,13 @@ async def main() -> None: my_agent = ai.agent(tools=[contact_mothership]) @my_agent.loop - async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: while True: - s = await ai.models.stream( + s = ai.models.stream( context.model, context.messages, tools=context.tools ) - async for msg in s: - yield msg + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -66,7 +66,9 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: else: results.append(await tc()) - yield ai.tool_message(*results) + tool_msg = ai.tool_message(*results) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) messages = [ ai.system_message( @@ -75,11 +77,16 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: ai.user_message("When will the robots take over?"), ] - async for msg in my_agent.run(model, messages): - # Hook signals arrive with role="internal" - if msg.role == "internal": - hook_part = msg.get_hook_part() - if hook_part and hook_part.status == "pending": + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + continue + + # Hook signals arrive as internal MessageEnd events. + if isinstance(event, ai.MessageEnd) and event.message.role == "internal": + hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] + hook_part = hook_parts[0] if hook_parts else None + if hook_part is not None and hook_part.status == "pending": answer = input(f"Approve {hook_part.hook_id}? [y/n] ") ai.resolve_hook( hook_part.hook_id, @@ -88,11 +95,6 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: reason="operator decision", ), ) - continue - - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index b0974f5e..7372ca6b 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -1,14 +1,14 @@ """Serverless hook pattern: interrupt_loop=True. Demonstrates the serverless/stateless pattern where the agent run suspends -cleanly when a hook has no resolution, and resumes from a checkpoint on -re-entry with a pre-registered resolution. +cleanly when a hook has no resolution, then re-enters with a pre-registered +resolution. Flow: 1. First run: hook fires, interrupt_loop=True cancels the future, - CancelledError is caught, run ends with a checkpoint. + CancelledError is caught and the run ends. 2. Second run: resolve_hook() pre-registers the answer, agent.run() - replays from checkpoint, hook finds the resolution immediately. + replays from the same input, and hook finds the resolution immediately. """ import asyncio @@ -36,13 +36,13 @@ async def main() -> None: my_agent = ai.agent(tools=[delete_file]) @my_agent.loop - async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: while True: - s = await ai.models.stream( + s = ai.models.stream( context.model, context.messages, tools=context.tools ) - async for msg in s: - yield msg + async for event in s: + yield event tool_calls = context.resolve(s.tool_calls) if not tool_calls: @@ -59,7 +59,6 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: ) except asyncio.CancelledError: # No resolution available — bail out cleanly. - # The checkpoint captures everything up to this point. return if confirmation.approved: @@ -74,7 +73,9 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: ) ) - yield ai.tool_message(*results) + tool_msg = ai.tool_message(*results) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) messages = [ ai.system_message("Delete files when asked. Always use the delete_file tool."), @@ -85,39 +86,34 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: print("--- Run 1: hook fires, no resolution, run suspends ---") pending_hook_labels: list[str] = [] - durability = ai.EventLogProvider() - async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "internal": - hook_part = msg.get_hook_part() - if hook_part and hook_part.status == "pending": + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.MessageEnd) and event.message.role == "internal": + hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] + hook_part = hook_parts[0] if hook_parts else None + if hook_part is not None and hook_part.status == "pending": pending_hook_labels.append(hook_part.hook_id) print( f" Hook pending: {hook_part.hook_id}" f" (metadata={hook_part.metadata})" ) - else: - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) - saved_checkpoint = durability.checkpoint() - print(f"\n Checkpoint saved: {len(saved_checkpoint.steps)} steps\n") + print("\n Run interrupted; approval will be pre-registered for re-entry.\n") # -- Second run: pre-register resolution, replay from checkpoint -- print("--- Run 2: pre-register approval, resume from checkpoint ---") for label in pending_hook_labels: ai.resolve_hook(label, Confirmation(approved=True, reason="user approved")) - durability = ai.EventLogProvider(saved_checkpoint) - async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "internal": - hook_part = msg.get_hook_part() - if hook_part: + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.MessageEnd) and event.message.role == "internal": + hook_parts = [p for p in event.message.parts if isinstance(p, ai.HookPart)] + hook_part = hook_parts[0] if hook_parts else None + if hook_part is not None: print(f" Hook {hook_part.status}: {hook_part.hook_id}") - else: - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_nested.py b/examples/samples/agent_nested.py index 9a6b9dc6..c5a11e87 100644 --- a/examples/samples/agent_nested.py +++ b/examples/samples/agent_nested.py @@ -21,7 +21,7 @@ async def get_facts(topic: str) -> str: # This tool is an async generator — it streams intermediate messages # through the runtime sink, then returns the final result. @ai.tool # type: ignore[arg-type] # async generator tools are supported at runtime -async def research(topic: str) -> AsyncGenerator[ai.Message]: +async def research(topic: str) -> AsyncGenerator[ai.Event]: """Research a topic in depth using a sub-agent.""" researcher = ai.agent(tools=[get_facts]) @@ -30,8 +30,8 @@ async def research(topic: str) -> AsyncGenerator[ai.Message]: ai.user_message(f"Research: {topic}"), ] - async for msg in researcher.run(model, messages): - yield msg + async for event in researcher.run(model, messages): + yield event async def main() -> None: @@ -44,10 +44,9 @@ async def main() -> None: ai.user_message("Tell me about Mars."), ] - async for msg in orchestrator.run(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in orchestrator.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_simple.py b/examples/samples/agent_simple.py index 1f62588e..4c376d0a 100644 --- a/examples/samples/agent_simple.py +++ b/examples/samples/agent_simple.py @@ -21,10 +21,9 @@ async def main() -> None: ai.user_message("What's the weather in Tokyo?"), ] - async for msg in my_agent.run(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index 2fc2c5d6..9de016b9 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -19,10 +19,9 @@ async def main() -> None: try: - async for msg in await ai.models.stream(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in ai.models.stream(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() finally: await client.aclose() diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 2d2228c7..39b0a01d 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -1,8 +1,8 @@ """Inline image generation — LLM that outputs images alongside text. Models like Gemini 3 Pro Image can generate images as part of their -language model response. The images arrive as FileParts in the streamed -Message. +language model response. The images arrive as FileParts on the final +MessageEnd message. """ import asyncio @@ -24,12 +24,12 @@ async def main() -> None: last_msg: ai.Message | None = None - # Stream — text deltas arrive as usual, images arrive as FileParts - async for msg in await ai.stream(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) - last_msg = msg + # Stream — text deltas arrive as events, images arrive on MessageEnd + async for event in ai.stream(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.MessageEnd): + last_msg = event.message print() diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index d5f8329f..74778d1a 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -25,10 +25,9 @@ async def main() -> None: ai.user_message("How do I create middleware in Next.js?"), ] - async for msg in my_agent.run(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index a9fdf933..f317ba3a 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -25,8 +25,8 @@ async def wrap_agent_run(self, call, next): print(f">>> [run] agent starting label={label} tools={len(call.tools)}") t0 = time.perf_counter() - async for msg in next(call): - yield msg + async for event in next(call): + yield event elapsed = time.perf_counter() - t0 print(f"<<< [run] agent finished label={label} {elapsed:.2f}s") @@ -39,7 +39,7 @@ async def wrap_model(self, call, next): result = await next(call) - # The result is a StreamResult — async-iterable of Message snapshots. + # The result is a StreamResult — async-iterable of Event objects. # We return it as-is; the consumer iterates it normally. print("<<< [model] stream started") return result @@ -98,10 +98,9 @@ async def main() -> None: ] print("--- starting agent run ---\n") - async for msg in my_agent.run(model, messages, middleware=[PrintMiddleware()]): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in my_agent.run(model, messages, middleware=[PrintMiddleware()]): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print("\n\n--- done ---") diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 2663ec46..5e78582a 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -20,10 +20,9 @@ async def main() -> None: - async for msg in await ai.stream(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in ai.stream(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/stream.py b/examples/samples/stream.py index 70bdafee..b731e5c3 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -13,10 +13,9 @@ async def main() -> None: - async for msg in await ai.stream(model, messages): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in ai.stream(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 9cf68e0c..59e5b15d 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -1,6 +1,6 @@ """Streaming from inside a tool via an async generator. -An async generator tool yields messages that flow through the runtime +An async generator tool yields events that flow through the runtime sink to the consumer in real time. The final yielded message's text becomes the tool result. """ @@ -12,21 +12,25 @@ @ai.tool # type: ignore[arg-type] # async generator tools are supported at runtime -async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Message]: +async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Event]: """Ask the mothership a question. Streams progress back to the caller.""" for step in ["Connecting...", "Transmitting...", "Awaiting response..."]: - yield ai.Message( + msg = ai.Message( role="assistant", parts=[ai.TextPart(text=step)], source_label="tool_progress", ) + yield ai.MessageStart(message=msg) + yield ai.MessageEnd(message=msg) await asyncio.sleep(0.3) # The final yielded message's text is returned as the tool result. - yield ai.Message( + msg = ai.Message( role="assistant", parts=[ai.TextPart(text="The mothership says: Soon.")], ) + yield ai.MessageStart(message=msg) + yield ai.MessageEnd(message=msg) async def main() -> None: @@ -39,13 +43,14 @@ async def main() -> None: ai.user_message("When will the robots take over?"), ] - async for msg in my_agent.run(model, messages): - if msg.source_label == "tool_progress": - print(f" [{msg.text}]") - else: - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + async for event in my_agent.run(model, messages): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif ( + isinstance(event, ai.MessageEnd) + and event.message.source_label == "tool_progress" + ): + print(f" [{event.message.text}]") print() diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 333841c6..128c1b03 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -21,12 +21,11 @@ class Recipe(pydantic.BaseModel): async def main() -> None: # Stream with structured output — watch JSON arrive, get validated at the end - async for msg in await ai.stream(model, messages, output_type=Recipe): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) - if msg.output: - recipe: Recipe = msg.output + async for event in ai.stream(model, messages, output_type=Recipe): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.MessageEnd) and event.message.output: + recipe: Recipe = event.message.output print(f"\n\nParsed recipe: {recipe.name}") print(f" Ingredients: {', '.join(recipe.ingredients)}") print(f" Prep time: {recipe.prep_time_minutes} min") diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index c10ae4c0..805f8344 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -25,13 +25,11 @@ async def main() -> None: # Stream with tools — the model may emit tool calls - async for msg in await ai.stream(model, messages, tools=[get_weather]): - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) - - if msg.is_done: - for tc in msg.tool_calls: + async for event in ai.stream(model, messages, tools=[get_weather]): + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.MessageEnd): + for tc in event.message.tool_calls: print(f"\nTool call: {tc.tool_name}({tc.tool_args})") print() diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 8039df76..e834fedb 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -100,9 +100,12 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] - s = await ai.models.stream(model, messages, tools=tools) - result = await ai.models.buffer(s) - return LLMResult(message=result.model_dump()) + s = ai.models.stream(model, messages, tools=tools) + async for _event in s: + pass + if s.message is None: + raise RuntimeError("LLM stream ended without a final message") + return LLMResult(message=s.message.model_dump()) # ── Agent with custom loop ─────────────────────────────────────── @@ -115,7 +118,7 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: @weather_agent.loop -async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: +async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: tool_schemas = [ {"name": t.name, "description": t.description, "param_schema": t.param_schema} for t in context.tools @@ -133,7 +136,8 @@ async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: retry_policy=temporalio.common.RetryPolicy(maximum_attempts=3), ) msg = ai.Message.model_validate(result.message) - yield msg + yield ai.MessageStart(message=msg) + yield ai.MessageEnd(message=msg) # 2. No tool calls → done if not msg.tool_calls: @@ -156,7 +160,9 @@ async def run_tool(tc: ai.ToolCallPart) -> ai.ToolResultPart: tasks = [asyncio.ensure_future(run_tool(tc)) for tc in msg.tool_calls] parts = await asyncio.gather(*tasks) - yield ai.tool_message(*parts) + tool_msg = ai.tool_message(*parts) + yield ai.MessageStart(message=tool_msg) + yield ai.MessageEnd(message=tool_msg) # ── Workflow ───────────────────────────────────────────────────── @@ -175,9 +181,9 @@ async def run(self, user_query: str) -> str: ] final_text = "" - async for msg in weather_agent.run(model, messages): - if msg.text: - final_text = msg.text + async for event in weather_agent.run(model, messages): + if isinstance(event, ai.MessageEnd) and event.message.text: + final_text = event.message.text return final_text diff --git a/examples/temporal-middleware/main.py b/examples/temporal-middleware/main.py index 7e8a56fb..d71b514d 100644 --- a/examples/temporal-middleware/main.py +++ b/examples/temporal-middleware/main.py @@ -117,9 +117,12 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] - s = await ai.models.stream(model, messages, tools=tools) - result = await ai.models.buffer(s) - return LLMResult(message=result.model_dump()) + s = ai.models.stream(model, messages, tools=tools) + async for _event in s: + pass + if s.message is None: + raise RuntimeError("LLM stream ended without a final message") + return LLMResult(message=s.message.model_dump()) # ── Middleware ─────────────────────────────────────────────────── @@ -153,8 +156,9 @@ async def wrap_model( ) msg = ai.Message.model_validate(result.message) - async def _single() -> AsyncGenerator[ai.Message]: - yield msg + async def _single() -> AsyncGenerator[ai.Event]: + yield ai.MessageStart(message=msg) + yield ai.MessageEnd(message=msg) return ai.StreamResult.from_generator(_single()) @@ -213,9 +217,9 @@ async def run(self, user_query: str) -> str: mw = TemporalMiddleware(tool_schemas) final_text = "" - async for msg in weather_agent.run(model, messages, middleware=[mw]): - if msg.text: - final_text = msg.text + async for event in weather_agent.run(model, messages, middleware=[mw]): + if isinstance(event, ai.MessageEnd) and event.message.text: + final_text = event.message.text return final_text diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 6f1edb48..840fbcd2 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -32,18 +32,34 @@ # Re-export core types from .types import ( + End, + Event, FilePart, HookPart, + HookResolution, + HookSuspention, Message, + MessageEnd, + MessageStart, Part, + ReasoningDelta, + ReasoningEnd, ReasoningPart, + ReasoningStart, + Start, StreamResultLike, StructuredOutputPart, + TextDelta, + TextEnd, TextPart, + TextStart, ToolCallPart, + ToolDelta, + ToolEnd, ToolLike, ToolResultPart, ToolSchema, + ToolStart, Usage, ) from .types.builders import ( @@ -58,14 +74,30 @@ __all__ = [ # Types (from types/) + "Start", + "End", + "Event", "Message", + "MessageStart", + "MessageEnd", "Part", "TextPart", + "TextStart", + "TextDelta", + "TextEnd", "ToolCallPart", + "ToolStart", + "ToolDelta", + "ToolEnd", "ToolResultPart", "ReasoningPart", + "ReasoningStart", + "ReasoningDelta", + "ReasoningEnd", "FilePart", "HookPart", + "HookSuspention", + "HookResolution", "StructuredOutputPart", "ToolLike", "ToolSchema", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index cd8bd3be..d399cf96 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -199,19 +199,38 @@ def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]: ] +StreamItem = types.Event | types.Message + + class LoopFn(Protocol): - def __call__(self, context: Context) -> AsyncGenerator[types.Message]: ... + def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ... + + +async def _message_events(message: types.Message) -> AsyncGenerator[types.Event]: + yield types.MessageStart(message=message) + yield types.MessageEnd(message=message) -async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: +async def _coerce_events( + source: AsyncIterable[StreamItem], +) -> AsyncGenerator[types.Event]: + async for item in source: + if isinstance(item, types.Message): + async for event in _message_events(item): + yield event + else: + yield item + + +async def _default_loop(context: Context) -> AsyncGenerator[types.Event]: while True: - stream = await models.stream( + stream = models.stream( context.model, context.messages, tools=context.tools, ) - async for message in stream: - yield message + async for event in stream: + yield event tool_calls = context.resolve(stream.tool_calls) if not tool_calls: @@ -225,33 +244,35 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: # Left un-stamped: the tool result is the input of the *next* turn, # so the next stream() call will stamp it with that turn's id. tool_msg = builders.tool_message(*(t.result() for t in tasks)) - yield tool_msg + async for event in _message_events(tool_msg): + yield event async def _collect_messages( - source: AsyncGenerator[types.Message], + source: AsyncIterable[StreamItem], messages: list[types.Message], -) -> AsyncGenerator[types.Message]: - """Intercept yielded messages and collect done ones into *messages*. +) -> AsyncGenerator[types.Event]: + """Intercept yielded events and collect MessageEnd messages into *messages*. This runs on the **producer** side (same coroutine as the loop function), so ``messages`` is always up-to-date by the time the loop reads it for the next model call — avoiding the race that would occur if collection happened on the consumer side of the runtime queue. """ - async for message in source: - if message.is_done: + async for event in _coerce_events(source): + if isinstance(event, types.MessageEnd): + message = event.message for i, existing in enumerate(messages): if existing.id == message.id: messages[i] = message break else: messages.append(message) - yield message + yield event -async def yield_from(source: AsyncIterable[types.Message]) -> str: - """Drain *source*, forwarding each message to the current runtime. +async def yield_from(source: AsyncIterable[StreamItem]) -> str: + """Drain *source*, forwarding each event to the current runtime. Use inside a custom loop to stream messages from a sub-agent to the consumer without adding them to the parent agent's message history:: @@ -269,9 +290,10 @@ async def yield_from(source: AsyncIterable[types.Message]) -> str: """ rt = runtime.get_runtime() last: types.Message | None = None - async for message in source: - await rt.put_message(message) - last = message + async for item in _coerce_events(source): + await rt.put_event(item) + if isinstance(item, types.MessageEnd): + last = item.message return last.text if last else "" @@ -303,8 +325,8 @@ async def run( *, label: str | None = None, middleware: list[middleware_.Middleware] | None = None, - ) -> AsyncGenerator[types.Message]: - """Run the agent loop, yielding messages to the consumer. + ) -> AsyncGenerator[types.Event]: + """Run the agent loop, yielding events to the consumer. Args: model: The model to use for LLM calls. @@ -327,17 +349,31 @@ async def run( async def _real( call: middleware_.AgentRunContext, - ) -> AsyncGenerator[types.Message]: + ) -> AsyncGenerator[types.Event]: context = Context( model=call.model, messages=list(call.messages), tools=call.tools, ) source = _collect_messages(loop_fn(context), context.messages) - async for message in runtime.run(source): + async for event in runtime.run(source): if call.label is not None: - message = message.model_copy(update={"source_label": call.label}) - yield message + event_message: types.Message | None = None + if isinstance(event, types.MessageEnd) or ( + isinstance(event, types.MessageStart) + and event.message is not None + ): + event_message = event.message + + if event_message is not None: + event = event.model_copy( + update={ + "message": event_message.model_copy( + update={"source_label": call.label} + ) + } + ) + yield event # Activate middleware for this run (and everything it calls). # When middleware is None (default), inherit the parent's middleware diff --git a/src/ai/agents/runtime.py b/src/ai/agents/runtime.py index 5865a639..cb8769ce 100644 --- a/src/ai/agents/runtime.py +++ b/src/ai/agents/runtime.py @@ -12,7 +12,7 @@ class Runtime: - """Central message queue. Producers put messages, run() yields them.""" + """Central event queue. Producers put events, run() yields them.""" class _Sentinel: pass @@ -20,16 +20,20 @@ class _Sentinel: _SENTINEL = _Sentinel() def __init__(self) -> None: - self._message_queue: asyncio.Queue[types.Message | Runtime._Sentinel] = ( + self._event_queue: asyncio.Queue[types.Event | Runtime._Sentinel] = ( asyncio.Queue() ) self._hook_labels: set[str] = set() + async def put_event(self, event: types.Event) -> None: + await self._event_queue.put(event) + async def put_message(self, message: types.Message) -> None: - await self._message_queue.put(message) + await self.put_event(types.MessageStart(message=message)) + await self.put_event(types.MessageEnd(message=message)) async def signal_done(self) -> None: - await self._message_queue.put(self._SENTINEL) + await self._event_queue.put(self._SENTINEL) def track_hook_label(self, label: str) -> None: """Register a hook label for cleanup when the run ends.""" @@ -57,9 +61,9 @@ async def _stop_when_done(runtime: Runtime, task: Awaitable[None]) -> None: async def run( - source: AsyncIterable[types.Message], -) -> AsyncGenerator[types.Message]: - """Run *source* and yield every message that gets put into the Runtime queue.""" + source: AsyncIterable[types.Event], +) -> AsyncGenerator[types.Event]: + """Run *source* and yield every event that gets put into the Runtime queue.""" rt = Runtime() token = _runtime.set(rt) @@ -68,15 +72,15 @@ async def run( mcp_token = mcp_client._pool.set(mcp_pool) async def _drain() -> None: - async for message in source: - await rt.put_message(message) + async for event in source: + await rt.put_event(event) try: async with asyncio.TaskGroup() as tg: tg.create_task(_stop_when_done(rt, _drain())) while True: - item = await rt._message_queue.get() + item = await rt._event_queue.get() if isinstance(item, Runtime._Sentinel): return yield item diff --git a/src/ai/agents/ui/ai_sdk/_approvals.py b/src/ai/agents/ui/ai_sdk/_approvals.py index f4e98d79..8ad0d10c 100644 --- a/src/ai/agents/ui/ai_sdk/_approvals.py +++ b/src/ai/agents/ui/ai_sdk/_approvals.py @@ -6,13 +6,15 @@ from __future__ import annotations +from typing import Any + from ....types import messages as messages_ from ...hooks import TOOL_APPROVAL_HOOK_TYPE _PREFIX = "approve_" -def tool_call_id_for(hook_part: messages_.HookPart) -> str | None: +def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: return None diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py index 9d4d6ecc..ae7c85a6 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -1,8 +1,8 @@ """Shared conversions between internal Part objects and UIMessagePart objects. Used by ``outbound.history`` to reconstruct UIMessages from persisted -``ai.Message`` lists. The live outbound stream does not use these — it -emits wire-protocol deltas directly from ``Message.stream.new_events``. +``ai.Message`` lists. The live outbound stream does not use these; it +emits wire-protocol deltas directly from event streams. """ from __future__ import annotations diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index 0dcadea2..9c256773 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -59,7 +59,7 @@ def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None: return normalized -def _approval_hook_part(tp: ui_message.UIToolPart) -> messages_.HookPart | None: +def _approval_hook_part(tp: ui_message.UIToolPart) -> messages_.HookPart[Any] | None: """Reconstruct approval hook state from a UI tool part when possible.""" approval = tp.approval if approval is None: @@ -241,7 +241,7 @@ def _parse( for ui_msg in ui_messages: assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] - hook_parts: list[messages_.HookPart] = [] + hook_parts: list[messages_.HookPart[Any]] = [] for part in ui_msg.parts: match part: diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index 14b26e45..e6c6eed3 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -1,14 +1,10 @@ -"""Stream state bookkeeping for the live outbound walk. - -Owns message/step boundary logic (via ``turn_id`` + ``agent``), tracks -which parts have open text/reasoning blocks, and guards against -re-emission when the runtime re-yields an already-finalized message. -""" +"""Stream state bookkeeping for the event-first outbound walk.""" from __future__ import annotations from typing import Any +from .....types import events as events_ from .....types import messages as messages_ from .. import _approvals, protocol @@ -35,19 +31,22 @@ def __init__(self) -> None: self.emitted_start: bool = False self.in_step: bool = False - # Message-level dedup — an ``is_done`` message re-emitted as input to a - # later ``stream()`` call must not fire events twice. self.seen_done: set[str] = set() + self.skip_current_message: bool = False + self.started_current_message: bool = False - # Tool-call dedup — keyed by tool_call_id. self.started_tool_inputs: set[str] = set() + self.tool_names: dict[str, str] = {} self.input_available_emitted: set[str] = set() self.emitted_tool_results: set[str] = set() self.emitted_approval_requests: set[str] = set() - # Open streaming blocks — keyed by part id. self.open_text_ids: set[str] = set() self.open_reasoning_ids: set[str] = set() + self.completed_text_ids: set[str] = set() + self.completed_reasoning_ids: set[str] = set() + self.text_delta_ids: set[str] = set() + self.reasoning_delta_ids: set[str] = set() # -- boundary helpers ---------------------------------------------------- @@ -55,9 +54,11 @@ def _close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: parts: list[protocol.UIMessageStreamPart] = [] for rid in list(self.open_reasoning_ids): parts.append(protocol.ReasoningEndPart(id=rid)) + self.completed_reasoning_ids.add(rid) self.open_reasoning_ids.clear() for tid in list(self.open_text_ids): parts.append(protocol.TextEndPart(id=tid)) + self.completed_text_ids.add(tid) self.open_text_ids.clear() return parts @@ -70,16 +71,36 @@ def _finish_step(self) -> list[protocol.UIMessageStreamPart]: def _reset_step_tracking(self) -> None: self.started_tool_inputs.clear() + self.tool_names.clear() self.input_available_emitted.clear() self.emitted_tool_results.clear() self.emitted_approval_requests.clear() + @staticmethod + def _is_visible_message(msg: messages_.Message) -> bool: + return msg.role not in ("user", "system") + # -- phase: message start ------------------------------------------------ + def on_message_start( + self, msg: messages_.Message | None + ) -> list[protocol.UIMessageStreamPart]: + self.started_current_message = False + self.skip_current_message = False + if msg is None: + return [] + if msg.id in self.seen_done or not self._is_visible_message(msg): + self.skip_current_message = True + return [] + self.started_current_message = True + return self.on_message(msg) + def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: """Emit UIMessage/step boundary parts for *msg*.""" - parts: list[protocol.UIMessageStreamPart] = [] + if not self._is_visible_message(msg): + return [] + parts: list[protocol.UIMessageStreamPart] = [] agent_changed = ( self.emitted_start and msg.source_label is not None @@ -101,10 +122,6 @@ def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPar self._reset_step_tracking() return parts - # Same UIMessage — check for step boundary via turn_id change. Only - # non-None → different-non-None transitions fire a step boundary; - # None carries the current step (tool results yielded by the loop are - # intentionally left unstamped until the next stream() stamps them). if ( msg.turn_id is not None and self.current_turn_id is not None @@ -120,118 +137,144 @@ def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPar return parts - # -- phase: per-event (mid-stream) --------------------------------------- + # -- phase: streaming events -------------------------------------------- + + def on_event(self, event: events_.Event) -> list[protocol.UIMessageStreamPart]: + if self.skip_current_message: + return [] - def on_event( - self, - msg: messages_.Message, - event: messages_.StreamEvent, - ) -> list[protocol.UIMessageStreamPart]: match event: - case messages_.PartOpened(part=messages_.TextPart(id=pid)): + case events_.TextStart(block_id=pid): self.open_text_ids.add(pid) return [protocol.TextStartPart(id=pid)] - case messages_.PartDelta(part=messages_.TextPart(id=pid), chunk=chunk): + case events_.TextDelta(block_id=pid, chunk=chunk): + out: list[protocol.UIMessageStreamPart] = [] if pid not in self.open_text_ids: self.open_text_ids.add(pid) - return [ - protocol.TextStartPart(id=pid), - protocol.TextDeltaPart(id=pid, delta=chunk), - ] - return [protocol.TextDeltaPart(id=pid, delta=chunk)] + out.append(protocol.TextStartPart(id=pid)) + self.text_delta_ids.add(pid) + out.append(protocol.TextDeltaPart(id=pid, delta=chunk)) + return out - case messages_.PartClosed(part=messages_.TextPart(id=pid)): + case events_.TextEnd(block_id=pid): if pid in self.open_text_ids: self.open_text_ids.discard(pid) + self.completed_text_ids.add(pid) return [protocol.TextEndPart(id=pid)] return [] - case messages_.PartOpened(part=messages_.ReasoningPart(id=pid)): + case events_.ReasoningStart(block_id=pid): self.open_reasoning_ids.add(pid) return [protocol.ReasoningStartPart(id=pid)] - case messages_.PartDelta(part=messages_.ReasoningPart(id=pid), chunk=chunk): + case events_.ReasoningDelta(block_id=pid, chunk=chunk): + out = [] if pid not in self.open_reasoning_ids: self.open_reasoning_ids.add(pid) - return [ - protocol.ReasoningStartPart(id=pid), - protocol.ReasoningDeltaPart(id=pid, delta=chunk), - ] - return [protocol.ReasoningDeltaPart(id=pid, delta=chunk)] + out.append(protocol.ReasoningStartPart(id=pid)) + self.reasoning_delta_ids.add(pid) + out.append(protocol.ReasoningDeltaPart(id=pid, delta=chunk)) + return out - case messages_.PartClosed(part=messages_.ReasoningPart(id=pid)): + case events_.ReasoningEnd(block_id=pid): if pid in self.open_reasoning_ids: self.open_reasoning_ids.discard(pid) + self.completed_reasoning_ids.add(pid) return [protocol.ReasoningEndPart(id=pid)] return [] - case messages_.PartOpened(part=messages_.ToolCallPart() as tc): - if tc.tool_call_id in self.started_tool_inputs: + case events_.ToolStart(tool_call_id=tcid, tool_name=name): + self.tool_names[tcid] = name + if tcid in self.started_tool_inputs: return [] - self.started_tool_inputs.add(tc.tool_call_id) + self.started_tool_inputs.add(tcid) return [ protocol.ToolInputStartPart( - tool_call_id=tc.tool_call_id, - tool_name=tc.tool_name, + tool_call_id=tcid, + tool_name=name, ) ] - case messages_.PartDelta(part=messages_.ToolCallPart() as tc, chunk=chunk): - out: list[protocol.UIMessageStreamPart] = [] - if tc.tool_call_id not in self.started_tool_inputs: - self.started_tool_inputs.add(tc.tool_call_id) + case events_.ToolDelta(tool_call_id=tcid, chunk=chunk): + out = [] + if tcid not in self.started_tool_inputs: + self.started_tool_inputs.add(tcid) out.append( protocol.ToolInputStartPart( - tool_call_id=tc.tool_call_id, - tool_name=tc.tool_name, + tool_call_id=tcid, + tool_name=self.tool_names.get(tcid, ""), ) ) out.append( protocol.ToolInputDeltaPart( - tool_call_id=tc.tool_call_id, + tool_call_id=tcid, input_text_delta=chunk, ) ) return out - case messages_.PartClosed(part=messages_.ToolCallPart()): - # ToolInputAvailablePart is emitted in ``on_terminal`` from - # the terminal ``tool_args`` snapshot. + case events_.ToolEnd(): return [] return [] - # -- phase: terminal (tool results, approvals, final tool-input) --------- + # -- phase: terminal message -------------------------------------------- + + def _static_content( + self, msg: messages_.Message + ) -> list[protocol.UIMessageStreamPart]: + out: list[protocol.UIMessageStreamPart] = [] + + for part in msg.parts: + if isinstance(part, messages_.ReasoningPart): + if part.id not in self.completed_reasoning_ids: + if part.id not in self.open_reasoning_ids: + out.append(protocol.ReasoningStartPart(id=part.id)) + if part.text and part.id not in self.reasoning_delta_ids: + out.append( + protocol.ReasoningDeltaPart(id=part.id, delta=part.text) + ) + out.append(protocol.ReasoningEndPart(id=part.id)) + self.open_reasoning_ids.discard(part.id) + self.completed_reasoning_ids.add(part.id) + + elif isinstance(part, messages_.TextPart): + if part.id not in self.completed_text_ids: + if part.id not in self.open_text_ids: + out.append(protocol.TextStartPart(id=part.id)) + if part.text and part.id not in self.text_delta_ids: + out.append(protocol.TextDeltaPart(id=part.id, delta=part.text)) + out.append(protocol.TextEndPart(id=part.id)) + self.open_text_ids.discard(part.id) + self.completed_text_ids.add(part.id) + + elif isinstance(part, messages_.FilePart): + out.append( + protocol.FilePart( + url=part.data if isinstance(part.data, str) else "", + media_type=part.media_type, + ) + ) + + return out def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: - if not msg.is_done: + if msg.id in self.seen_done or not self._is_visible_message(msg): + self.seen_done.add(msg.id) return [] out: list[protocol.UIMessageStreamPart] = [] + if not self.started_current_message: + out.extend(self.on_message(msg)) - # Close any blocks that were opened but didn't see an explicit - # PartClosed (e.g. provider terminates abruptly — safety net). - if msg.stream is not None: - opened_ids = { - e.part.id - for e in msg.stream.new_events - if isinstance(e, messages_.PartOpened) - } - for tid in list(self.open_text_ids): - if tid in opened_ids and not any( - isinstance(e, messages_.PartClosed) and e.part.id == tid - for e in msg.stream.new_events - ): - out.append(protocol.TextEndPart(id=tid)) - self.open_text_ids.discard(tid) + out.extend(self._static_content(msg)) for part in msg.parts: if isinstance(part, messages_.ToolCallPart): if part.tool_call_id in self.input_available_emitted: continue self.input_available_emitted.add(part.tool_call_id) - # Ensure ToolInputStart was emitted (no streaming events case). if part.tool_call_id not in self.started_tool_inputs: self.started_tool_inputs.add(part.tool_call_id) out.append( @@ -294,6 +337,9 @@ def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPa ) ) + self.seen_done.add(msg.id) + self.skip_current_message = False + self.started_current_message = False return out # -- phase: stream finish ------------------------------------------------ diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index 019f2894..e6c1f581 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -6,7 +6,7 @@ import json from collections.abc import AsyncGenerator, AsyncIterable -from .....types import messages as messages_ +from .....types import events as events_ from .. import protocol from .stream import to_stream @@ -32,8 +32,8 @@ def format_sse(part: protocol.UIMessageStreamPart) -> str: async def to_sse( - messages: AsyncIterable[messages_.Message], + events: AsyncIterable[events_.Event], ) -> AsyncGenerator[str]: - """Convert an internal message stream into SSE strings.""" - async for part in to_stream(messages): + """Convert an internal event stream into SSE strings.""" + async for part in to_stream(events): yield format_sse(part) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 3bccf95f..5ec0db5f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -1,42 +1,35 @@ -"""Convert an internal ``ai.Message`` stream into AI SDK UI protocol parts.""" +"""Convert an internal ``ai.Event`` stream into AI SDK UI protocol parts.""" from __future__ import annotations from collections.abc import AsyncGenerator, AsyncIterable -from .....types import messages as messages_ +from .....types import events as events_ from .. import protocol from ._state import _StreamState async def to_stream( - messages: AsyncIterable[messages_.Message], + events: AsyncIterable[events_.Event], ) -> AsyncGenerator[protocol.UIMessageStreamPart]: - """Walk ``messages`` once, emitting AI SDK UI stream parts. + """Walk ``events`` once, emitting AI SDK UI stream parts. - Drives off ``Message.stream.new_events`` for incremental deltas and - ``Message.parts`` for terminal tool input/output/approval parts. - Re-emitted messages (same id, already seen ``is_done``) are skipped. + Streaming text/reasoning/tool-input deltas come from public events. + Terminal tool results, approvals, and files come from + ``MessageEnd.message``. """ state = _StreamState() - async for msg in messages: - if msg.id in state.seen_done: - continue - - for part in state.on_message(msg): - yield part - - if msg.stream is not None and msg.stream.new_events: - for event in msg.stream.new_events: - for out in state.on_event(msg, event): - yield out - - for part in state.on_terminal(msg): - yield part - - if msg.is_done: - state.seen_done.add(msg.id) + async for event in events: + if isinstance(event, events_.MessageStart): + for part in state.on_message_start(event.message): + yield part + elif isinstance(event, events_.MessageEnd): + for part in state.on_terminal(event.message): + yield part + else: + for part in state.on_event(event): + yield part for part in state.finish(): yield part diff --git a/src/ai/middleware.py b/src/ai/middleware.py index 2390bb98..a8f77b81 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -22,9 +22,9 @@ import pydantic +from .types import events as events_ from .types import messages as messages_ -from .types import tools as tools_ -from .types.proto import StreamResultLike +from .types.proto import StreamResultLike, ToolLike # --------------------------------------------------------------------------- # Call context objects — frozen dataclasses with isolated mutable fields. @@ -47,7 +47,7 @@ class ModelContext: model: Model messages: list[messages_.Message] - tools: Sequence[tools_.ToolLike] | None + tools: Sequence[ToolLike] | None output_type: type[pydantic.BaseModel] | None kwargs: dict[str, Any] @@ -113,11 +113,12 @@ def __post_init__(self) -> None: # Middleware base class — override the methods you care about. # --------------------------------------------------------------------------- -# Message alias for brevity in signatures. +# Event/message aliases for brevity in signatures. +_Event = events_.Event _Message = messages_.Message -# Agent run next-function type: call -> async generator of messages. -_AgentRunNext = Callable[[AgentRunContext], AsyncGenerator[_Message]] +# Agent run next-function type: call -> async generator of events. +_AgentRunNext = Callable[[AgentRunContext], AsyncGenerator[_Event]] class Middleware: @@ -130,21 +131,21 @@ async def wrap_agent_run( self, call: AgentRunContext, next: _AgentRunNext, - ) -> AsyncGenerator[_Message]: + ) -> AsyncGenerator[_Event]: """Wrap an agent run. - ``next(call)`` returns an async generator of ``Message`` objects. + ``next(call)`` returns an async generator of ``Event`` objects. Override to add tracing, durability checkpoints, or other run-scoped behavior:: async def wrap_agent_run(self, call, next): span = start_span("agent.run") - async for msg in next(call): - yield msg + async for event in next(call): + yield event span.end() """ - async for msg in next(call): - yield msg + async for event in next(call): + yield event async def wrap_model( self, @@ -154,7 +155,7 @@ async def wrap_model( """Wrap a model streaming call. ``next(call)`` returns a :class:`~ai.types.StreamResultLike` that - is async-iterable over ``Message`` snapshots. You can do work + is async-iterable over ``Event`` objects. You can do work before, iterate / transform the stream, or do cleanup after. To transform the stream, use @@ -163,8 +164,8 @@ async def wrap_model( async def wrap_model(self, call, next): stream = await next(call) async def _add_suffix(): - async for msg in stream: - yield msg + async for event in stream: + yield event from ai.models import StreamResult return StreamResult.from_generator(_add_suffix()) """ @@ -343,9 +344,9 @@ def _build_agent_run_chain( for m in reversed(mw): def _make(m: Middleware, nxt: _AgentRunNext) -> _AgentRunNext: - async def _wrapped(call: AgentRunContext) -> AsyncGenerator[_Message]: - async for msg in m.wrap_agent_run(call, nxt): - yield msg + async def _wrapped(call: AgentRunContext) -> AsyncGenerator[_Event]: + async for event in m.wrap_agent_run(call, nxt): + yield event return _wrapped diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index f87913f8..61a97767 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -11,16 +11,15 @@ # stream — auto-creates client from env vars msgs = [ai.user_message("hello")] - s = await ai.stream(model, msgs) - async for msg in s: - for ev in msg.deltas: - if isinstance(ev.part, ai.TextPart): - print(ev.chunk, end="", flush=True) + s = ai.stream(model, msgs) + async for event in s: + if isinstance(event, ai.TextDelta): + print(event.chunk, end="", flush=True) # explicit client for custom auth client = ai.Client(base_url="https://custom.example.com/v1", api_key="sk-...") model = openai("gpt-5.4", client=client) - s = await ai.stream(model, msgs) + s = ai.stream(model, msgs) # list available models ids = await openai.list() diff --git a/src/ai/models/ai_gateway/stream.py b/src/ai/models/ai_gateway/stream.py index 1bbb5c6c..8d97c5d9 100644 --- a/src/ai/models/ai_gateway/stream.py +++ b/src/ai/models/ai_gateway/stream.py @@ -15,7 +15,7 @@ from ...types import events as events_ from ...types import media from ...types import messages as messages_ -from ...types import tools as tools_ +from ...types import proto as proto_ from ...types import usage as usage_ from ..core import client as client_ from ..core import model as model_ @@ -124,7 +124,7 @@ async def _messages_to_prompt( async def _build_request_body( messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[proto_.ToolLike] | None = None, output_type: type[Any] | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -286,7 +286,7 @@ async def stream( model: model_.Model, messages: list[messages_.Message], *, - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[proto_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, ) -> AsyncGenerator[events_.Event]: diff --git a/src/ai/models/anthropic/adapter.py b/src/ai/models/anthropic/adapter.py index c54c16d7..34329575 100644 --- a/src/ai/models/anthropic/adapter.py +++ b/src/ai/models/anthropic/adapter.py @@ -282,18 +282,22 @@ async def stream( block_types: dict[int, str] = {} tool_ids: dict[int, str] = {} + tool_names: dict[int, str] = {} signature_buffer: dict[int, str] = {} # Accumulate parts for the final Message parts: list[types.Part] = [] _text_parts: dict[str, str] = {} # block_id -> accumulated text _reasoning_parts: dict[str, str] = {} # block_id -> accumulated text _tool_parts: dict[str, str] = {} # tool_call_id -> accumulated args + message_id = types.generate_id() try: stream_cm = sdk_client.messages.stream(**api_kwargs) async with stream_cm as sdk_stream: - yield events.MessageStart() + yield events.MessageStart( + message=types.Message(id=message_id, role="assistant", parts=[]) + ) async for event in sdk_stream: match event.type: case "content_block_start": @@ -310,6 +314,7 @@ async def stream( yield events.ReasoningStart(block_id=str(idx)) case "tool_use": tool_ids[idx] = block.id + tool_names[idx] = block.name _tool_parts[block.id] = "" yield events.ToolStart( tool_call_id=block.id, @@ -331,8 +336,7 @@ async def stream( ) case "thinking_delta": _reasoning_parts[str(idx)] = ( - _reasoning_parts.get(str(idx), "") - + delta.thinking + _reasoning_parts.get(str(idx), "") + delta.thinking ) yield events.ReasoningDelta( chunk=delta.thinking, @@ -384,7 +388,7 @@ async def stream( types.ToolCallPart( id=tool_id, tool_call_id=tool_id, - tool_name=block_types.get(idx, ""), + tool_name=tool_names.get(idx, ""), tool_args=_tool_parts.get(tool_id, ""), ) ) @@ -402,6 +406,7 @@ async def stream( raw=sdk_usage.model_dump(exclude_none=True) or None, ) final_message = types.Message( + id=message_id, role="assistant", parts=parts, usage=usage, diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index a38e6362..d0b90315 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -14,8 +14,8 @@ from ...types import events as events_ from ...types import integrity as integrity_ from ...types import messages as messages_ +from ...types import proto as proto_ from ...types import stream as stream_ -from ...types import tools as tools_ from . import adapters from . import client as client_ from . import model as model_ @@ -26,11 +26,11 @@ def stream( model: model_.Model, messages: list[messages_.Message], *, - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[proto_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, turn_id: str | None = None, **kwargs: Any, -) -> stream_.StreamResultLike: +) -> proto_.StreamResultLike: """Stream an LLM response. Returns a :class:`StreamResultLike` that is async-iterable and @@ -39,7 +39,7 @@ def stream( Call-site is a plain ``async for`` — no outer ``await`` needed:: - async for msg in ai.stream(model, messages): + async for event in ai.stream(model, messages): ... One call is one turn: a single request and its response. The model @@ -69,7 +69,7 @@ def stream( # Capture in closure for the inner function. _turn_id = turn_id - async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: + async def _real(call: middleware_.ModelContext) -> proto_.StreamResultLike: c = client_.auto_client(call.model) adapter_fn = adapters.get_stream_adapter(call.model.adapter) return types_.StreamResult( @@ -91,7 +91,7 @@ async def _driver() -> AsyncGenerator[events_.Event]: async for event in inner: yield event - return stream_.StreamResult.from_generator(_driver()) + return stream_.StreamResult(_driver(), turn_id=turn_id) async def generate( diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index 57ceebe5..25d8b649 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -113,7 +113,7 @@ class StreamHandler: def message_start(self) -> events_.MessageStart: """Emit a MessageStart event at the beginning of a stream.""" - return events_.MessageStart() + return events_.MessageStart(message=self._build_message()) def handle_event(self, event: StreamEvent) -> list[events_.Event]: """Process an adapter event and return public Event objects.""" diff --git a/src/ai/models/core/params.py b/src/ai/models/core/params.py index afd143f9..9969c31c 100644 --- a/src/ai/models/core/params.py +++ b/src/ai/models/core/params.py @@ -1,6 +1,6 @@ from typing import Any -import pydantic +import pydantic _PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 173bcfe7..35daf675 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -14,8 +14,9 @@ import pydantic +from ...types import events as events_ from ...types import messages as messages_ -from ...types import tools as tools_ +from ...types import proto as types_proto_ if TYPE_CHECKING: from .client import Client @@ -89,9 +90,8 @@ def __call__( class StreamFn(Protocol): """Protocol for streaming adapter functions. - Implementations yield ``Message`` snapshots as the response streams - in. Each snapshot is a complete, self-contained message reflecting - the accumulated state up to that point. + Implementations yield event objects as the response streams in. The + terminal assistant state is surfaced as a ``MessageEnd.message``. """ def __call__( @@ -100,10 +100,10 @@ def __call__( model: Model, messages: list[messages_.Message], *, - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[types_proto_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages_.Message]: ... + ) -> AsyncGenerator[events_.Event]: ... @runtime_checkable diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index 764d9b29..3bfd082f 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -15,7 +15,7 @@ from ...types import events as events_ from ...types import media from ...types import messages as messages_ -from ...types import tools as tools_ +from ...types import proto as proto_ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files, streaming @@ -26,7 +26,7 @@ def _tools_to_openai( - tools: Sequence[tools_.ToolLike], + tools: Sequence[proto_.ToolLike], ) -> list[dict[str, Any]]: """Convert internal Tool objects to OpenAI tool schema format.""" return [ @@ -201,7 +201,7 @@ async def stream( model: model_.Model, messages: list[messages_.Message], *, - tools: Sequence[tools_.ToolLike] | None = None, + tools: Sequence[proto_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, thinking: bool = False, budget_tokens: int | None = None, @@ -314,9 +314,7 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if reasoning_value: if not reasoning_started: reasoning_started = True - for e in _emit( - streaming.ReasoningStart(block_id="reasoning") - ): + for e in _emit(streaming.ReasoningStart(block_id="reasoning")): yield e for e in _emit( streaming.ReasoningDelta( @@ -327,9 +325,7 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if delta.content: if reasoning_started: - for e in _emit( - streaming.ReasoningEnd(block_id="reasoning") - ): + for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): yield e reasoning_started = False @@ -382,23 +378,17 @@ def _emit(adapter_event: streaming.StreamEvent) -> list[events_.Event]: if choice.finish_reason is not None: finish_reason = choice.finish_reason if reasoning_started: - for e in _emit( - streaming.ReasoningEnd(block_id="reasoning") - ): + for e in _emit(streaming.ReasoningEnd(block_id="reasoning")): yield e if text_started: for e in _emit(streaming.TextEnd(block_id="text")): yield e for tc in tc_state.values(): if tc["started"] and tc["id"]: - for e in _emit( - streaming.ToolEnd(tool_call_id=tc["id"]) - ): + for e in _emit(streaming.ToolEnd(tool_call_id=tc["id"])): yield e - for e in _emit( - streaming.MessageDone(finish_reason=finish_reason, usage=usage) - ): + for e in _emit(streaming.MessageDone(finish_reason=finish_reason, usage=usage)): yield e finally: await sdk_client.close() diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 4c9a277e..769b12df 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -1,3 +1,4 @@ +from . import media from .events import ( End, Event, @@ -31,7 +32,6 @@ from .proto import StreamResultLike, ToolLike from .tools import ToolSchema from .usage import Usage -from . import media __all__ = [ "End", diff --git a/src/ai/types/group.py b/src/ai/types/group.py deleted file mode 100644 index 942003c8..00000000 --- a/src/ai/types/group.py +++ /dev/null @@ -1,222 +0,0 @@ -import dataclasses - -from ....types import messages as messages_ - - -@dataclasses.dataclass -class TextStart: - block_id: str - - -@dataclasses.dataclass -class TextDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class TextEnd: - block_id: str - - -@dataclasses.dataclass -class ReasoningStart: - block_id: str - - -@dataclasses.dataclass -class ReasoningDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class ReasoningEnd: - block_id: str - signature: str | None = None - - -@dataclasses.dataclass -class ToolStart: - tool_call_id: str - tool_name: str - - -@dataclasses.dataclass -class ToolArgsDelta: - tool_call_id: str - delta: str - - -@dataclasses.dataclass -class ToolEnd: - tool_call_id: str - - -@dataclasses.dataclass -class FileEvent: - """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" - - block_id: str - media_type: str - data: str # base64 string or data-URL from the gateway - - -@dataclasses.dataclass -class MessageDone: - finish_reason: str | None = None - usage: messages_.Usage | None = None - - -StreamEvent = ( - TextStart - | TextDelta - | TextEnd - | ReasoningStart - | ReasoningDelta - | ReasoningEnd - | ToolStart - | ToolArgsDelta - | ToolEnd - | FileEvent - | MessageDone -) - - -@dataclasses.dataclass -class StreamHandler: - """ - Accumulates LLM adapter events and produces Messages with stateful parts. - - This is the normalization layer between LLM adapters and the rest of the system. - Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, - updated in place as events stream in. Each event carries the just-constructed - frozen part snapshot, so consumers never need to look parts up by id. - """ - - message_id: str = dataclasses.field(default_factory=messages_.generate_id) - - # Single source of truth for part state, keyed by id. Insertion order - # preserves provider emission order. - _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) - - # Active tracking - _active_text_id: str | None = None - _active_reasoning_id: str | None = None - _active_tool_ids: set[str] = dataclasses.field(default_factory=set) - - _is_done: bool = False - _usage: messages_.Usage | None = None - - def handle_event(self, event: StreamEvent) -> messages_.Message: - """Process event and return current Message state.""" - - # Sidecar events for this yield (reset each call). - stream_events: list[messages_.StreamEvent] = [] - - match event: - case TextStart(block_id=bid): - part: messages_.Part = messages_.TextPart(id=bid, text="") - self._current_parts[bid] = part - self._active_text_id = bid - stream_events.append(messages_.PartOpened(part=part)) - - case TextDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.TextPart) - part = messages_.TextPart(id=bid, text=existing.text + d) - self._current_parts[bid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) - - case TextEnd(block_id=bid): - if self._active_text_id == bid: - self._active_text_id = None - stream_events.append( - messages_.PartClosed(part=self._current_parts[bid]) - ) - - case ReasoningStart(block_id=bid): - part = messages_.ReasoningPart(id=bid, text="") - self._current_parts[bid] = part - self._active_reasoning_id = bid - stream_events.append(messages_.PartOpened(part=part)) - - case ReasoningDelta(block_id=bid, delta=d): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, - text=existing.text + d, - signature=existing.signature, - ) - self._current_parts[bid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) - - case ReasoningEnd(block_id=bid, signature=sig): - existing = self._current_parts[bid] - assert isinstance(existing, messages_.ReasoningPart) - part = messages_.ReasoningPart( - id=bid, text=existing.text, signature=sig - ) - self._current_parts[bid] = part - if self._active_reasoning_id == bid: - self._active_reasoning_id = None - stream_events.append(messages_.PartClosed(part=part)) - - case ToolStart(tool_call_id=tcid, tool_name=name): - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=tcid, - tool_name=name, - tool_args="", - ) - self._current_parts[tcid] = part - self._active_tool_ids.add(tcid) - stream_events.append(messages_.PartOpened(part=part)) - - case ToolArgsDelta(tool_call_id=tcid, delta=d): - existing = self._current_parts[tcid] - assert isinstance(existing, messages_.ToolCallPart) - part = messages_.ToolCallPart( - id=tcid, - tool_call_id=existing.tool_call_id, - tool_name=existing.tool_name, - tool_args=existing.tool_args + d, - ) - self._current_parts[tcid] = part - stream_events.append(messages_.PartDelta(part=part, chunk=d)) - - case ToolEnd(tool_call_id=tcid): - self._active_tool_ids.discard(tcid) - stream_events.append( - messages_.PartClosed(part=self._current_parts[tcid]) - ) - - case FileEvent(block_id=bid, media_type=mt, data=d): - self._current_parts[bid] = messages_.FilePart( - id=bid, data=d, media_type=mt - ) - - case MessageDone(usage=usage): - self._is_done = True - self._usage = usage - self._active_text_id = None - self._active_reasoning_id = None - self._active_tool_ids.clear() - - return self._build_message(stream_events) - - def _build_message( - self, - stream_events: list[messages_.StreamEvent], - ) -> messages_.Message: - return messages_.Message( - id=self.message_id, - role="assistant", - parts=list(self._current_parts.values()), - usage=self._usage if self._is_done else None, - stream=messages_.StreamState( - new_events=stream_events, - is_done=self._is_done, - ), - ) diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 0f6974cc..29332eca 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -58,7 +58,7 @@ class HookPart[T](pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) hook_id: str hook_type: str - status: Literal["pending", "resolved"] + status: Literal["pending", "resolved", "cancelled"] metadata: dict[str, Any] = pydantic.Field(default_factory=dict) resolution: T | None = None @@ -187,15 +187,6 @@ def from_bytes( ] -ALLOWED_PARTS: dict[str, set[str]] = { - "user": {"text", "file"}, - "assistant": {"text", "tool_call", "reasoning", "structured_output"}, - "system": {"text"}, - "tool": {"tool_result"}, - "internal": {"hook"}, -} - - class Message(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) @@ -206,13 +197,65 @@ class Message(pydantic.BaseModel): source_label: str | None = None usage: usage_.Usage | None = None - @pydantic.model_validator(mode="after") - def _check_parts(self) -> Self: - allowed = ALLOWED_PARTS[self.role] - bad = [p.kind for p in self.parts if p.kind not in allowed] - if bad: - raise ValueError( - f"role={self.role!r} cannot contain parts of kind(s) " - f"{sorted(set(bad))}; allowed: {sorted(allowed)}" - ) - return self + @property + def text(self) -> str: + """Concatenated text parts.""" + return "".join(p.text for p in self.parts if isinstance(p, TextPart)) + + @property + def reasoning(self) -> str: + """Concatenated reasoning parts.""" + return "".join(p.text for p in self.parts if isinstance(p, ReasoningPart)) + + @property + def tool_calls(self) -> list[ToolCallPart]: + return [p for p in self.parts if isinstance(p, ToolCallPart)] + + @property + def tool_results(self) -> list[ToolResultPart]: + return [p for p in self.parts if isinstance(p, ToolResultPart)] + + @property + def files(self) -> list[FilePart]: + return [p for p in self.parts if isinstance(p, FilePart)] + + @property + def images(self) -> list[FilePart]: + return [p for p in self.files if p.media_type.startswith("image/")] + + @property + def videos(self) -> list[FilePart]: + return [p for p in self.files if p.media_type.startswith("video/")] + + @property + def output(self) -> Any: + """Parsed structured output from the first structured-output part.""" + for part in self.parts: + if isinstance(part, StructuredOutputPart): + return part.value + return None + + def replace(self, old: Part, new: Part | None = None) -> Self: + """Return a copy with one part replaced. + + ``replace(new_part)`` matches by ``new_part.id``. + ``replace(old_part, new_part)`` matches by object identity. + """ + if new is None: + new = old + for idx, part in enumerate(self.parts): + if part.id == new.id: + parts = list(self.parts) + parts[idx] = new + return self.model_copy(update={"parts": parts}) + raise ValueError(f"Part id={new.id!r} not found in message {self.id!r}") + + for idx, part in enumerate(self.parts): + if part is old: + parts = list(self.parts) + parts[idx] = new + return self.model_copy(update={"parts": parts}) + raise ValueError(f"Part id={old.id!r} not found in message {self.id!r}") + + +Usage = usage_.Usage diff --git a/src/ai/types/proto.py b/src/ai/types/proto.py index e389e4d1..a2a8e3ca 100644 --- a/src/ai/types/proto.py +++ b/src/ai/types/proto.py @@ -28,6 +28,9 @@ class StreamResultLike(Protocol): def __aiter__(self) -> AsyncGenerator[events_.Event]: ... + @property + def message(self) -> messages.Message | None: ... + @property def text(self) -> str: ... diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py index b64d648d..fdaaebca 100644 --- a/src/ai/types/stream.py +++ b/src/ai/types/stream.py @@ -39,19 +39,30 @@ def from_generator(cls, gen: AsyncGenerator[events_.Event]) -> Self: def __aiter__(self) -> AsyncGenerator[events_.Event]: return self._iterate() + def _stamp_message(self, msg: messages.Message) -> messages.Message: + if msg.turn_id is None and self._turn_id is not None: + return msg.model_copy(update={"turn_id": self._turn_id}) + return msg + async def _iterate(self) -> AsyncGenerator[events_.Event]: # Re-emit input messages as MessageStart + MessageEnd event pairs. for msg in self._input_messages: - if msg.turn_id is None and self._turn_id is not None: - msg = msg.model_copy(update={"turn_id": self._turn_id}) + msg = self._stamp_message(msg) yield events_.MessageStart(message=msg) yield events_.MessageEnd(message=msg) # Stream adapter events. async for event in self._gen: + if isinstance(event, events_.MessageStart) and event.message is not None: + event = event.model_copy( + update={"message": self._stamp_message(event.message)} + ) + # Capture the final message from MessageEnd. if isinstance(event, events_.MessageEnd): - self._message = event.message + message = self._stamp_message(event.message) + event = event.model_copy(update={"message": message}) + self._message = message self._usage = event.usage yield event @@ -77,9 +88,7 @@ def text(self) -> str: def tool_calls(self) -> list[messages.ToolCallPart]: if self._message is None: return [] - return [ - p for p in self._message.parts if isinstance(p, messages.ToolCallPart) - ] + return [p for p in self._message.parts if isinstance(p, messages.ToolCallPart)] @property def usage(self) -> usage_.Usage | None: diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 9568af08..54b741af 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -10,7 +10,7 @@ import ai from ai.agents.mcp.client import _mcp_tool_to_native -from ...conftest import MOCK_MODEL, mock_llm, text_msg, tool_call_msg +from ...conftest import MOCK_MODEL, collect_messages, mock_llm, text_msg, tool_call_msg def _fake_mcp_tool( @@ -88,9 +88,9 @@ async def fake_fn(**kwargs: str) -> str: call2 = [text_msg("Done.", id="msg-2")] llm = mock_llm([call1, call2]) - msgs: list[ai.Message] = [] - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("echo hello")]): - msgs.append(msg) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("echo hello")]) + ) # Tool was called with the right args. assert len(call_log) == 1 diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index c1e6a4f7..3c4e9812 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -10,9 +10,10 @@ import ai from ai import models from ai.models.core.helpers import streaming as streaming_ +from ai.types import events as events_ from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_call_msg +from ..conftest import MOCK_MODEL, collect_messages, mock_llm, text_msg, tool_call_msg # --------------------------------------------------------------------------- # Generator tool: yields intermediate messages, returns final text @@ -44,9 +45,9 @@ async def test_generator_tool_streams_and_returns_result() -> None: reply = [text_msg("Done!", id="msg-2")] llm = mock_llm([call, reply]) - collected: list[ai.Message] = [] - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("Go")]): - collected.append(msg) + collected = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Go")]) + ) assert llm.call_count == 2 @@ -87,7 +88,7 @@ async def stream( tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[events_.Event]: if self._idx >= len(self._responses): raise RuntimeError("_CapturingAdapter: no more responses") self.call_count += 1 @@ -95,35 +96,46 @@ async def stream( seq = self._responses[self._idx] self._idx += 1 - handler = streaming_.StreamHandler() + message_id = seq[0].id if seq else messages_.generate_id() + handler = streaming_.StreamHandler(message_id=message_id) + yield handler.message_start() for msg in seq: for i, part in enumerate(msg.parts): if isinstance(part, messages_.TextPart): bid = f"text-{i}" - yield handler.handle_event(streaming_.TextStart(block_id=bid)) + for event in handler.handle_event( + streaming_.TextStart(block_id=bid) + ): + yield event if part.text: - yield handler.handle_event( + for event in handler.handle_event( streaming_.TextDelta(block_id=bid, delta=part.text) - ) - yield handler.handle_event(streaming_.TextEnd(block_id=bid)) + ): + yield event + for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): + yield event elif isinstance(part, messages_.ToolCallPart): - yield handler.handle_event( + for event in handler.handle_event( streaming_.ToolStart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, ) - ) + ): + yield event if part.tool_args: - yield handler.handle_event( + for event in handler.handle_event( streaming_.ToolArgsDelta( tool_call_id=part.tool_call_id, delta=part.tool_args, ) - ) - yield handler.handle_event( + ): + yield event + for event in handler.handle_event( streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ) - yield handler.handle_event(streaming_.MessageDone()) + ): + yield event + for event in handler.handle_event(streaming_.MessageDone()): + yield event @ai.tool @@ -133,7 +145,7 @@ async def inner_fact(topic: str) -> str: @ai.tool # type: ignore[arg-type] -async def research_tool(topic: str) -> AsyncGenerator[ai.Message]: +async def research_tool(topic: str) -> AsyncGenerator[ai.Event]: """Nested agent that researches a topic.""" inner = ai.agent(tools=[inner_fact]) @@ -141,8 +153,8 @@ async def research_tool(topic: str) -> AsyncGenerator[ai.Message]: ai.system_message("Be concise."), ai.user_message(f"Research: {topic}"), ] - async for msg in inner.run(MOCK_MODEL, msgs, label="inner"): - yield msg + async for event in inner.run(MOCK_MODEL, msgs, label="inner"): + yield event async def test_yield_from_nested_agent() -> None: @@ -168,9 +180,9 @@ async def test_yield_from_nested_agent() -> None: adapter = _CapturingAdapter([outer_call, inner_reply, outer_reply]) models.register_stream("mock", adapter.stream) - collected: list[ai.Message] = [] - async for msg in outer.run(MOCK_MODEL, [ai.user_message("Tell me about Mars")]): - collected.append(msg) + collected = await collect_messages( + outer.run(MOCK_MODEL, [ai.user_message("Tell me about Mars")]) + ) assert adapter.call_count == 3 diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index 91a1fe2d..1cfb2cc5 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -27,16 +27,19 @@ async def test_resolve_live_future() -> None: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: nonlocal resolved_value - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async for event in ai.models.stream(context.model, context.messages): + yield event result = await ai.hook("confirm_1", payload=Confirmation) resolved_value = result mock_llm([[text_msg("OK")]]) - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + if not isinstance(event, ai.MessageEnd): + continue + msg = event.message # When we see the pending hook message, resolve it. if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): ai.resolve_hook("confirm_1", {"approved": True, "reason": "looks good"}) @@ -55,10 +58,10 @@ async def test_cancel_live_hook() -> None: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: nonlocal was_cancelled - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async for event in ai.models.stream(context.model, context.messages): + yield event try: await ai.hook("cancel_me", payload=Confirmation) except asyncio.CancelledError: @@ -66,7 +69,10 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: mock_llm([[text_msg("OK")]]) - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + if not isinstance(event, ai.MessageEnd): + continue + msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): await ai.cancel_hook("cancel_me", reason="denied") @@ -90,10 +96,10 @@ async def test_pre_registered_resolution_consumed() -> None: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: nonlocal resolved_value - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async for event in ai.models.stream(context.model, context.messages): + yield event resolved_value = await ai.hook("pre_reg_1", payload=Confirmation) # Pre-register BEFORE run. @@ -129,15 +135,18 @@ async def test_resolved_hook_emits_message() -> None: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: + async for event in ai.models.stream(context.model, context.messages): + yield event await ai.hook("emit_test", payload=Confirmation) mock_llm([[text_msg("OK")]]) msgs: list[ai.Message] = [] - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + if not isinstance(event, ai.MessageEnd): + continue + msg = event.message msgs.append(msg) if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): ai.resolve_hook("emit_test", {"approved": False}) @@ -158,9 +167,9 @@ async def test_hook_metadata_in_pending() -> None: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: + async for event in ai.models.stream(context.model, context.messages): + yield event await ai.hook( "meta_test", payload=Confirmation, @@ -170,8 +179,9 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: mock_llm([[text_msg("OK")]]) msgs: list[ai.Message] = [] - async for msg in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): - msgs.append(msg) + async for event in my_agent.run(MOCK_MODEL, [ai.user_message("go")]): + if isinstance(event, ai.MessageEnd): + msgs.append(event.message) hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] assert len(hook_msgs) >= 1 diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 6cda3a8d..4165b727 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -5,7 +5,7 @@ import ai from ai.types import messages -from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_call_msg +from ..conftest import MOCK_MODEL, collect_messages, mock_llm, text_msg, tool_call_msg # -- Tool definitions for tests -------------------------------------------- @@ -30,9 +30,7 @@ async def test_agent_text_only() -> None: my_agent = ai.agent(tools=[double]) llm = mock_llm([[text_msg("Hello!")]]) - msgs: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Hi")]): - msgs.append(m) + msgs = await collect_messages(my_agent.run(MOCK_MODEL, [ai.user_message("Hi")])) assert llm.call_count == 1 assert any(m.text == "Hello!" for m in msgs) @@ -48,9 +46,9 @@ async def test_agent_tool_then_text() -> None: call2 = [text_msg("The answer is 10.")] llm = mock_llm([call1, call2]) - msgs: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]): - msgs.append(m) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]) + ) assert llm.call_count == 2 tool_results = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_results) >= 1 @@ -83,9 +81,9 @@ async def test_agent_parallel_tools() -> None: call2 = [text_msg("6 and 14", id="msg-2")] llm = mock_llm([[two_tools], call2]) - msgs: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Double 3 and 7")]): - msgs.append(m) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Double 3 and 7")]) + ) assert llm.call_count == 2 tool_result_msgs = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_result_msgs) >= 1 @@ -105,9 +103,9 @@ async def test_agent_multi_turn() -> None: turn3 = [text_msg("Done: hello world, 6", id="msg-3")] llm = mock_llm([turn1, turn2, turn3]) - msgs: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]): - msgs.append(m) + await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]) + ) assert llm.call_count == 3 @@ -131,18 +129,17 @@ async def test_two_user_messages_produce_four_turns() -> None: def dedup(stream: list[ai.Message]) -> list[ai.Message]: seen: dict[str, ai.Message] = {} for m in stream: - if m.is_done: - seen[m.id] = m + seen[m.id] = m return list(seen.values()) - run1_stream: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]): - run1_stream.append(m) + run1_stream = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]) + ) history = dedup(run1_stream) - run2_stream: list[ai.Message] = [] - async for m in my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]): - run2_stream.append(m) + run2_stream = await collect_messages( + my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]) + ) final = dedup(run2_stream) # Chronological list of terminal non-internal messages. Insertion order diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index af7f998f..501d4d03 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -5,6 +5,7 @@ from ai.agents.ui.ai_sdk import protocol, to_sse from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part +from ai.types import events as events_ from ai.types import messages as messages_ @@ -29,23 +30,25 @@ def test_serialize_data_part_uses_type_with_prefix() -> None: async def _gen( - msgs: list[messages_.Message], -) -> AsyncGenerator[messages_.Message]: - for m in msgs: - yield m + stream_events: list[events_.Event], +) -> AsyncGenerator[events_.Event]: + for event in stream_events: + yield event async def test_to_sse_emits_data_prefixed_lines() -> None: - msgs = [ - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(text="hi")], - stream=messages_.StreamState(new_events=[], is_done=True), + msg = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hi")], + ) + lines = [ + line + async for line in to_sse( + _gen([events_.MessageStart(message=msg), events_.MessageEnd(message=msg)]) ) ] - lines = [line async for line in to_sse(_gen(msgs))] assert all(line.startswith("data: ") for line in lines) # first line is the start part first = json.loads(lines[0].removeprefix("data: ").rstrip()) diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 7a14fd1e..5cd1bb19 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -3,88 +3,58 @@ from collections.abc import AsyncGenerator from ai.agents.ui.ai_sdk import protocol, to_stream +from ai.types import events as events_ from ai.types import messages as messages_ async def _gen( - msgs: list[messages_.Message], -) -> AsyncGenerator[messages_.Message]: - for m in msgs: - yield m + stream_events: list[events_.Event], +) -> AsyncGenerator[events_.Event]: + for event in stream_events: + yield event async def _collect( - msgs: list[messages_.Message], + stream_events: list[events_.Event], ) -> list[protocol.UIMessageStreamPart]: - return [p async for p in to_stream(_gen(msgs))] + return [part async for part in to_stream(_gen(stream_events))] -def _text_stream_message( - msg_id: str, - turn_id: str | None, - text_id: str, - chunk: str, +def _assistant_start( + msg_id: str = "m1", *, - is_done: bool, - full_text: str | None = None, -) -> messages_.Message: - text = full_text or chunk - part = messages_.TextPart(id=text_id, text=text) - events: list[messages_.StreamEvent] - if is_done: - events = [messages_.PartClosed(part=part)] - else: - events = [messages_.PartDelta(part=part, chunk=chunk)] - return messages_.Message( - id=msg_id, - role="assistant", - turn_id=turn_id, - parts=[part], - stream=messages_.StreamState(new_events=events, is_done=is_done), + turn_id: str | None = "t1", + source_label: str | None = None, +) -> events_.MessageStart: + return events_.MessageStart( + message=messages_.Message( + id=msg_id, + role="assistant", + turn_id=turn_id, + source_label=source_label, + parts=[], + ) ) async def test_event_driven_text_streaming() -> None: text_id = "txt1" - empty_text = messages_.TextPart(id=text_id, text="") - hi_text = messages_.TextPart(id=text_id, text="hi") - msgs = [ - # Initial: PartOpened - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[empty_text], - stream=messages_.StreamState( - new_events=[messages_.PartOpened(part=empty_text)], - is_done=False, - ), - ), - # Delta: "hi" - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[hi_text], - stream=messages_.StreamState( - new_events=[messages_.PartDelta(part=hi_text, chunk="hi")], - is_done=False, - ), - ), - # Closed - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[hi_text], - stream=messages_.StreamState( - new_events=[messages_.PartClosed(part=hi_text)], - is_done=True, - ), - ), - ] - out = await _collect(msgs) - # expect: Start, StartStep, TextStart, TextDelta, TextEnd, FinishStep, Finish + final = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(id=text_id, text="hi")], + ) + out = await _collect( + [ + _assistant_start("m1"), + events_.TextStart(block_id=text_id), + events_.TextDelta(block_id=text_id, chunk="hi"), + events_.TextEnd(block_id=text_id), + events_.MessageEnd(message=final), + ] + ) + assert isinstance(out[0], protocol.StartPart) assert out[0].message_id == "m1" assert isinstance(out[1], protocol.StartStepPart) @@ -95,25 +65,39 @@ async def test_event_driven_text_streaming() -> None: assert isinstance(out[6], protocol.FinishPart) +async def test_static_text_message_emits_text_parts() -> None: + msg = messages_.Message( + id="m1", + role="assistant", + parts=[messages_.TextPart(id="txt1", text="hello")], + ) + out = await _collect( + [events_.MessageStart(message=msg), events_.MessageEnd(message=msg)] + ) + assert any(isinstance(part, protocol.TextDeltaPart) for part in out) + + async def test_turn_id_change_emits_step_boundary() -> None: - msgs = [ - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[messages_.TextPart(text="hello")], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - messages_.Message( - id="m2", - role="assistant", - turn_id="t2", # different turn → step boundary - parts=[messages_.TextPart(text="world")], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - ] - out = await _collect(msgs) - # Look for FinishStep followed by StartStep between messages. + msg1 = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hello")], + ) + msg2 = messages_.Message( + id="m2", + role="assistant", + turn_id="t2", + parts=[messages_.TextPart(text="world")], + ) + out = await _collect( + [ + events_.MessageStart(message=msg1), + events_.MessageEnd(message=msg1), + events_.MessageStart(message=msg2), + events_.MessageEnd(message=msg2), + ] + ) has_mid_step_boundary = any( isinstance(out[i], protocol.FinishStepPart) and i + 1 < len(out) @@ -124,24 +108,26 @@ async def test_turn_id_change_emits_step_boundary() -> None: async def test_agent_change_emits_message_boundary() -> None: - msgs = [ - messages_.Message( - id="m1", - role="assistant", - source_label="a1", - parts=[messages_.TextPart(text="from a")], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - messages_.Message( - id="m2", - role="assistant", - source_label="a2", # different source → FinishPart + StartPart - parts=[messages_.TextPart(text="from b")], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - ] - out = await _collect(msgs) - # There should be a FinishPart+StartPart pair mid-stream. + msg1 = messages_.Message( + id="m1", + role="assistant", + source_label="a1", + parts=[messages_.TextPart(text="from a")], + ) + msg2 = messages_.Message( + id="m2", + role="assistant", + source_label="a2", + parts=[messages_.TextPart(text="from b")], + ) + out = await _collect( + [ + events_.MessageStart(message=msg1), + events_.MessageEnd(message=msg1), + events_.MessageStart(message=msg2), + events_.MessageEnd(message=msg2), + ] + ) has_mid_msg_boundary = any( isinstance(out[i], protocol.FinishPart) and i + 1 < len(out) @@ -152,67 +138,75 @@ async def test_agent_change_emits_message_boundary() -> None: async def test_tool_call_and_result_emit_terminal_parts() -> None: - msgs = [ - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[ - messages_.ToolCallPart( - id="tc1", - tool_call_id="tc1", - tool_name="search", - tool_args='{"q":"x"}', - ) - ], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - messages_.Message( - role="tool", - parts=[ - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 1}, - ) - ], - ), - ] - out = await _collect(msgs) - types = [type(p).__name__ for p in out] + tool_call = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ) + ], + ) + tool_result = messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ) + out = await _collect( + [ + events_.MessageStart(message=tool_call), + events_.MessageEnd(message=tool_call), + events_.MessageStart(message=tool_result), + events_.MessageEnd(message=tool_result), + ] + ) + types = [type(part).__name__ for part in out] assert "ToolInputStartPart" in types assert "ToolInputAvailablePart" in types assert "ToolOutputAvailablePart" in types async def test_approval_request_hook_emits_approval_part() -> None: - msgs = [ - messages_.Message( - id="m1", - role="assistant", - turn_id="t1", - parts=[ - messages_.ToolCallPart( - id="tc1", - tool_call_id="tc1", - tool_name="delete", - tool_args="{}", - ) - ], - stream=messages_.StreamState(new_events=[], is_done=True), - ), - messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ) - ], - ), - ] - out = await _collect(msgs) + tool_call = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + ) + hook = messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ) + out = await _collect( + [ + events_.MessageStart(message=tool_call), + events_.MessageEnd(message=tool_call), + events_.MessageStart(message=hook), + events_.MessageEnd(message=hook), + ] + ) approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] assert len(approval_parts) == 1 assert approval_parts[0].tool_call_id == "tc1" @@ -220,23 +214,19 @@ async def test_approval_request_hook_emits_approval_part() -> None: async def test_dedup_on_reemitted_message_id() -> None: - empty = messages_.TextPart(id="txt1", text="") - hi = messages_.TextPart(id="txt1", text="hi") msg = messages_.Message( id="m1", role="assistant", turn_id="t1", - parts=[hi], - stream=messages_.StreamState( - new_events=[ - messages_.PartOpened(part=empty), - messages_.PartDelta(part=hi, chunk="hi"), - messages_.PartClosed(part=hi), - ], - is_done=True, - ), - ) - out = await _collect([msg, msg]) # re-emit the same done message - text_deltas = [p for p in out if isinstance(p, protocol.TextDeltaPart)] - # only the first emission should fire a TextDelta + parts=[messages_.TextPart(id="txt1", text="hi")], + ) + stream_events: list[events_.Event] = [ + events_.MessageStart(message=msg), + events_.TextStart(block_id="txt1"), + events_.TextDelta(block_id="txt1", chunk="hi"), + events_.TextEnd(block_id="txt1"), + events_.MessageEnd(message=msg), + ] + out = await _collect([*stream_events, *stream_events]) + text_deltas = [part for part in out if isinstance(part, protocol.TextDeltaPart)] assert len(text_deltas) == 1 diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py index e0733757..993a9c25 100644 --- a/tests/agents/ui/ai_sdk/test_approvals.py +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -1,11 +1,13 @@ from __future__ import annotations +from typing import Any + from ai.agents.ui.ai_sdk import _approvals from ai.types import messages as messages_ def test_tool_call_id_for_strips_prefix() -> None: - hook = messages_.HookPart( + hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc_42", hook_type="ToolApproval", status="pending", @@ -14,7 +16,7 @@ def test_tool_call_id_for_strips_prefix() -> None: def test_tool_call_id_for_rejects_non_approval_type() -> None: - hook = messages_.HookPart( + hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc_42", hook_type="SomethingElse", status="pending", @@ -23,7 +25,7 @@ def test_tool_call_id_for_rejects_non_approval_type() -> None: def test_tool_call_id_for_rejects_bad_prefix() -> None: - hook = messages_.HookPart( + hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="tc_42", hook_type="ToolApproval", status="pending", diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py index 9b808d82..2fc76d57 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -79,7 +79,7 @@ def test_to_messages_approval_hook_emitted_as_internal() -> None: ) assert [m.role for m in result] == ["assistant", "internal"] hook = result[1].parts[0] - assert hook.type == "hook" + assert hook.kind == "hook" assert hook.hook_id == "approve_tc1" diff --git a/tests/conftest.py b/tests/conftest.py index dd876b3a..d7e4abda 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, Sequence +from collections.abc import AsyncGenerator, AsyncIterable, Sequence from typing import Any import pydantic @@ -8,6 +8,7 @@ import ai from ai import models from ai.types import builders +from ai.types import events as events_ from ai.types import messages as messages_ @@ -108,7 +109,7 @@ async def stream( tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[events_.Event]: if self._call_index >= len(self._responses): raise RuntimeError("MockAdapter: no more responses configured") self.call_count += 1 @@ -117,49 +118,78 @@ async def stream( from ai.models.core.helpers import streaming as streaming_ - handler = streaming_.StreamHandler() + message_id = seq[0].id if seq else messages_.generate_id() + handler = streaming_.StreamHandler(message_id=message_id) + yield handler.message_start() for msg in seq: for i, part in enumerate(msg.parts): if isinstance(part, messages_.TextPart): bid = f"text-{i}" - yield handler.handle_event(streaming_.TextStart(block_id=bid)) + for event in handler.handle_event( + streaming_.TextStart(block_id=bid) + ): + yield event if part.text: - yield handler.handle_event( + for event in handler.handle_event( streaming_.TextDelta(block_id=bid, delta=part.text) - ) - yield handler.handle_event(streaming_.TextEnd(block_id=bid)) + ): + yield event + for event in handler.handle_event(streaming_.TextEnd(block_id=bid)): + yield event elif isinstance(part, messages_.ReasoningPart): bid = f"reasoning-{i}" - yield handler.handle_event(streaming_.ReasoningStart(block_id=bid)) + for event in handler.handle_event( + streaming_.ReasoningStart(block_id=bid) + ): + yield event if part.text: - yield handler.handle_event( + for event in handler.handle_event( streaming_.ReasoningDelta(block_id=bid, delta=part.text) - ) - yield handler.handle_event( + ): + yield event + for event in handler.handle_event( streaming_.ReasoningEnd(block_id=bid, signature=part.signature) - ) + ): + yield event elif isinstance(part, messages_.ToolCallPart): - yield handler.handle_event( + for event in handler.handle_event( streaming_.ToolStart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, ) - ) + ): + yield event if part.tool_args: - yield handler.handle_event( + for event in handler.handle_event( streaming_.ToolArgsDelta( tool_call_id=part.tool_call_id, delta=part.tool_args, ) - ) - yield handler.handle_event( + ): + yield event + for event in handler.handle_event( streaming_.ToolEnd(tool_call_id=part.tool_call_id) - ) + ): + yield event + + elif isinstance(part, messages_.StructuredOutputPart): + handler._current_parts[part.id] = part + + elif isinstance(part, messages_.FilePart): + for event in handler.handle_event( + streaming_.FileEvent( + block_id=part.id, + media_type=part.media_type, + data=part.data if isinstance(part.data, str) else "", + ) + ): + yield event - yield handler.handle_event(streaming_.MessageDone()) + for event in handler.handle_event(streaming_.MessageDone()): + yield event def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: @@ -172,6 +202,17 @@ def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: return adapter +async def collect_messages( + source: AsyncIterable[events_.Event], +) -> list[messages_.Message]: + """Collect terminal messages from an event stream.""" + result: list[messages_.Message] = [] + async for event in source: + if isinstance(event, events_.MessageEnd): + result.append(event.message) + return result + + class MockGenerateAdapter: """Mock generate adapter that returns pre-configured responses. diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index 71dfa5d9..dff7c294 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -25,7 +25,7 @@ import ai from ai.models.ai_gateway import ai_gateway, errors from ai.models.core import model as model_ -from ai.types import messages +from ai.types import events, messages from .conftest import mock_client, sse, user_msg @@ -45,14 +45,29 @@ async def _collect( msgs: list[messages.Message], model: model_.Model = _TEST_MODEL, **kwargs: Any, -) -> list[messages.Message]: - """Drain ``stream()`` and return all yielded messages.""" - result: list[messages.Message] = [] - async for msg in stream_mod.stream(client, model, msgs, **kwargs): - result.append(msg) +) -> list[events.Event]: + """Drain ``stream()`` and return all yielded events.""" + result: list[events.Event] = [] + async for event in stream_mod.stream(client, model, msgs, **kwargs): + result.append(event) return result +async def _final( + client: Any, + msgs: list[messages.Message], + model: model_.Model = _TEST_MODEL, + **kwargs: Any, +) -> messages.Message: + """Drain ``stream()`` and return the terminal assistant message.""" + result: list[messages.Message] = [] + async for event in stream_mod.stream(client, model, msgs, **kwargs): + if isinstance(event, events.MessageEnd): + result.append(event.message) + assert result + return result[-1] + + # --------------------------------------------------------------------------- # Streaming: text, reasoning, tool calls # --------------------------------------------------------------------------- @@ -79,11 +94,8 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) client = mock_client(httpx.MockTransport(handler)) - msgs = await _collect(client, [user_msg("Hi")]) - - final = msgs[-1] + final = await _final(client, [user_msg("Hi")]) assert final.text == "Hello World" - assert final.is_done assert final.usage is not None assert final.usage.input_tokens == 5 assert final.usage.output_tokens == 2 @@ -102,9 +114,7 @@ async def test_reasoning_then_text(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - final = ( - await _collect(mock_client(httpx.MockTransport(handler)), [user_msg("?")]) - )[-1] + final = await _final(mock_client(httpx.MockTransport(handler)), [user_msg("?")]) assert final.reasoning == "think" assert final.text == "42" @@ -128,11 +138,9 @@ async def test_streaming_tool_call(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - final = ( - await _collect( - mock_client(httpx.MockTransport(handler)), [user_msg("search")] - ) - )[-1] + final = await _final( + mock_client(httpx.MockTransport(handler)), [user_msg("search")] + ) tc = final.tool_calls assert len(tc) == 1 assert tc[0].tool_name == "search" @@ -161,16 +169,13 @@ async def test_inline_file_stream(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - final = ( - await _collect( - mock_client(httpx.MockTransport(handler)), [user_msg("draw me")] - ) - )[-1] + final = await _final( + mock_client(httpx.MockTransport(handler)), [user_msg("draw me")] + ) assert final.text == "Here is an image:" assert len(final.images) == 1 assert final.images[0].media_type == "image/png" assert final.images[0].data == "iVBORw0KGgo=" - assert final.is_done async def test_complete_tool_call_part(self) -> None: """Non-streaming ``tool-call`` part (one shot) must also work.""" @@ -191,11 +196,9 @@ async def test_complete_tool_call_part(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - final = ( - await _collect( - mock_client(httpx.MockTransport(handler)), [user_msg("weather")] - ) - )[-1] + final = await _final( + mock_client(httpx.MockTransport(handler)), [user_msg("weather")] + ) assert len(final.tool_calls) == 1 assert json.loads(final.tool_calls[0].tool_args) == {"city": "SF"} diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py index b123f0e8..7c46d723 100644 --- a/tests/models/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -2,122 +2,89 @@ from __future__ import annotations +from collections.abc import Sequence + from ai.models.core.helpers import streaming -from ai.types import messages -from ai.types.messages import PartClosed, PartDelta, PartOpened +from ai.types import events, messages + -# -- Text streaming -------------------------------------------------------- +def _only[T](items: Sequence[object], typ: type[T]) -> T: + matches = [item for item in items if isinstance(item, typ)] + assert len(matches) == 1 + return matches[0] def test_text_lifecycle() -> None: h = streaming.StreamHandler(message_id="m1") - m = h.handle_event(streaming.TextStart(block_id="b1")) - assert len(m.parts) == 1 - part = m.parts[0] - assert isinstance(part, messages.TextPart) - assert part.text == "" - assert m.stream is not None - assert any( - isinstance(e, PartOpened) and e.part.id == "b1" for e in m.stream.new_events - ) - m = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) - part = m.parts[0] - assert isinstance(part, messages.TextPart) - assert part.text == "Hello" - assert m.stream is not None - assert any( - isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == "Hello" - for e in m.stream.new_events - ) + out = h.handle_event(streaming.TextStart(block_id="b1")) + assert isinstance(out[0], events.TextStart) + assert out[0].block_id == "b1" - m = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) - part = m.parts[0] - assert isinstance(part, messages.TextPart) - assert part.text == "Hello world" - assert m.stream is not None - assert any( - isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == " world" - for e in m.stream.new_events - ) + out = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) + delta = _only(out, events.TextDelta) + assert delta.chunk == "Hello" + assert delta.block_id == "b1" - m = h.handle_event(streaming.TextEnd(block_id="b1")) - part = m.parts[0] - assert isinstance(part, messages.TextPart) - assert m.stream is not None - assert any( - isinstance(e, PartClosed) and e.part.id == "b1" for e in m.stream.new_events - ) - # No delta events in this yield - assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) + out = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) + delta = _only(out, events.TextDelta) + assert delta.chunk == " world" + out = h.handle_event(streaming.TextEnd(block_id="b1")) + assert isinstance(out[0], events.TextEnd) + assert out[0].block_id == "b1" + assert not any(isinstance(event, events.TextDelta) for event in out) -# -- Reasoning streaming --------------------------------------------------- + out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) + msg = _only(out, events.MessageEnd).message + assert msg.text == "Hello world" def test_reasoning_lifecycle() -> None: h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.ReasoningStart(block_id="r1")) - m = h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="thinking")) - part = m.parts[0] - assert isinstance(part, messages.ReasoningPart) - assert part.text == "thinking" - assert m.stream is not None - assert any( - isinstance(e, PartDelta) and e.part.id == "r1" and e.chunk == "thinking" - for e in m.stream.new_events - ) - m = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) - part = m.parts[0] - assert isinstance(part, messages.ReasoningPart) - assert part.signature == "sig123" - assert m.stream is not None - assert any( - isinstance(e, PartClosed) and e.part.id == "r1" for e in m.stream.new_events - ) + out = h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="thinking")) + delta = _only(out, events.ReasoningDelta) + assert delta.chunk == "thinking" + out = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) + end = _only(out, events.ReasoningEnd) + assert end.signature == "sig123" -# -- Tool streaming -------------------------------------------------------- + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + assert msg.reasoning == "thinking" + part = msg.parts[0] + assert isinstance(part, messages.ReasoningPart) + assert part.signature == "sig123" def test_tool_lifecycle() -> None: h = streaming.StreamHandler(message_id="m1") - h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="get_weather")) - m = h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) - part = m.parts[0] - assert isinstance(part, messages.ToolCallPart) - assert part.tool_name == "get_weather" - assert part.tool_args == '{"ci' - assert m.stream is not None - assert any( - isinstance(e, PartDelta) and e.part.id == "tc1" and e.chunk == '{"ci' - for e in m.stream.new_events - ) - m = h.handle_event( - streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}') - ) - part = m.parts[0] - assert isinstance(part, messages.ToolCallPart) - assert part.tool_args == '{"city":"London"}' - - m = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - part = m.parts[0] - assert isinstance(part, messages.ToolCallPart) - assert m.stream is not None - assert any( - isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events + out = h.handle_event( + streaming.ToolStart(tool_call_id="tc1", tool_name="get_weather") ) - # No delta events in this yield - assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) + start = _only(out, events.ToolStart) + assert start.tool_name == "get_weather" + + out = h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) + delta = _only(out, events.ToolDelta) + assert delta.chunk == '{"ci' + + h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) + out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) + assert isinstance(out[0], events.ToolEnd) + assert not any(isinstance(event, events.ToolDelta) for event in out) -# -- Multi-part messages --------------------------------------------------- + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + tc = msg.tool_calls[0] + assert tc.tool_name == "get_weather" + assert tc.tool_args == '{"city":"London"}' def test_reasoning_then_text_then_tool() -> None: - """Full message: reasoning block, text block, tool call.""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.ReasoningStart(block_id="r1")) h.handle_event(streaming.ReasoningDelta(block_id="r1", delta="Let me think")) @@ -129,123 +96,86 @@ def test_reasoning_then_text_then_tool() -> None: h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="search")) h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"q":"test"}')) - m = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - - assert len(m.parts) == 3 - assert isinstance(m.parts[0], messages.ReasoningPart) - assert isinstance(m.parts[1], messages.TextPart) - assert isinstance(m.parts[2], messages.ToolCallPart) - # The last event was ToolEnd(tc1), so only that PartClosed is in events - assert m.stream is not None - assert any( - isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events - ) + out = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) + assert isinstance(out[0], events.ToolEnd) + + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + assert len(msg.parts) == 3 + assert isinstance(msg.parts[0], messages.ReasoningPart) + assert isinstance(msg.parts[1], messages.TextPart) + assert isinstance(msg.parts[2], messages.ToolCallPart) def test_multiple_tool_calls() -> None: - """Parallel tool calls in one message.""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.ToolStart(tool_call_id="tc1", tool_name="read_file")) h.handle_event(streaming.ToolStart(tool_call_id="tc2", tool_name="list_files")) - m = h.handle_event( - streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}') - ) - # Both tools should be in parts - tool_parts = [p for p in m.parts if isinstance(p, messages.ToolCallPart)] - assert len(tool_parts) == 2 - # tc1 has args, tc2 is empty - assert tool_parts[0].tool_args == '{"path":"a.py"}' - assert tool_parts[1].tool_args == "" - + h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}')) h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) - m = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) - # Last event was ToolEnd(tc2), so its PartClosed is in events - assert m.stream is not None - assert any( - isinstance(e, PartClosed) and e.part.id == "tc2" for e in m.stream.new_events - ) + out = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) + assert isinstance(out[0], events.ToolEnd) + assert out[0].tool_call_id == "tc2" - -# -- MessageDone ----------------------------------------------------------- + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + tool_parts = [p for p in msg.parts if isinstance(p, messages.ToolCallPart)] + assert [p.tool_args for p in tool_parts] == ['{"path":"a.py"}', '{"dir":"."}'] def test_message_done_finalizes_all() -> None: h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="hello")) - # Don't send TextEnd -- MessageDone should finalize everything - m = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) - part = m.parts[0] - assert isinstance(part, messages.TextPart) - assert m.is_done - assert m.stream is not None - assert m.stream.is_done + + out = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) + final = _only(out, events.MessageEnd) + assert final.message.text == "hello" def test_message_done_propagates_usage() -> None: - """Usage on MessageDone surfaces on the built Message.""" usage = messages.Usage(input_tokens=10, output_tokens=20) h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="hi")) - # Before MessageDone, usage should not be on the message - m = h.handle_event(streaming.TextEnd(block_id="t1")) - assert m.usage is None - - m = h.handle_event(streaming.MessageDone(usage=usage)) - assert m.usage is not None - assert m.usage.input_tokens == 10 - assert m.usage.output_tokens == 20 - assert m.usage.total_tokens == 30 - - -# -- Message properties propagate ------------------------------------------ + h.handle_event(streaming.TextEnd(block_id="t1")) + final = _only(h.handle_event(streaming.MessageDone(usage=usage)), events.MessageEnd) + assert final.usage is not None + assert final.usage.input_tokens == 10 + assert final.message.usage is not None + assert final.message.usage.total_tokens == 30 def test_deltas_only_on_active_blocks() -> None: - """Delta events should only reference the active block.""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="first")) h.handle_event(streaming.TextEnd(block_id="t1")) h.handle_event(streaming.TextStart(block_id="t2")) - m = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) - - text_parts = [p for p in m.parts if isinstance(p, messages.TextPart)] - assert text_parts[0].text == "first" # t1 snapshot - assert text_parts[1].text == "second" # t2 snapshot - # Only t2 has a delta event in this yield - assert m.stream is not None - assert any( - isinstance(e, PartDelta) and e.part.id == "t2" and e.chunk == "second" - for e in m.stream.new_events - ) - assert not any( - isinstance(e, PartDelta) and e.part.id == "t1" for e in m.stream.new_events - ) + out = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) - -# -- File event (inline images from LLMs like Gemini/GPT-5) --------------- + deltas = [event for event in out if isinstance(event, events.TextDelta)] + assert len(deltas) == 1 + assert deltas[0].block_id == "t2" + assert deltas[0].chunk == "second" def test_file_event_accumulates() -> None: - """FileEvent should produce a FilePart in the message.""" h = streaming.StreamHandler(message_id="m1") - m = h.handle_event( + out = h.handle_event( streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") ) - file_parts = [p for p in m.parts if isinstance(p, messages.FilePart)] - assert len(file_parts) == 1 - assert file_parts[0].media_type == "image/png" - assert file_parts[0].data == "iVBORw0KGgo=" + assert out == [] + + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + assert len(msg.images) == 1 + assert msg.images[0].media_type == "image/png" + assert msg.images[0].data == "iVBORw0KGgo=" def test_file_event_with_text() -> None: - """A message can have both text and file parts (e.g. Gemini image gen).""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="Here is your image:")) @@ -253,26 +183,20 @@ def test_file_event_with_text() -> None: h.handle_event( streaming.FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") ) - m = h.handle_event(streaming.MessageDone(finish_reason="stop")) + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message - assert len(m.parts) == 2 - assert isinstance(m.parts[0], messages.TextPart) - assert m.parts[0].text == "Here is your image:" - assert isinstance(m.parts[1], messages.FilePart) - assert m.parts[1].media_type == "image/png" - assert m.is_done + assert msg.text == "Here is your image:" + assert len(msg.images) == 1 def test_multiple_file_events() -> None: - """Multiple FileEvents produce multiple FileParts.""" h = streaming.StreamHandler(message_id="m1") h.handle_event( streaming.FileEvent(block_id="f1", media_type="image/png", data="png_data") ) - m = h.handle_event( + h.handle_event( streaming.FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") ) - file_parts = [p for p in m.parts if isinstance(p, messages.FilePart)] - assert len(file_parts) == 2 - assert file_parts[0].media_type == "image/png" - assert file_parts[1].media_type == "image/jpeg" + msg = _only(h.handle_event(streaming.MessageDone()), events.MessageEnd).message + + assert [p.media_type for p in msg.images] == ["image/png", "image/jpeg"] diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 8f1ddb1b..3b0c1012 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -9,9 +9,17 @@ import ai from ai import models +from ai.types import events as events_ from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, MOCK_PROVIDER, MockProvider, mock_llm, text_msg +from ..conftest import ( + MOCK_MODEL, + MOCK_PROVIDER, + MockProvider, + collect_messages, + mock_llm, + text_msg, +) # Module-level model so StructuredOutputPart can resolve it by FQN. @@ -29,12 +37,11 @@ async def test_stream_basic() -> None: """ai.models.stream() yields deltas and exposes .text after iteration.""" mock = mock_llm([[text_msg("Hello world")]]) - s = await models.stream(MOCK_MODEL, [ai.user_message("Hi")]) + s = models.stream(MOCK_MODEL, [ai.user_message("Hi")]) deltas: list[str] = [] - async for msg in s: - for ev in msg.deltas: - if isinstance(ev.part, messages_.TextPart): - deltas.append(ev.chunk) + async for event in s: + if isinstance(event, events_.TextDelta): + deltas.append(event.chunk) assert mock.call_count == 1 assert s.text == "Hello world" @@ -49,10 +56,8 @@ async def test_stream_preserves_existing_turn_ids() -> None: old = old.model_copy(update={"turn_id": "prev"}) fresh = ai.user_message("latest") - s = await models.stream(MOCK_MODEL, [old, fresh]) - yielded: list[messages_.Message] = [] - async for msg in s: - yielded.append(msg) + s = models.stream(MOCK_MODEL, [old, fresh]) + yielded = await collect_messages(s) assert mock.call_count == 1 # First yielded is the old input — unchanged. @@ -70,10 +75,8 @@ async def test_stream_accepts_explicit_turn_id() -> None: mock_llm([[text_msg("ok")]]) fresh = ai.user_message("hi") - s = await models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") - yielded: list[messages_.Message] = [] - async for msg in s: - yielded.append(msg) + s = models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") + yielded = await collect_messages(s) assert s.turn_id == "custom-turn" assert yielded[0].turn_id == "custom-turn" @@ -92,13 +95,15 @@ async def _spy_stream( tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[events_.Event]: received_clients.append(client) - yield messages_.Message( + msg = messages_.Message( id="m1", role="assistant", parts=[messages_.TextPart(text="ok")], ) + yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) + yield events_.MessageEnd(message=msg) models.register_stream("mock", _spy_stream) @@ -106,7 +111,7 @@ async def _spy_stream( explicit_model = models.Model( id="mock-model", adapter="mock", provider=MOCK_PROVIDER, client=explicit ) - s = await models.stream(explicit_model, [ai.user_message("Hi")]) + s = models.stream(explicit_model, [ai.user_message("Hi")]) async for _ in s: pass @@ -128,7 +133,7 @@ async def _structured_stream( tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[events_.Event]: text_part = messages_.TextPart(text=json_text) parts: list[messages_.Part] = [text_part] if output_type is not None: @@ -140,11 +145,16 @@ async def _structured_stream( output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", ) ) - yield messages_.Message(id="m1", role="assistant", parts=parts) + msg = messages_.Message(id="m1", role="assistant", parts=parts) + yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) + yield events_.TextStart(block_id=text_part.id) + yield events_.TextDelta(block_id=text_part.id, chunk=json_text) + yield events_.TextEnd(block_id=text_part.id) + yield events_.MessageEnd(message=msg) models.register_stream("mock", _structured_stream) - s = await models.stream( + s = models.stream( MOCK_MODEL, [ai.user_message("Give me a recipe")], output_type=_Recipe ) async for _ in s: diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 0a8ecf14..db911c75 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -12,9 +12,17 @@ import ai from ai import middleware, models from ai.models.core.helpers import streaming as streaming_ +from ai.types import events as events_ from ai.types import messages as messages_ -from .conftest import MOCK_MODEL, mock_generate, mock_llm, text_msg, tool_call_msg +from .conftest import ( + MOCK_MODEL, + collect_messages, + mock_generate, + mock_llm, + text_msg, + tool_call_msg, +) # ── Helpers ────────────────────────────────────────────────────── @@ -97,16 +105,19 @@ async def wrap_hook(self, call: middleware.HookContext, next: Any) -> Any: my_agent = ai.agent() @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: - async for msg in await ai.models.stream(context.model, context.messages): - yield msg + async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: + async for event in ai.models.stream(context.model, context.messages): + yield event await ai.hook("test_hook", payload=Confirmation) mock_llm([[text_msg("OK")]]) - async for msg in my_agent.run( + async for event in my_agent.run( MOCK_MODEL, [ai.user_message("go")], middleware=[Spy()] ): + if not isinstance(event, ai.MessageEnd): + continue + msg = event.message if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): ai.resolve_hook("test_hook", {"approved": True, "reason": "ok"}) @@ -178,22 +189,32 @@ async def stream( tools: Sequence[Any] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kw: Any, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[events_.Event]: captured_messages.append(list(messages)) seq = self._responses[self._idx] self._idx += 1 - handler = streaming_.StreamHandler() + message_id = seq[0].id if seq else messages_.generate_id() + handler = streaming_.StreamHandler(message_id=message_id) + yield handler.message_start() for msg in seq: for i, part in enumerate(msg.parts): if isinstance(part, messages_.TextPart): bid = f"text-{i}" - yield handler.handle_event(streaming_.TextStart(block_id=bid)) + for event in handler.handle_event( + streaming_.TextStart(block_id=bid) + ): + yield event if part.text: - yield handler.handle_event( + for event in handler.handle_event( streaming_.TextDelta(block_id=bid, delta=part.text) - ) - yield handler.handle_event(streaming_.TextEnd(block_id=bid)) - yield handler.handle_event(streaming_.MessageDone()) + ): + yield event + for event in handler.handle_event( + streaming_.TextEnd(block_id=bid) + ): + yield event + for event in handler.handle_event(streaming_.MessageDone()): + yield event adapter = CapturingAdapter([[text_msg("Concise!")]]) models.register_stream("mock", adapter.stream) @@ -228,14 +249,14 @@ async def wrap_model(self, call: middleware.ModelContext, next: Any) -> Any: inner = ai.agent() @ai.tool # type: ignore[arg-type] - async def run_inner(query: str) -> AsyncGenerator[ai.Message]: + async def run_inner(query: str) -> AsyncGenerator[ai.Event]: """Run sub-agent with its own middleware.""" - async for msg in inner.run( + async for event in inner.run( MOCK_MODEL, [ai.user_message(query)], middleware=[Tagger("B")], ): - yield msg + yield event outer = ai.agent(tools=[run_inner]) @@ -269,19 +290,19 @@ async def test_wrap_agent_run_ordering() -> None: class Outer(ai.Middleware): async def wrap_agent_run( self, call: middleware.AgentRunContext, next: Any - ) -> AsyncGenerator[ai.Message]: + ) -> AsyncGenerator[ai.Event]: order.append("outer-before") - async for msg in next(call): - yield msg + async for event in next(call): + yield event order.append("outer-after") class Inner(ai.Middleware): async def wrap_agent_run( self, call: middleware.AgentRunContext, next: Any - ) -> AsyncGenerator[ai.Message]: + ) -> AsyncGenerator[ai.Event]: order.append("inner-before") - async for msg in next(call): - yield msg + async for event in next(call): + yield event order.append("inner-after") my_agent = ai.agent() @@ -354,12 +375,10 @@ async def echo(x: int) -> int: call2 = [text_msg("done")] mock_llm([call1, call2]) - tool_result_msgs: list[ai.Message] = [] - async for m in my_agent.run( - MOCK_MODEL, [ai.user_message("go")], middleware=[Rewriter()] - ): - if m.role == "tool" and m.tool_results: - tool_result_msgs.append(m) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("go")], middleware=[Rewriter()]) + ) + tool_result_msgs = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_result_msgs) >= 1 # The result message should use the rewritten name, not the original. @@ -408,26 +427,26 @@ async def wrap_model( ) -> ai.StreamResultLike: stream_result = await next(call) - async def _transformed() -> AsyncGenerator[messages_.Message]: - async for msg in stream_result: - yield msg + async def _transformed() -> AsyncGenerator[events_.Event]: + async for event in stream_result: + yield event # After the stream ends, yield one more snapshot with extra text. - yield messages_.Message( + msg = messages_.Message( id="appended", role="assistant", parts=[messages_.TextPart(text="original + appended")], ) + yield events_.MessageStart(message=msg.model_copy(update={"parts": []})) + yield events_.MessageEnd(message=msg) return ai.StreamResult.from_generator(_transformed()) my_agent = ai.agent() mock_llm([[text_msg("original")]]) - msgs: list[ai.Message] = [] - async for m in my_agent.run( - MOCK_MODEL, [ai.user_message("Hi")], middleware=[TextAppender()] - ): - msgs.append(m) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("Hi")], middleware=[TextAppender()]) + ) # The last message should be from the appended stream. texts = [m.text for m in msgs if m.text] @@ -483,12 +502,10 @@ async def double(x: int) -> int: call2 = [text_msg("done")] mock_llm([call1, call2]) - tool_result_msgs: list[ai.Message] = [] - async for m in my_agent.run( - MOCK_MODEL, [ai.user_message("go")], middleware=[ArgFixer()] - ): - if m.role == "tool" and m.tool_results: - tool_result_msgs.append(m) + msgs = await collect_messages( + my_agent.run(MOCK_MODEL, [ai.user_message("go")], middleware=[ArgFixer()]) + ) + tool_result_msgs = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_result_msgs) >= 1 # The fixer middleware supplied x=99, so double should return 198. diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index 16a94f18..49586172 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -12,6 +12,7 @@ import ai from ai import models from ai.types import builders, messages +from ai.types import events as events_ from ai.types.integrity import IntegrityError, prepare_messages from ..conftest import MOCK_MODEL, mock_generate, mock_llm, text_msg @@ -481,7 +482,7 @@ async def test_stream_calls_prepare_messages() -> None: with patch( "ai.models.core.api.integrity_.prepare_messages", wraps=lambda m: m ) as spy: - s = await models.stream(MOCK_MODEL, msgs) + s = models.stream(MOCK_MODEL, msgs) async for _ in s: pass spy.assert_called_once_with(msgs) @@ -503,12 +504,12 @@ async def _spy_stream( tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[messages.Message]: + ) -> AsyncGenerator[events_.Event]: received.append(list(messages)) - async for m in original_adapter( + async for event in original_adapter( client, model, messages, tools=tools, output_type=output_type, **kwargs ): - yield m + yield event models.register_stream("mock", _spy_stream) @@ -517,7 +518,7 @@ async def _spy_stream( messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), ai.assistant_message("hello"), ] - s = await models.stream(MOCK_MODEL, msgs) + s = models.stream(MOCK_MODEL, msgs) async for _ in s: pass diff --git a/uv.lock b/uv.lock index 3a033e92..7d20e44c 100644 --- a/uv.lock +++ b/uv.lock @@ -1019,7 +1019,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev9" +version = "0.0.1.dev10" source = { editable = "." } dependencies = [ { name = "anthropic" },