From bca1679ab84de32e9eeb83251d27bc7134a534e3 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 16 Apr 2026 13:02:02 -0700 Subject: [PATCH 01/10] Move the ai sdk ui adapter into agents and split its code --- src/ai/adapters/ai_sdk_ui/__init__.py | 12 - src/ai/adapters/ai_sdk_ui/adapter.py | 606 ------------------ src/ai/{adapters => agents/ui}/__init__.py | 0 src/ai/agents/ui/ai_sdk/__init__.py | 42 ++ src/ai/agents/ui/ai_sdk/inbound.py | 437 +++++++++++++ src/ai/agents/ui/ai_sdk/message_to_ui.py | 456 +++++++++++++ src/ai/agents/ui/ai_sdk/outbound.py | 303 +++++++++ .../ui/ai_sdk}/protocol.py | 0 .../ui/ai_sdk}/ui_message.py | 0 9 files changed, 1238 insertions(+), 618 deletions(-) delete mode 100644 src/ai/adapters/ai_sdk_ui/__init__.py delete mode 100644 src/ai/adapters/ai_sdk_ui/adapter.py rename src/ai/{adapters => agents/ui}/__init__.py (100%) create mode 100644 src/ai/agents/ui/ai_sdk/__init__.py create mode 100644 src/ai/agents/ui/ai_sdk/inbound.py create mode 100644 src/ai/agents/ui/ai_sdk/message_to_ui.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound.py rename src/ai/{adapters/ai_sdk_ui => agents/ui/ai_sdk}/protocol.py (100%) rename src/ai/{adapters/ai_sdk_ui => agents/ui/ai_sdk}/ui_message.py (100%) diff --git a/src/ai/adapters/ai_sdk_ui/__init__.py b/src/ai/adapters/ai_sdk_ui/__init__.py deleted file mode 100644 index 4b8bf175..00000000 --- a/src/ai/adapters/ai_sdk_ui/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .adapter import filter_by_label, to_messages, to_sse_stream, to_ui_message_stream -from .protocol import UI_MESSAGE_STREAM_HEADERS -from .ui_message import UIMessage - -__all__ = [ - "to_ui_message_stream", - "filter_by_label", - "to_sse_stream", - "to_messages", - "UIMessage", - "UI_MESSAGE_STREAM_HEADERS", -] diff --git a/src/ai/adapters/ai_sdk_ui/adapter.py b/src/ai/adapters/ai_sdk_ui/adapter.py deleted file mode 100644 index 99a20a82..00000000 --- a/src/ai/adapters/ai_sdk_ui/adapter.py +++ /dev/null @@ -1,606 +0,0 @@ -""" -Reference: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol -""" - -from __future__ import annotations - -import dataclasses -import json -import logging -from collections.abc import AsyncGenerator, AsyncIterable -from typing import Any - -from ...agents import hooks -from ...agents.hooks import TOOL_APPROVAL_HOOK_TYPE -from ...types import messages as messages_ -from . import protocol, ui_message - -logger = logging.getLogger(__name__) - -# ============================================================================ -# Serialization utilities -# ============================================================================ - - -def _to_camel_case(snake_str: str) -> str: - """Convert snake_case to camelCase.""" - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def serialize_part(part: protocol.UIMessageStreamPart) -> str: - """Serialize a stream part to JSON with camelCase keys.""" - d = dataclasses.asdict(part) - if isinstance(part, protocol.DataPart): - # DataPart's wire type is computed (``data-{data_type}``); replace - # the raw ``data_type`` field with the protocol ``type`` key. - d["type"] = part.type - del d["data_type"] - camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} - return json.dumps(camel_dict) - - -def format_sse(part: protocol.UIMessageStreamPart) -> str: - """Format a stream part as an SSE data line.""" - return f"data: {serialize_part(part)}\n\n" - - -# ============================================================================ -# Internal Message → UI Message Stream Conversion -# ============================================================================ - - -class _StreamState: - """Tracks state for UI message stream event sequencing. - - Encapsulates the mutable state needed to properly sequence events - (reasoning blocks, text blocks, steps, tool calls) when converting - an internal message stream to the AI SDK UI protocol. - """ - - def __init__(self) -> None: - self.text_id: str | None = None - self.reasoning_id: str | None = None - self.label: str | None = None - self.message_id: str | None = None - self.emitted_start: bool = False - self.in_step: bool = False - self.started_tool_calls: set[str] = set() - self.emitted_tool_results: set[str] = set() - self.pending_tool_calls: set[str] = set() - self.emitted_approval_requests: set[str] = set() - - def close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: - """Close any open reasoning/text blocks, returning parts to emit.""" - parts: list[protocol.UIMessageStreamPart] = [] - if self.reasoning_id: - parts.append(protocol.ReasoningEndPart(id=self.reasoning_id)) - self.reasoning_id = None - if self.text_id: - parts.append(protocol.TextEndPart(id=self.text_id)) - self.text_id = None - return parts - - def finish_step(self) -> list[protocol.UIMessageStreamPart]: - """Close open blocks and finish the current step if active.""" - parts = self.close_open_blocks() - if self.in_step: - parts.append(protocol.FinishStepPart()) - self.in_step = False - return parts - - def reset_tool_tracking(self) -> None: - """Reset tool tracking sets (for new message/agent boundaries).""" - self.started_tool_calls = set() - self.emitted_tool_results = set() - self.pending_tool_calls = set() - self.emitted_approval_requests = set() - - def begin_message( - self, msg: messages_.Message - ) -> list[protocol.UIMessageStreamPart]: - """Handle message/step boundaries, returning parts to emit. - - Decides whether to start a new message (first message or agent switch) - or a new step (same stream, different message ID), closing any open - blocks and steps as needed. - """ - parts: list[protocol.UIMessageStreamPart] = [] - is_new_message = self.message_id is not None and msg.id != self.message_id - - if not self.emitted_start or (msg.label and msg.label != self.label): - # First message or label change (new agent) - parts.extend(self.finish_step()) - if self.emitted_start: - parts.append(protocol.FinishPart(finish_reason="stop")) - - parts.append(protocol.StartPart(message_id=msg.id)) - parts.append(protocol.StartStepPart()) - self.emitted_start = True - self.in_step = True - self.label = msg.label - self.message_id = msg.id - self.reset_tool_tracking() - elif is_new_message: - # New message ID within the same stream = new step - parts.extend(self.finish_step()) - parts.append(protocol.StartStepPart()) - self.in_step = True - self.message_id = msg.id - - return parts - - -def _tool_call_id_from_approval_hook( - hook_part: messages_.HookPart, -) -> str | None: - """Extract tool_call_id from a ToolApproval HookPart. - - Returns the tool_call_id if this is a ToolApproval hook whose hook_id - follows the ``approve_{tool_call_id}`` convention, otherwise None. - """ - if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: - return None - prefix = "approve_" - if hook_part.hook_id.startswith(prefix): - return hook_part.hook_id[len(prefix) :] - return None - - -def _is_tool_approval_hook_message(msg: messages_.Message) -> bool: - """True if this message contains only ToolApproval HookParts.""" - if not msg.parts: - return False - return all( - isinstance(p, messages_.HookPart) - and _tool_call_id_from_approval_hook(p) is not None - for p in msg.parts - ) - - -async def to_ui_message_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[protocol.UIMessageStreamPart]: - """ - Convert a proto_sdk message stream into AI SDK UI message stream parts. - - This adapter transforms the internal message format into the AI SDK - protocol that can be consumed by useChat and other AI SDK UI hooks. - """ - state = _StreamState() - - async for msg in messages: - # Tool-result messages (role="tool") are emitted by the Runtime - # as separate Message objects (with their own auto-generated id). - # To the frontend they belong to the *same* step as the tool - # call, so we pin the message id to avoid a spurious step boundary. - if msg.role == "tool" and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - # Tool-approval hook messages are emitted by the Runtime as - # separate Message objects (with their own id). To the frontend - # they belong to the *same* step as the tool call, so we pin - # the message id to avoid creating a spurious step boundary. - if _is_tool_approval_hook_message(msg) and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - for part in state.begin_message(msg): - yield part - - # Handle reasoning streaming (deltas) - reasoning comes before text - if delta := msg.reasoning_delta: - if not state.reasoning_id: - state.reasoning_id = messages_.generate_id("reasoning") - yield protocol.ReasoningStartPart(id=state.reasoning_id) - yield protocol.ReasoningDeltaPart(id=state.reasoning_id, delta=delta) - - # Handle text streaming (deltas) - if delta := msg.text_delta: - # Close reasoning block when text starts (reasoning precedes text) - if state.reasoning_id: - yield protocol.ReasoningEndPart(id=state.reasoning_id) - state.reasoning_id = None - - if not state.text_id: - state.text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=state.text_id) - yield protocol.TextDeltaPart(id=state.text_id, delta=delta) - - # Handle streaming tool call arguments - for tool_delta in msg.tool_deltas: - if tool_delta.tool_call_id not in state.started_tool_calls: - state.started_tool_calls.add(tool_delta.tool_call_id) - yield protocol.ToolInputStartPart( - tool_call_id=tool_delta.tool_call_id, - tool_name=tool_delta.tool_name, - ) - yield protocol.ToolInputDeltaPart( - tool_call_id=tool_delta.tool_call_id, - input_text_delta=tool_delta.args_delta, - ) - - # Handle completed messages - if msg.is_done: - had_active_text = state.text_id is not None - for part in state.close_open_blocks(): - yield part - - # Scan for new pending tool calls or tool results - has_new_pending_tools = any( - isinstance(p, messages_.ToolCallPart) - and p.tool_call_id not in state.pending_tool_calls - for p in msg.parts - ) - has_new_tool_results = any( - isinstance(p, messages_.ToolResultPart) - and p.tool_call_id not in state.emitted_tool_results - for p in msg.parts - ) - - # Process parts in passes: - # 1. Text and pending tool calls (from assistant messages) - # 2. Tool results (from tool messages) - - # Pass 1: Text and pending tool inputs - for msg_part in msg.parts: - match msg_part: - case messages_.TextPart(text=text) if ( - text - and not had_active_text - and not has_new_pending_tools - and not has_new_tool_results - ): - text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=text_id) - yield protocol.TextEndPart(id=text_id) - case messages_.ToolCallPart( - tool_call_id=tc_id, - tool_name=name, - tool_args=args, - ): - if tc_id not in state.started_tool_calls: - state.started_tool_calls.add(tc_id) - yield protocol.ToolInputStartPart( - tool_call_id=tc_id, - tool_name=name, - ) - if tc_id not in state.pending_tool_calls: - state.pending_tool_calls.add(tc_id) - yield protocol.ToolInputAvailablePart( - tool_call_id=tc_id, - tool_name=name, - input=args, - ) - - # Pass 2: Tool results - if has_new_tool_results: - for msg_part in msg.parts: - if ( - isinstance(msg_part, messages_.ToolResultPart) - and msg_part.tool_call_id not in state.emitted_tool_results - ): - state.emitted_tool_results.add(msg_part.tool_call_id) - state.pending_tool_calls.discard(msg_part.tool_call_id) - yield protocol.ToolOutputAvailablePart( - tool_call_id=msg_part.tool_call_id, - output=msg_part.result, - ) - - # Pass 3: Hook-based tool approvals - for msg_part in msg.parts: - if not isinstance(msg_part, messages_.HookPart): - continue - approval_tc_id = _tool_call_id_from_approval_hook(msg_part) - if approval_tc_id is None: - continue - - if msg_part.status == "pending": - if approval_tc_id not in state.emitted_approval_requests: - state.emitted_approval_requests.add(approval_tc_id) - yield protocol.ToolApprovalRequestPart( - approval_id=msg_part.hook_id, - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "resolved": - resolution = msg_part.resolution or {} - if not resolution.get("granted", False): - yield protocol.ToolOutputDeniedPart( - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "cancelled": - yield protocol.ToolOutputErrorPart( - tool_call_id=approval_tc_id, - error_text="Hook cancelled", - ) - - # Final cleanup - for part in state.finish_step(): - yield part - if state.emitted_start: - yield protocol.FinishPart(finish_reason="stop") - - -async def filter_by_label( - messages: AsyncIterable[messages_.Message], - label: str | None = None, -) -> AsyncGenerator[messages_.Message]: - """Filter a message stream to a single agent label. - - If label is provided, only messages with that label pass through. - If label is None, auto-locks to whichever label arrives first. - """ - async for msg in messages: - if label is None: - label = msg.label - if msg.label == label: - yield msg - - -async def to_sse_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[str]: - """Convert a proto_sdk message stream directly into SSE-formatted strings.""" - async for part in to_ui_message_stream(messages): - yield format_sse(part) - - -# ============================================================================ -# Tool conversion helpers -# ============================================================================ - -_TOOL_RESULT_STATES: frozenset[str] = frozenset({"output-available"}) -_TOOL_ERROR_STATES: frozenset[str] = frozenset({"output-error", "output-denied"}) - - -def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates a completed tool.""" - return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES - - -def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates an error.""" - return state in _TOOL_ERROR_STATES - - -def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: - """Normalize tool input (JSON string, dict, or None) to a JSON string.""" - match tool_input: - case str(): - return tool_input - case dict(): - return json.dumps(tool_input) - case _: - return "{}" - - -def _normalize_tool_result(output: Any) -> dict[str, Any] | None: - """Normalize tool output to dict format for internal ToolResultPart. - - The internal ToolResultPart.result expects dict | None, but AI SDK - output can be any type. Wrap non-dict results for compatibility. - """ - if output is None: - return None - return output if isinstance(output, dict) else {"value": output} - - -def to_messages( - ui_messages: list[ui_message.UIMessage], -) -> list[messages_.Message]: - """Convert AI SDK v6 UI messages to internal Message format. - - As a side-effect, tool parts in ``approval-responded`` state trigger - ``ToolApproval.resolve()`` so the agent loop can resume execution - without the caller needing to handle approval routing explicitly. - - When approvals are resolved, the trailing assistant message is - automatically stripped to avoid sending duplicate tool-use content - to the LLM on re-entry. - - Args: - ui_messages: List of UIMessage objects from the AI SDK v6 frontend. - - Returns: - List of internal Message objects ready for use with the runtime. - """ - result: list[messages_.Message] = [] - resolved_any_approval = False - - for ui_msg in ui_messages: - # For assistant messages, separate tool calls from tool results. - assistant_parts: list[messages_.Part] = [] - tool_result_parts: list[messages_.ToolResultPart] = [] - - for part in ui_msg.parts: - match part: - case ui_message.UITextPart(text=text) if text: - assistant_parts.append(messages_.TextPart(text=text)) - - case ui_message.UIReasoningPart(reasoning=reasoning): - assistant_parts.append(messages_.ReasoningPart(text=reasoning)) - - case ui_message.UIToolInvocationPart() as inv: - # Legacy tool-invocation type — always create the call part - tool_args = json.dumps(inv.args) if inv.args else "{}" - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - tool_args=tool_args, - ) - ) - if _is_tool_completed(inv.state): - tool_result_parts.append( - messages_.ToolResultPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - result=inv.result, - is_error=_is_tool_error(inv.state), - ) - ) - - case ui_message.UIToolPart() as tp: - # Dynamic tool-{toolName} type (e.g., "tool-get_weather") - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - tool_args=_normalize_tool_args(tp.input), - ) - ) - if _is_tool_completed(tp.state): - tool_result_parts.append( - messages_.ToolResultPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - result=_normalize_tool_result(tp.output), - is_error=_is_tool_error(tp.state), - ) - ) - # Side-effect: resolve ToolApproval hooks from approval - # responses so the agent loop can resume execution. - if ( - tp.state == "approval-responded" - and tp.approval is not None - and tp.approval.approved is not None - ): - hooks.resolve_hook( - tp.approval.id, - { - "granted": tp.approval.approved, - "reason": tp.approval.reason, - }, - ) - resolved_any_approval = True - - case ui_message.UIFilePart() as fp: - assistant_parts.append( - messages_.FilePart( - data=fp.url, - media_type=fp.media_type, - filename=fp.filename, - ) - ) - - case ( - ui_message.UIStepStartPart() - | ui_message.UISourceUrlPart() - | ui_message.UISourceDocumentPart() - ): - pass # Skip unsupported/boundary parts - - # Validate user/system messages have content - OpenAI requires it. - if ui_msg.role in ("user", "system") and not assistant_parts: - raise ValueError( - f"Message '{ui_msg.id}' has role '{ui_msg.role}' but no content. " - "User and system messages require non-empty content." - ) - - # The UI sends one assistant message per conversation turn, but a - # single turn may span multiple default-loop iterations (e.g. - # [text, tool_call, tool_result, text, tool_call, tool_result, text]). - # LLM APIs expect one message per iteration, so split into - # assistant + tool message pairs at tool-result boundaries. - if ui_msg.role == "assistant": - result.extend( - _split_assistant_parts( - assistant_parts, tool_result_parts, msg_id=ui_msg.id - ) - ) - else: - result.append( - messages_.Message( - id=ui_msg.id, - role=ui_msg.role, - parts=assistant_parts, - ) - ) - - # When resuming after approvals were resolved above, the frontend - # sends the full history including the assistant message from the - # interrupted run. Strip it to avoid duplicate tool-use content. - if resolved_any_approval and result and result[-1].role == "assistant": - logger.info("Stripping trailing assistant message (approvals resolved)") - result = result[:-1] - - return result - - -def _split_assistant_parts( - parts: list[messages_.Part], - tool_results: list[messages_.ToolResultPart], - msg_id: str, -) -> list[messages_.Message]: - """Split assistant parts into assistant + tool message pairs. - - The UI sends one big assistant message per turn, but internally each - loop iteration produces an assistant message (with tool calls) followed - by a tool message (with results). This reconstructs that structure. - - Returns a list of Messages: alternating assistant and tool messages, - split at tool-call boundaries when results are available. - """ - # Index tool results by their tool_call_id for lookup - results_by_id = {tr.tool_call_id: tr for tr in tool_results} - - messages: list[messages_.Message] = [] - current: list[messages_.Part] = [] - pending_results: list[messages_.ToolResultPart] = [] - - for part in parts: - current.append(part) - - # When we see a ToolCallPart that has a result, accumulate it - if ( - isinstance(part, messages_.ToolCallPart) - and part.tool_call_id in results_by_id - ): - pending_results.append(results_by_id[part.tool_call_id]) - - # If there are pending results and more parts follow, we need to split. - # Walk again, splitting at boundaries where all accumulated tool calls - # have results and a non-tool part follows. - if not pending_results: - # No completed tools — single assistant message - if current: - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - return messages - - # Re-walk to split at tool-call boundaries - messages = [] - current = [] - current_results: list[messages_.ToolResultPart] = [] - seen_tool_call = False - - for part in parts: - # If we had a completed tool call group and now see a non-tool part, - # split here - if ( - seen_tool_call - and current_results - and not isinstance(part, messages_.ToolCallPart) - ): - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - messages.append(messages_.Message(role="tool", parts=list(current_results))) - current = [] - current_results = [] - seen_tool_call = False - - current.append(part) - - if isinstance(part, messages_.ToolCallPart): - seen_tool_call = True - if part.tool_call_id in results_by_id: - current_results.append(results_by_id[part.tool_call_id]) - - # Flush remaining - if current: - messages.append(messages_.Message(role="assistant", parts=current, id=msg_id)) - if current_results: - messages.append(messages_.Message(role="tool", parts=list(current_results))) - - return messages diff --git a/src/ai/adapters/__init__.py b/src/ai/agents/ui/__init__.py similarity index 100% rename from src/ai/adapters/__init__.py rename to src/ai/agents/ui/__init__.py diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py new file mode 100644 index 00000000..8bb4bcaf --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -0,0 +1,42 @@ +# inbound: UI -> internal +from .inbound import ( + ApprovalResponse, + extract_approvals, + normalize_ui_messages, + ui_to_messages, +) + +# outbound: internal -> SSE stream +from .outbound import filter_by_label, stream_to_sse, stream_to_ui + +# message_to_ui: internal -> UI format (persistence/history) +from .message_to_ui import ( + UIMessageBuilder, + messages_to_ui, + parts_to_ui, + ui_parts_to_dicts, +) + +# data models +from .protocol import UI_MESSAGE_STREAM_HEADERS +from .ui_message import UIMessage + +__all__ = [ + # inbound + "ui_to_messages", + "normalize_ui_messages", + "extract_approvals", + "ApprovalResponse", + # outbound + "stream_to_ui", + "stream_to_sse", + "filter_by_label", + # message_to_ui + "messages_to_ui", + "parts_to_ui", + "ui_parts_to_dicts", + "UIMessageBuilder", + # data models + "UIMessage", + "UI_MESSAGE_STREAM_HEADERS", +] diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py new file mode 100644 index 00000000..e8f32d70 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -0,0 +1,437 @@ +""" +Inbound: UI -> internal message conversion. + +Converts AI SDK v6 UIMessages into internal Message objects +for use with the runtime/agent loop. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, NamedTuple + +from ...types import messages as messages_ +from . import ui_message + +logger = logging.getLogger(__name__) + + +_TOOL_RESULT_STATES: frozenset[str] = frozenset({"output-available"}) +_TOOL_ERROR_STATES: frozenset[str] = frozenset({"output-error", "output-denied"}) + + +def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: + """Return True if the tool invocation state indicates a completed tool.""" + return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES + + +def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: + """Return True if the tool invocation state indicates an error.""" + return state in _TOOL_ERROR_STATES + + +def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: + """Normalize tool input (JSON string, dict, or None) to a JSON string.""" + match tool_input: + case str(): + return tool_input + case dict(): + return json.dumps(tool_input) + case _: + return "{}" + + +def _normalize_tool_result(output: Any) -> dict[str, Any] | None: + """Normalize tool output to dict format for internal ToolResultPart. + + The internal ToolResultPart.result expects dict | None, but AI SDK + output can be any type. Wrap non-dict results for compatibility. + """ + if output is None: + return None + return output if isinstance(output, dict) else {"value": output} + + +def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None: + """Normalize an error-state tool result.""" + normalized = _normalize_tool_result(output) + if error_text: + if normalized is None: + return {"error": error_text} + if isinstance(normalized, dict) and "error" not in normalized: + return {"error": error_text, **normalized} + return normalized + + +def _approval_signal_part(tp: ui_message.UIToolPart) -> messages_.HookPart | None: + """Reconstruct approval signal state from a UI tool part when possible.""" + approval = tp.approval + if approval is None: + return None + + if tp.state == "approval-requested": + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="pending", + ) + + if tp.state == "approval-responded" and approval.approved is not None: + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="resolved", + resolution={ + "granted": approval.approved, + "reason": approval.reason, + }, + ) + + if tp.state == "output-denied": + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="resolved", + resolution={ + "granted": False, + "reason": approval.reason, + }, + ) + + return None + + +# ============================================================================ +# Approval extraction +# ============================================================================ + + +class ApprovalResponse(NamedTuple): + """Extracted approval response from a UIToolPart in approval-responded state.""" + + hook_id: str + granted: bool + reason: str | None + + +def extract_approvals( + ui_messages: list[ui_message.UIMessage], +) -> list[ApprovalResponse]: + """Extract approval responses from UI messages. + + Walks UIMessages looking for UIToolParts in ``approval-responded`` state + and returns the approval data as a list. Pure function -- does not + resolve hooks or trigger any side-effects. + + Args: + ui_messages: List of UIMessage objects from the AI SDK v6 frontend. + + Returns: + List of ApprovalResponse tuples with hook_id, granted, and reason. + """ + approvals: list[ApprovalResponse] = [] + for ui_msg in ui_messages: + for part in ui_msg.parts: + if not isinstance(part, ui_message.UIToolPart): + continue + if ( + part.state == "approval-responded" + and part.approval is not None + and part.approval.approved is not None + ): + approvals.append( + ApprovalResponse( + hook_id=part.approval.id, + granted=part.approval.approved, + reason=part.approval.reason, + ) + ) + return approvals + + +# ============================================================================ +# UI message normalization (heal stale tool states) +# ============================================================================ + + +def normalize_ui_messages( + ui_messages: list[ui_message.UIMessage], +) -> list[ui_message.UIMessage]: + """Heal stale tool-part states from previously persisted assistant history. + + Tool parts may be stored in transient states (e.g. ``"call"``) if the + stream was interrupted. This normalizes them to consistent terminal + states based on what data is actually present. + """ + normalized: list[ui_message.UIMessage] = [] + for message in ui_messages: + new_parts = [] + changed = False + for part in message.parts: + part_type = getattr(part, "type", None) + state = getattr(part, "state", None) + if isinstance(part_type, str) and part_type.startswith("tool-"): + output = getattr(part, "output", None) + approval = getattr(part, "approval", None) + approved = approval.approved if approval is not None else None + error_text = getattr(part, "error_text", None) + + next_state = state + if output is not None: + if state == "output-error" or error_text is not None: + next_state = "output-error" + elif state == "output-denied" or approved is False: + next_state = "output-denied" + else: + next_state = "output-available" + elif state == "call": + next_state = "input-available" + + if next_state != state: + part = part.model_copy(update={"state": next_state}) + changed = True + + new_parts.append(part) + + normalized.append( + message.model_copy(update={"parts": new_parts}) if changed else message + ) + return normalized + + +# ============================================================================ +# UI -> internal message conversion +# ============================================================================ + + +def ui_to_messages( + ui_messages: list[ui_message.UIMessage], +) -> list[messages_.Message]: + """Convert AI SDK v6 UI messages to internal Message format. + + This is a pure data transformation. It does not resolve hooks or + trigger any side-effects. Use ``extract_approvals()`` separately + to obtain approval responses for hook resolution. + + When the last message is an assistant message that contains + approval-responded tool parts, it is automatically stripped to + avoid sending duplicate tool-use content to the LLM on re-entry. + The caller should check ``extract_approvals()`` to determine + whether this stripping occurred. + + Args: + ui_messages: List of UIMessage objects from the AI SDK v6 frontend. + + Returns: + List of internal Message objects ready for use with the runtime. + """ + result: list[messages_.Message] = [] + has_approval_responses = False + + for ui_msg in ui_messages: + assistant_parts: list[messages_.Part] = [] + tool_result_parts: list[messages_.ToolResultPart] = [] + signal_parts: list[messages_.HookPart] = [] + + for part in ui_msg.parts: + match part: + case ui_message.UITextPart(text=text) if text: + assistant_parts.append(messages_.TextPart(text=text)) + + case ui_message.UIReasoningPart(reasoning=reasoning): + assistant_parts.append(messages_.ReasoningPart(text=reasoning)) + + case ui_message.UIToolInvocationPart() as inv: + # Legacy tool-invocation type + tool_args = json.dumps(inv.args) if inv.args else "{}" + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + tool_args=tool_args, + ) + ) + if _is_tool_completed(inv.state): + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + result=inv.result, + is_error=_is_tool_error(inv.state), + ) + ) + + case ui_message.UIToolPart() as tp: + # Dynamic tool-{toolName} type (e.g., "tool-get_weather") + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=_normalize_tool_args(tp.input), + ) + ) + approval_signal = _approval_signal_part(tp) + if approval_signal is not None: + signal_parts.append(approval_signal) + + if tp.state in _TOOL_RESULT_STATES: + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + result=_normalize_tool_result(tp.output), + is_error=False, + ) + ) + elif tp.state == "output-error": + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + result=_error_result(tp.error_text, tp.output), + is_error=True, + ) + ) + if tp.state == "approval-responded": + has_approval_responses = True + + case ui_message.UIFilePart() as fp: + assistant_parts.append( + messages_.FilePart( + data=fp.url, + media_type=fp.media_type, + filename=fp.filename, + ) + ) + + case ( + ui_message.UIStepStartPart() + | ui_message.UISourceUrlPart() + | ui_message.UISourceDocumentPart() + ): + pass # Skip unsupported/boundary parts + + # Validate user/system messages have content - OpenAI requires it. + if ui_msg.role in ("user", "system") and not assistant_parts: + raise ValueError( + f"Message '{ui_msg.id}' has role '{ui_msg.role}' but no content. " + "User and system messages require non-empty content." + ) + + # The UI sends one assistant message per conversation turn, but a + # single turn may span multiple loop iterations (e.g. + # [text, tool_call, tool_result, text, tool_call, tool_result, text]). + # LLM APIs expect one message per iteration, so split into + # assistant + tool message pairs at tool-result boundaries. + if ui_msg.role == "assistant": + split_messages = _split_assistant_parts( + assistant_parts, tool_result_parts, msg_id=ui_msg.id + ) + result.extend(split_messages) + if signal_parts: + result.extend( + messages_.Message( + id=ui_msg.id, + role="signal", + parts=[part], + ) + for part in signal_parts + ) + else: + result.append( + messages_.Message( + id=ui_msg.id, + role=ui_msg.role, + parts=assistant_parts, + ) + ) + + # When resuming after approval responses, the frontend sends the full + # history including the assistant message from the interrupted run. + # Strip it to avoid sending duplicate tool-use content to the LLM. + if has_approval_responses and result and result[-1].role == "assistant": + logger.info("Stripping trailing assistant message (approval responses present)") + result = result[:-1] + + return result + + +def _split_assistant_parts( + parts: list[messages_.Part], + tool_results: list[messages_.ToolResultPart], + msg_id: str, +) -> list[messages_.Message]: + """Split assistant parts into assistant + tool message pairs. + + The UI sends one big assistant message per turn, but internally each + loop iteration produces an assistant message (with tool calls) followed + by a tool message (with results). This reconstructs that structure. + + Returns a list of Messages: alternating assistant and tool messages, + split at tool-call boundaries when results are available. + """ + # Index tool results by their tool_call_id for lookup + results_by_id = {tr.tool_call_id: tr for tr in tool_results} + + messages: list[messages_.Message] = [] + current: list[messages_.Part] = [] + pending_results: list[messages_.ToolResultPart] = [] + + for part in parts: + current.append(part) + + # When we see a ToolCallPart that has a result, accumulate it + if ( + isinstance(part, messages_.ToolCallPart) + and part.tool_call_id in results_by_id + ): + pending_results.append(results_by_id[part.tool_call_id]) + + # If there are pending results and more parts follow, we need to split. + # Walk again, splitting at boundaries where all accumulated tool calls + # have results and a non-tool part follows. + if not pending_results: + # No completed tools -- single assistant message + if current: + messages.append( + messages_.Message(role="assistant", parts=current, id=msg_id) + ) + return messages + + # Re-walk to split at tool-call boundaries + messages = [] + current = [] + current_results: list[messages_.ToolResultPart] = [] + seen_tool_call = False + + for part in parts: + # If we had a completed tool call group and now see a non-tool part, + # split here + if ( + seen_tool_call + and current_results + and not isinstance(part, messages_.ToolCallPart) + ): + messages.append( + messages_.Message(role="assistant", parts=current, id=msg_id) + ) + messages.append(messages_.Message(role="tool", parts=list(current_results))) + current = [] + current_results = [] + seen_tool_call = False + + current.append(part) + + if isinstance(part, messages_.ToolCallPart): + seen_tool_call = True + if part.tool_call_id in results_by_id: + current_results.append(results_by_id[part.tool_call_id]) + + # Flush remaining + if current: + messages.append(messages_.Message(role="assistant", parts=current, id=msg_id)) + if current_results: + messages.append(messages_.Message(role="tool", parts=list(current_results))) + + return messages diff --git a/src/ai/agents/ui/ai_sdk/message_to_ui.py b/src/ai/agents/ui/ai_sdk/message_to_ui.py new file mode 100644 index 00000000..6778504e --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/message_to_ui.py @@ -0,0 +1,456 @@ +""" +Message to UI: internal Message -> UI format conversion. + +Converts internal Message objects back into AI SDK v6 UI format, +both for persistence (DB storage, GET endpoints) and for +accumulating streaming assistant turns. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any + +from ...types import messages as messages_ +from . import ui_message + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Part-level conversions +# ============================================================================ + + +def parts_to_ui(parts: list[messages_.Part]) -> list[dict[str, Any]]: + """Convert internal Part objects to UI-compatible dicts. + + The frontend expects the AI SDK UI protocol shape (``type``, ``text``, + ``toolCallId``, ``toolName``, ``input``, ``output``, ``state``, etc.) + which differs from the internal model. + + Handles: TextPart, ReasoningPart, ToolCallPart, ToolResultPart, FilePart. + ToolResultParts are returned as standalone dicts (with ``output`` and + terminal state); callers that want merged tool-call+result dicts should + use ``UIMessageBuilder`` instead. + """ + result: list[dict[str, Any]] = [] + for part in parts: + if isinstance(part, messages_.TextPart): + if part.text: + result.append({"type": "text", "text": part.text}) + elif isinstance(part, messages_.ReasoningPart): + if part.text: + result.append({"type": "reasoning", "reasoning": part.text}) + elif isinstance(part, messages_.ToolCallPart): + result.append( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "toolName": part.tool_name, + "state": "input-available", + "input": _normalize_tool_input(part.tool_args), + } + ) + elif isinstance(part, messages_.ToolResultPart): + state = "output-error" if part.is_error else "output-available" + result.append( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "toolName": part.tool_name, + "state": state, + "output": part.result, + } + ) + elif isinstance(part, messages_.FilePart): + entry: dict[str, Any] = { + "type": "file", + "mediaType": part.media_type, + "url": part.data if isinstance(part.data, str) else "", + } + if part.filename: + entry["filename"] = part.filename + result.append(entry) + return result + + +def ui_parts_to_dicts( + parts: list[ui_message.UIMessagePart], +) -> list[dict[str, Any]]: + """Serialize UIMessage parts to plain dicts for DB storage.""" + return [ + part.model_dump() if hasattr(part, "model_dump") else dict(part) # type: ignore[call-overload] + for part in parts + ] + + +# ============================================================================ +# Message-level conversion (batch, for history loading) +# ============================================================================ + + +def messages_to_ui( + messages: list[messages_.Message], +) -> list[ui_message.UIMessage]: + """Convert internal Messages to UIMessages. + + This is the inverse of ``inbound.ui_to_messages()``. It merges + consecutive assistant + tool + signal message groups back into single + UIMessages with the correct tool states, producing the format + expected by the AI SDK frontend and suitable for DB persistence. + + User/system messages are converted directly. Assistant messages + are accumulated until a non assistant/tool/signal message is seen (or + the list ends), merging tool results and approval signals into the + preceding assistant's tool-call parts. + """ + result: list[ui_message.UIMessage] = [] + + i = 0 + while i < len(messages): + msg = messages[i] + + if msg.role in ("user", "system"): + result.append( + ui_message.UIMessage( + id=msg.id, + role=msg.role, + parts=_internal_parts_to_ui_parts(msg.parts), + ) + ) + i += 1 + continue + + if msg.role == "assistant": + # Accumulate: merge this assistant message with any following + # tool/signal messages, and possibly more assistant+tool pairs that + # belong to the same UI turn. + ui_parts: list[ui_message.UIMessagePart] = [] + turn_id = msg.id + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "signal", + ): + current = messages[i] + if current.role == "assistant": + ui_parts.extend(_internal_parts_to_ui_parts(current.parts)) + elif current.role == "tool": + _merge_tool_results(ui_parts, current.parts) + elif current.role == "signal": + _merge_signal_parts(ui_parts, current.parts) + i += 1 + + result.append( + ui_message.UIMessage( + id=turn_id, + role="assistant", + parts=ui_parts, + ) + ) + continue + + # Skip signal, tool (orphaned), or unknown roles + i += 1 + + return result + + +def _internal_parts_to_ui_parts( + parts: list[messages_.Part], +) -> list[ui_message.UIMessagePart]: + """Convert internal Part objects to UIMessagePart objects.""" + result: list[ui_message.UIMessagePart] = [] + for part in parts: + if isinstance(part, messages_.TextPart) and part.text: + result.append(ui_message.UITextPart(type="text", text=part.text)) + elif isinstance(part, messages_.ReasoningPart) and part.text: + result.append( + ui_message.UIReasoningPart(type="reasoning", reasoning=part.text) + ) + elif isinstance(part, messages_.ToolCallPart): + tool_input = _normalize_tool_input(part.tool_args) + result.append( + ui_message.UIToolPart( + type=f"tool-{part.tool_name}", + tool_call_id=part.tool_call_id, + state="input-available", + input=tool_input, + ) + ) + elif isinstance(part, messages_.FilePart): + result.append( + ui_message.UIFilePart( + type="file", + media_type=part.media_type, + url=part.data if isinstance(part.data, str) else "", + filename=part.filename, + ) + ) + return result + + +def _merge_tool_results( + ui_parts: list[ui_message.UIMessagePart], + tool_parts: list[messages_.Part], +) -> None: + """Merge tool result parts into existing UI tool-call parts in-place. + + Finds the matching UIToolPart by tool_call_id and updates its state + and output to reflect the tool result. + """ + # Index existing tool parts by tool_call_id + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in tool_parts: + if not isinstance(part, messages_.ToolResultPart): + continue + idx = tool_index.get(part.tool_call_id) + if idx is None: + continue + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + if existing.state == "output-denied": + continue + state = "output-error" if part.is_error else "output-available" + ui_parts[idx] = existing.model_copy( + update={"state": state, "output": part.result} + ) + + +def _merge_signal_parts( + ui_parts: list[ui_message.UIMessagePart], + signal_parts: list[messages_.Part], +) -> None: + """Merge HookPart approval state into existing UI tool-call parts.""" + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in signal_parts: + if not isinstance(part, messages_.HookPart): + continue + + tool_call_id = _tool_call_id_from_approval_id(part.hook_id) + if tool_call_id is None: + continue + + idx = tool_index.get(tool_call_id) + if idx is None: + continue + + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + + updates: dict[str, Any] = {} + if part.status == "pending": + updates["state"] = "approval-requested" + updates["approval"] = {"id": part.hook_id} + elif part.status == "resolved": + resolution = part.resolution or {} + updates["approval"] = { + "id": part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + } + if resolution.get("granted", False): + updates["state"] = "approval-responded" + else: + updates["state"] = "output-denied" + updates["output"] = None + elif part.status == "cancelled": + updates["state"] = "output-error" + updates["error_text"] = "Hook cancelled" + + if updates: + ui_parts[idx] = existing.model_copy(update=updates) + + +# ============================================================================ +# UIMessageBuilder: streaming accumulator for assistant turns +# ============================================================================ + + +class UIMessageBuilder: + """Accumulate streaming runtime messages into a single UI assistant message. + + Processes internal Message objects as they arrive from the agent loop + and builds up a single UIMessage with all parts in the AI SDK UI format. + Handles text, reasoning, tool calls, tool results, and approval signals. + + Usage:: + + builder = UIMessageBuilder() + async for msg in agent_stream: + builder.ingest(msg) + ui_msg = builder.build() + """ + + def __init__(self, message_id: str | None = None) -> None: + self.message_id = message_id + self.parts: list[dict[str, Any]] = [] + self._tool_indexes: dict[str, int] = {} + + @classmethod + def from_ui_message(cls, message: ui_message.UIMessage) -> UIMessageBuilder: + """Seed the builder from an existing UI assistant message (for resume).""" + builder = cls(message_id=message.id) + builder.parts = ui_parts_to_dicts(message.parts) + for index, part in enumerate(builder.parts): + part_type = part.get("type") + tool_call_id = part.get("toolCallId") + if ( + isinstance(part_type, str) + and part_type.startswith("tool-") + and isinstance(tool_call_id, str) + ): + builder._tool_indexes[tool_call_id] = index + return builder + + def ingest(self, message: messages_.Message) -> None: + """Consume one runtime message. + + Routes by role: + - ``assistant`` (done): appends text, reasoning, and tool-call parts + - ``tool``: updates existing tool parts with results + - ``signal``: updates tool parts with approval state + """ + if message.role == "assistant" and message.is_done: + self._ingest_assistant(message) + elif message.role == "tool": + self._ingest_tool(message) + elif message.role == "signal": + self._ingest_signal(message) + + def build(self) -> ui_message.UIMessage | None: + """Return the accumulated UIMessage, or None if nothing was ingested.""" + if not self.parts or self.message_id is None: + return None + # Parse the accumulated dicts back into typed UIMessageParts + parsed_parts: list[ui_message.UIMessagePart] = [] + for part_dict in self.parts: + parsed = ui_message._parse_ui_part(part_dict) + if parsed is not None: + parsed_parts.append(parsed) + return ui_message.UIMessage( + id=self.message_id, + role="assistant", + parts=parsed_parts, + ) + + @property + def raw_parts(self) -> list[dict[str, Any]]: + """Access the accumulated parts as raw dicts (for direct DB storage).""" + return self.parts + + # -- Private ingest handlers -- + + def _ingest_assistant(self, message: messages_.Message) -> None: + if self.message_id is None: + self.message_id = message.id + for part in message.parts: + if isinstance(part, messages_.ReasoningPart) and part.text: + candidate = {"type": "reasoning", "reasoning": part.text} + if self.parts[-1:] != [candidate]: + self.parts.append(candidate) + elif isinstance(part, messages_.TextPart) and part.text: + candidate = {"type": "text", "text": part.text} + if self.parts[-1:] != [candidate]: + self.parts.append(candidate) + elif isinstance(part, messages_.ToolCallPart): + if part.tool_call_id in self._tool_indexes: + continue + self._tool_indexes[part.tool_call_id] = len(self.parts) + self.parts.append( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "toolName": part.tool_name, + "state": "input-available", + "input": _normalize_tool_input(part.tool_args), + } + ) + + def _ingest_tool(self, message: messages_.Message) -> None: + for part in message.parts: + if not isinstance(part, messages_.ToolResultPart): + continue + index = self._tool_indexes.get(part.tool_call_id) + if index is None: + continue + tool_part = dict(self.parts[index]) + if tool_part.get("state") != "output-denied": + tool_part["state"] = ( + "output-error" if part.is_error else "output-available" + ) + tool_part["output"] = part.result + self.parts[index] = tool_part + + def _ingest_signal(self, message: messages_.Message) -> None: + hook_part = message.get_hook_part() + if hook_part is None: + return + tool_call_id = _tool_call_id_from_approval_id(hook_part.hook_id) + if tool_call_id is None: + return + index = self._tool_indexes.get(tool_call_id) + if index is None: + return + + tool_part = dict(self.parts[index]) + if hook_part.status == "pending": + tool_part["state"] = "approval-requested" + tool_part["approval"] = {"id": hook_part.hook_id} + elif hook_part.status == "resolved": + resolution = hook_part.resolution or {} + tool_part["approval"] = { + "id": hook_part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + } + if resolution.get("granted", False): + tool_part["state"] = "approval-responded" + else: + tool_part["state"] = "output-denied" + elif hook_part.status == "cancelled": + tool_part["state"] = "output-error" + tool_part["errorText"] = "Hook cancelled" + self.parts[index] = tool_part + + +# ============================================================================ +# Shared helpers +# ============================================================================ + + +def _tool_call_id_from_approval_id(approval_id: str) -> str | None: + """Extract the tool_call_id from a ToolApproval hook label. + + E.g. ``"approve_tc_abc123"`` -> ``"tc_abc123"``. + """ + prefix = "approve_" + if approval_id.startswith(prefix): + return approval_id[len(prefix) :] + return None + + +def _normalize_tool_input(raw: str) -> str | dict[str, Any]: + """Normalize tool input for the UI's accepted string-or-dict shape. + + Tries to parse the JSON string into a dict. Falls back to the + raw string if parsing fails or the result isn't a dict. + """ + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return raw + return parsed if isinstance(parsed, dict) else raw diff --git a/src/ai/agents/ui/ai_sdk/outbound.py b/src/ai/agents/ui/ai_sdk/outbound.py new file mode 100644 index 00000000..b5494f8a --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound.py @@ -0,0 +1,303 @@ +""" +Outbound: internal message stream -> AI SDK UI stream. + +Converts the internal runtime stream into AI SDK UI stream protocol parts +and optionally serializes them as SSE payloads. +""" + +from __future__ import annotations + +import dataclasses +import json +from collections.abc import AsyncGenerator, AsyncIterable + +from ...agents.hooks import TOOL_APPROVAL_HOOK_TYPE +from ...types import messages as messages_ +from . import protocol + + +def _to_camel_case(snake_str: str) -> str: + """Convert snake_case to camelCase.""" + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def serialize_part(part: protocol.UIMessageStreamPart) -> str: + """Serialize a stream part to JSON with camelCase keys.""" + d = dataclasses.asdict(part) + if isinstance(part, protocol.DataPart): + d["type"] = part.type + del d["data_type"] + camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} + return json.dumps(camel_dict) + + +def format_sse(part: protocol.UIMessageStreamPart) -> str: + """Format a stream part as an SSE data line.""" + return f"data: {serialize_part(part)}\n\n" + + +class _StreamState: + """Tracks state for UI message stream event sequencing.""" + + def __init__(self) -> None: + self.text_id: str | None = None + self.reasoning_id: str | None = None + self.label: str | None = None + self.message_id: str | None = None + self.emitted_start = False + self.in_step = False + self.started_tool_calls: set[str] = set() + self.emitted_tool_results: set[str] = set() + self.pending_tool_calls: set[str] = set() + self.emitted_approval_requests: set[str] = set() + + def close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: + parts: list[protocol.UIMessageStreamPart] = [] + if self.reasoning_id: + parts.append(protocol.ReasoningEndPart(id=self.reasoning_id)) + self.reasoning_id = None + if self.text_id: + parts.append(protocol.TextEndPart(id=self.text_id)) + self.text_id = None + return parts + + def finish_step(self) -> list[protocol.UIMessageStreamPart]: + parts = self.close_open_blocks() + if self.in_step: + parts.append(protocol.FinishStepPart()) + self.in_step = False + return parts + + def reset_tool_tracking(self) -> None: + self.started_tool_calls = set() + self.emitted_tool_results = set() + self.pending_tool_calls = set() + self.emitted_approval_requests = set() + + def begin_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: + parts: list[protocol.UIMessageStreamPart] = [] + is_new_message = self.message_id is not None and msg.id != self.message_id + + if not self.emitted_start or (msg.label and msg.label != self.label): + parts.extend(self.finish_step()) + if self.emitted_start: + parts.append(protocol.FinishPart(finish_reason="stop")) + + parts.append(protocol.StartPart(message_id=msg.id)) + parts.append(protocol.StartStepPart()) + self.emitted_start = True + self.in_step = True + self.label = msg.label + self.message_id = msg.id + self.reset_tool_tracking() + elif is_new_message: + parts.extend(self.finish_step()) + parts.append(protocol.StartStepPart()) + self.in_step = True + self.message_id = msg.id + + return parts + + +def _tool_call_id_from_approval_hook(hook_part: messages_.HookPart) -> str | None: + """Extract tool_call_id from a ToolApproval HookPart.""" + if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: + return None + prefix = "approve_" + if hook_part.hook_id.startswith(prefix): + return hook_part.hook_id[len(prefix) :] + return None + + +def _is_tool_approval_hook_message(msg: messages_.Message) -> bool: + """True if this message contains only ToolApproval HookParts.""" + if not msg.parts: + return False + return all( + isinstance(p, messages_.HookPart) + and _tool_call_id_from_approval_hook(p) is not None + for p in msg.parts + ) + + +def _tool_error_text(part: messages_.ToolResultPart) -> str: + """Best-effort error text for failed tool executions.""" + if isinstance(part.result, str) and part.result: + return part.result + if isinstance(part.result, dict): + for key in ("error", "message", "detail"): + value = part.result.get(key) + if isinstance(value, str) and value: + return value + return "Tool execution failed" + + +async def to_ui_message_stream( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[protocol.UIMessageStreamPart]: + """ + Convert an internal message stream into AI SDK UI stream parts. + """ + state = _StreamState() + + async for msg in messages: + if msg.role == "tool" and state.message_id: + msg = msg.model_copy(update={"id": state.message_id}) + + if _is_tool_approval_hook_message(msg) and state.message_id: + msg = msg.model_copy(update={"id": state.message_id}) + + for part in state.begin_message(msg): + yield part + + if delta := msg.reasoning_delta: + if not state.reasoning_id: + state.reasoning_id = messages_.generate_id("reasoning") + yield protocol.ReasoningStartPart(id=state.reasoning_id) + yield protocol.ReasoningDeltaPart(id=state.reasoning_id, delta=delta) + + if delta := msg.text_delta: + if state.reasoning_id: + yield protocol.ReasoningEndPart(id=state.reasoning_id) + state.reasoning_id = None + + if not state.text_id: + state.text_id = messages_.generate_id("text") + yield protocol.TextStartPart(id=state.text_id) + yield protocol.TextDeltaPart(id=state.text_id, delta=delta) + + for tool_delta in msg.tool_deltas: + if tool_delta.tool_call_id not in state.started_tool_calls: + state.started_tool_calls.add(tool_delta.tool_call_id) + yield protocol.ToolInputStartPart( + tool_call_id=tool_delta.tool_call_id, + tool_name=tool_delta.tool_name, + ) + yield protocol.ToolInputDeltaPart( + tool_call_id=tool_delta.tool_call_id, + input_text_delta=tool_delta.args_delta, + ) + + if msg.is_done: + had_active_text = state.text_id is not None + for part in state.close_open_blocks(): + yield part + + has_new_pending_tools = any( + isinstance(p, messages_.ToolCallPart) + and p.tool_call_id not in state.pending_tool_calls + for p in msg.parts + ) + has_new_tool_results = any( + isinstance(p, messages_.ToolResultPart) + and p.tool_call_id not in state.emitted_tool_results + for p in msg.parts + ) + + for msg_part in msg.parts: + match msg_part: + case messages_.TextPart(text=text) if ( + text + and not had_active_text + and not has_new_pending_tools + and not has_new_tool_results + ): + text_id = messages_.generate_id("text") + yield protocol.TextStartPart(id=text_id) + yield protocol.TextEndPart(id=text_id) + case messages_.ToolCallPart( + tool_call_id=tc_id, + tool_name=name, + tool_args=args, + ): + if tc_id not in state.started_tool_calls: + state.started_tool_calls.add(tc_id) + yield protocol.ToolInputStartPart( + tool_call_id=tc_id, + tool_name=name, + ) + if tc_id not in state.pending_tool_calls: + state.pending_tool_calls.add(tc_id) + yield protocol.ToolInputAvailablePart( + tool_call_id=tc_id, + tool_name=name, + input=args, + ) + + if has_new_tool_results: + for msg_part in msg.parts: + if not isinstance(msg_part, messages_.ToolResultPart): + continue + if msg_part.tool_call_id in state.emitted_tool_results: + continue + + state.emitted_tool_results.add(msg_part.tool_call_id) + state.pending_tool_calls.discard(msg_part.tool_call_id) + + if msg_part.is_error: + yield protocol.ToolOutputErrorPart( + tool_call_id=msg_part.tool_call_id, + error_text=_tool_error_text(msg_part), + ) + else: + yield protocol.ToolOutputAvailablePart( + tool_call_id=msg_part.tool_call_id, + output=msg_part.result, + ) + + for msg_part in msg.parts: + if not isinstance(msg_part, messages_.HookPart): + continue + approval_tc_id = _tool_call_id_from_approval_hook(msg_part) + if approval_tc_id is None: + continue + + if msg_part.status == "pending": + if approval_tc_id not in state.emitted_approval_requests: + state.emitted_approval_requests.add(approval_tc_id) + yield protocol.ToolApprovalRequestPart( + approval_id=msg_part.hook_id, + tool_call_id=approval_tc_id, + ) + elif msg_part.status == "resolved": + resolution = msg_part.resolution or {} + if not resolution.get("granted", False): + yield protocol.ToolOutputDeniedPart( + tool_call_id=approval_tc_id, + ) + elif msg_part.status == "cancelled": + yield protocol.ToolOutputErrorPart( + tool_call_id=approval_tc_id, + error_text="Hook cancelled", + ) + + for part in state.finish_step(): + yield part + if state.emitted_start: + yield protocol.FinishPart(finish_reason="stop") + + +async def filter_by_label( + messages: AsyncIterable[messages_.Message], + label: str | None = None, +) -> AsyncGenerator[messages_.Message]: + """Filter a message stream to a single agent label.""" + async for msg in messages: + if label is None: + label = msg.label + if msg.label == label: + yield msg + + +async def to_sse_stream( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[str]: + """Convert an internal message stream into SSE strings.""" + async for part in to_ui_message_stream(messages): + yield format_sse(part) + + +# Backward-compatible aliases for the current package surface. +stream_to_ui = to_ui_message_stream +stream_to_sse = to_sse_stream diff --git a/src/ai/adapters/ai_sdk_ui/protocol.py b/src/ai/agents/ui/ai_sdk/protocol.py similarity index 100% rename from src/ai/adapters/ai_sdk_ui/protocol.py rename to src/ai/agents/ui/ai_sdk/protocol.py diff --git a/src/ai/adapters/ai_sdk_ui/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_message.py similarity index 100% rename from src/ai/adapters/ai_sdk_ui/ui_message.py rename to src/ai/agents/ui/ai_sdk/ui_message.py From d1deca64bff7b5ac17d6cc9f675a0009dee84fc9 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 17 Apr 2026 13:21:03 -0700 Subject: [PATCH 02/10] Gather streaming state in a sidecar on ai.Message --- src/ai/__init__.py | 10 +- src/ai/agents/agent.py | 21 +++- src/ai/agents/hooks.py | 10 +- src/ai/models/core/api.py | 14 ++- src/ai/models/core/helpers/streaming.py | 52 ++++------ src/ai/models/core/types.py | 22 +++- src/ai/types/__init__.py | 10 +- src/ai/types/integrity.py | 8 +- src/ai/types/messages.py | 129 ++++++++++++++++-------- tests/agents/test_generator_tools.py | 12 +-- tests/agents/test_runtime.py | 2 - tests/conftest.py | 5 +- tests/models/core/test_streaming.py | 85 +++++++++++----- tests/models/test_public_api.py | 4 +- tests/types/test_integrity.py | 36 +++---- 15 files changed, 268 insertions(+), 152 deletions(-) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 00a7c589..8c07b4b9 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -36,8 +36,11 @@ FilePart, HookPart, Message, + MessageStreamState, Part, - PartState, + PartClosed, + PartDelta, + PartOpened, ReasoningPart, StreamResultLike, StructuredOutputPart, @@ -62,8 +65,11 @@ __all__ = [ # Types (from types/) "Message", + "MessageStreamState", "Part", - "PartState", + "PartClosed", + "PartDelta", + "PartOpened", "TextPart", "ToolCallPart", "ToolResultPart", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index c7d73a9d..99275312 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -188,6 +188,7 @@ class Context(pydantic.BaseModel): model: models.Model messages: list[types.Message] tools: list[Tool[..., Any]] + run_id: str | None = None model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) @@ -206,7 +207,10 @@ def __call__(self, context: Context) -> AsyncGenerator[types.Message]: ... async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: while True: stream = await models.stream( - context.model, context.messages, tools=context.tools + context.model, + context.messages, + tools=context.tools, + run_id=context.run_id, ) async for message in stream: yield message @@ -220,7 +224,10 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: tasks = [tg.create_task(tc()) for tc in tool_calls] # Yield one merged tool-result message — history auto-collects it. - yield builders.tool_message(*(t.result() for t in tasks)) + tool_msg = builders.tool_message(*(t.result() for t in tasks)) + if context.run_id is not None: + tool_msg = tool_msg.model_copy(update={"run_id": context.run_id}) + yield tool_msg async def _collect_messages( @@ -311,6 +318,8 @@ async def run( First in the list = outermost. Middleware wraps model calls, tool calls, hooks, and the run itself. """ + run_id = types.generate_id("run") + call = middleware_.AgentRunContext( model=model, messages=messages, @@ -327,11 +336,17 @@ async def _real( model=call.model, messages=list(call.messages), tools=call.tools, + run_id=run_id, ) source = _collect_messages(loop_fn(context), context.messages) async for message in runtime.run(source): + updates: dict[str, Any] = {} + if message.run_id is None: + updates["run_id"] = run_id if call.label is not None: - message = message.model_copy(update={"label": call.label}) + updates["agent"] = call.label + if updates: + message = message.model_copy(update=updates) yield message # Activate middleware for this run (and everything it calls). diff --git a/src/ai/agents/hooks.py b/src/ai/agents/hooks.py index 871a2a9d..6205dcb0 100644 --- a/src/ai/agents/hooks.py +++ b/src/ai/agents/hooks.py @@ -125,7 +125,7 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: # Emit pending signal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, @@ -150,10 +150,10 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: # Clean up live registry. _live_hooks.pop(label, None) - # Emit resolved signal message. + # Emit resolved internal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, @@ -229,10 +229,10 @@ async def cancel_hook(label: str, *, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - # Emit cancelled signal message. + # Emit cancelled internal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index d79e5ca0..14a31d65 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -27,6 +27,7 @@ async def stream( *, tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, + run_id: str | None = None, **kwargs: Any, ) -> stream_.StreamResultLike: """Stream an LLM response. @@ -35,11 +36,17 @@ async def stream( collects the final ``Message``. After iteration, access ``.text``, ``.tool_calls``, ``.usage``, etc. + Every yielded message (re-emitted inputs + model response) carries + ``run_id``. If *run_id* is not provided, one is generated. + The client is resolved from the model: ``model.client`` if set, otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. """ messages = integrity_.prepare_messages(messages) + if run_id is None: + run_id = messages_.generate_id("run") + call = middleware_.ModelContext( model=model, messages=messages, @@ -48,6 +55,9 @@ async def stream( kwargs=kwargs, ) + # Capture in closure for the inner function. + _run_id = run_id + async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: c = client_.auto_client(call.model) adapter_fn = adapters.get_stream_adapter(call.model.adapter) @@ -59,7 +69,9 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: tools=call.tools, output_type=call.output_type, **call.kwargs, - ) + ), + run_id=_run_id, + input_messages=call.messages, ) chain = middleware_._build_model_chain(_real) diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index 1fa91441..855a7a99 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -122,50 +122,54 @@ class StreamHandler: def handle_event(self, event: StreamEvent) -> messages_.Message: """Process event and return current Message state.""" - # Current deltas (reset each call) - text_delta: str | None = None - reasoning_delta: str | None = None - tool_deltas: dict[str, str] = {} # tool_call_id -> delta + # Sidecar events for this yield (reset each call). + stream_events: list[messages_.StreamEvent] = [] match event: case TextStart(block_id=bid): self._text_blocks[bid] = "" self._active_text_id = bid + stream_events.append(messages_.PartOpened(part_id=bid)) case TextDelta(block_id=bid, delta=d): self._text_blocks[bid] += d - text_delta = d + stream_events.append(messages_.PartDelta(part_id=bid, chunk=d)) case TextEnd(block_id=bid): if self._active_text_id == bid: self._active_text_id = None + stream_events.append(messages_.PartClosed(part_id=bid)) case ReasoningStart(block_id=bid): self._reasoning_blocks[bid] = ("", None) self._active_reasoning_id = bid + stream_events.append(messages_.PartOpened(part_id=bid)) case ReasoningDelta(block_id=bid, delta=d): text, sig = self._reasoning_blocks[bid] self._reasoning_blocks[bid] = (text + d, sig) - reasoning_delta = d + stream_events.append(messages_.PartDelta(part_id=bid, chunk=d)) case ReasoningEnd(block_id=bid, signature=sig): text, _ = self._reasoning_blocks[bid] self._reasoning_blocks[bid] = (text, sig) if self._active_reasoning_id == bid: self._active_reasoning_id = None + stream_events.append(messages_.PartClosed(part_id=bid)) case ToolStart(tool_call_id=tcid, tool_name=name): self._tool_calls[tcid] = (name, "") self._active_tool_ids.add(tcid) + stream_events.append(messages_.PartOpened(part_id=tcid)) case ToolArgsDelta(tool_call_id=tcid, delta=d): name, args = self._tool_calls[tcid] self._tool_calls[tcid] = (name, args + d) - tool_deltas[tcid] = d + stream_events.append(messages_.PartDelta(part_id=tcid, chunk=d)) case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) + stream_events.append(messages_.PartClosed(part_id=tcid)) case FileEvent(block_id=bid, media_type=mt, data=d): self._files[bid] = (mt, d) @@ -177,52 +181,30 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: self._active_reasoning_id = None self._active_tool_ids.clear() - return self._build_message(text_delta, reasoning_delta, tool_deltas) + return self._build_message(stream_events) def _build_message( self, - text_delta: str | None, - reasoning_delta: str | None, - tool_deltas: dict[str, str], + stream_events: list[messages_.StreamEvent], ) -> messages_.Message: parts: list[messages_.Part] = [] # Reasoning parts first (like thinking blocks) for bid, (text, sig) in self._reasoning_blocks.items(): - is_active = bid == self._active_reasoning_id - parts.append( - messages_.ReasoningPart( - id=bid, - text=text, - signature=sig, - state="streaming" if is_active else "done", - delta=reasoning_delta if is_active else None, - ) - ) + parts.append(messages_.ReasoningPart(id=bid, text=text, signature=sig)) # Text parts for bid, text in self._text_blocks.items(): - is_active = bid == self._active_text_id - parts.append( - messages_.TextPart( - id=bid, - text=text, - state="streaming" if is_active else "done", - delta=text_delta if is_active else None, - ) - ) + parts.append(messages_.TextPart(id=bid, text=text)) # Tool call parts for tcid, (name, args) in self._tool_calls.items(): - is_active = tcid in self._active_tool_ids parts.append( messages_.ToolCallPart( id=tcid, tool_call_id=tcid, tool_name=name, tool_args=args, - state="streaming" if is_active else "done", - args_delta=tool_deltas.get(tcid), ) ) @@ -235,6 +217,10 @@ def _build_message( role="assistant", parts=parts, usage=self._usage if self._is_done else None, + stream=messages_.MessageStreamState( + events=stream_events, + is_done=self._is_done, + ), ) diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py index eff6e5af..29222cf2 100644 --- a/src/ai/models/core/types.py +++ b/src/ai/models/core/types.py @@ -69,11 +69,23 @@ class StreamResult: Properties like ``.text`` and ``.tool_calls`` delegate to the final ``Message`` snapshot and are available after iteration completes. + When *run_id* is provided, every yielded message is stamped with it. + When *input_messages* is provided, they are re-emitted (with *run_id*) + before the model response stream. + Satisfies :class:`~ai.types.StreamResultLike`. """ - def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: + def __init__( + self, + gen: AsyncGenerator[messages_.Message], + *, + run_id: str | None = None, + input_messages: list[messages_.Message] | None = None, + ) -> None: self._gen = gen + self._run_id = run_id + self._input_messages = input_messages or [] self._final: messages_.Message | None = None @classmethod @@ -98,7 +110,15 @@ def __aiter__(self) -> AsyncGenerator[messages_.Message]: return self._iterate() async def _iterate(self) -> AsyncGenerator[messages_.Message]: + # Re-emit input messages with run_id stamped. + for msg in self._input_messages: + stamped = msg.model_copy(update={"run_id": self._run_id}) + yield stamped + + # Stream model response with run_id stamped. async for msg in self._gen: + if self._run_id is not None: + msg = msg.model_copy(update={"run_id": self._run_id}) self._final = msg yield msg diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index aa7a9435..4ce2d699 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -8,8 +8,11 @@ FilePart, HookPart, Message, + MessageStreamState, Part, - PartState, + PartClosed, + PartDelta, + PartOpened, ReasoningPart, StructuredOutputPart, TextPart, @@ -26,8 +29,11 @@ "FilePart", "HookPart", "Message", + "MessageStreamState", "Part", - "PartState", + "PartClosed", + "PartDelta", + "PartOpened", "ReasoningPart", "StreamResultLike", "StructuredOutputPart", diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py index f2b8ca87..9f737101 100644 --- a/src/ai/types/integrity.py +++ b/src/ai/types/integrity.py @@ -16,7 +16,7 @@ "invalid-tool-args", "orphaned-tool-call", "orphaned-tool-result", - "signal-message", + "internal-message", ] @@ -47,9 +47,9 @@ def _clean_messages( result: list[messages_.Message] = [] for msg in messages: - # 1. drop signal messages emitted by hooks - if msg.role == "signal": - issues.append("signal-message") + # 1. drop internal messages emitted by hooks + if msg.role == "internal": + issues.append("internal-message") if mode == "strict": result.append(msg) continue diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index e0adc8ef..b1ebfb9e 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -13,19 +13,12 @@ def generate_id(prefix: str | None = None) -> str: return f"{prefix}_{raw}" if prefix else raw -# Streaming state for parts -PartState = Literal["streaming", "done"] - - class TextPart(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) id: str = pydantic.Field(default_factory=generate_id) text: str type: Literal["text"] = "text" - # Streaming state - state: PartState | None = None # None = finalized/restored from storage - delta: str | None = None # Current delta, None when not actively streaming class ToolCallPart(pydantic.BaseModel): @@ -43,9 +36,6 @@ class ToolCallPart(pydantic.BaseModel): tool_name: str tool_args: str type: Literal["tool_call"] = "tool_call" - # Streaming state (for args streaming) - state: PartState | None = None - args_delta: str | None = None # Delta for tool_args class ToolResultPart(pydantic.BaseModel): @@ -74,9 +64,6 @@ class ReasoningPart(pydantic.BaseModel): # Anthropic's thinking blocks include a signature for cache/verification. # This must be preserved and sent back in multi-turn conversations. signature: str | None = None - # Streaming state - state: PartState | None = None - delta: str | None = None class HookPart(pydantic.BaseModel): @@ -279,14 +266,60 @@ class ToolDelta(pydantic.BaseModel): args_delta: str +# --------------------------------------------------------------------------- +# Streaming sidecar — transient state excluded from persistence. +# --------------------------------------------------------------------------- + + +class PartOpened(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + part_id: str + type: Literal["part_opened"] = "part_opened" + + +class PartDelta(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + part_id: str + chunk: str + type: Literal["part_delta"] = "part_delta" + + +class PartClosed(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + part_id: str + type: Literal["part_closed"] = "part_closed" + + +StreamEvent = Annotated[ + PartOpened | PartDelta | PartClosed, + pydantic.Field(discriminator="type"), +] + + +class MessageStreamState(pydantic.BaseModel): + """Transient streaming state attached to a Message during streaming. + + ``events`` contains the events since the previous yield — never cumulative. + ``is_done`` is True once the stream has finished. + """ + + events: list[StreamEvent] = pydantic.Field(default_factory=list) + is_done: bool = False + + class Message(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) - role: Literal["user", "assistant", "system", "tool", "signal"] + role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) - label: str | None = None + run_id: str | None = None + agent: str | None = None usage: Usage | None = None + stream: MessageStreamState | None = pydantic.Field(default=None, exclude=True) @overload def replace(self, new: Part, /) -> Message: ... @@ -337,44 +370,58 @@ def output(self) -> Any: @property def is_done(self) -> bool: - """Message is done when all parts are done (or have no streaming state).""" - for part in self.parts: - if ( - isinstance(part, (TextPart, ReasoningPart, ToolCallPart)) - and part.state == "streaming" - ): - return False - return True + """No sidecar (persisted/restored) means done. Otherwise ``stream.is_done``.""" + if self.stream is None: + return True + return self.stream.is_done + + def _parts_by_id(self) -> dict[str, Part]: + return {p.id: p for p in self.parts} @property def text_delta(self) -> str: - """Get current text delta from parts.""" - for part in self.parts: - if isinstance(part, TextPart) and part.delta: - return part.delta + """Derive from ``stream.events`` — first PartDelta whose part is TextPart.""" + if self.stream is None: + return "" + parts_map = self._parts_by_id() + for ev in self.stream.events: + if isinstance(ev, PartDelta): + part = parts_map.get(ev.part_id) + if isinstance(part, TextPart): + return ev.chunk return "" @property def reasoning_delta(self) -> str: - """Get current reasoning delta from parts.""" - for part in self.parts: - if isinstance(part, ReasoningPart) and part.delta: - return part.delta + """First PartDelta whose part is a ReasoningPart.""" + if self.stream is None: + return "" + parts_map = self._parts_by_id() + for ev in self.stream.events: + if isinstance(ev, PartDelta): + part = parts_map.get(ev.part_id) + if isinstance(part, ReasoningPart): + return ev.chunk return "" @property def tool_deltas(self) -> list[ToolDelta]: - """Get current tool deltas from parts.""" - deltas = [] - for part in self.parts: - if isinstance(part, ToolCallPart) and part.args_delta: - deltas.append( - ToolDelta( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - args_delta=part.args_delta, + """Derive from ``stream.events`` — PartDeltas whose parts are ToolCallPart.""" + if self.stream is None: + return [] + parts_map = self._parts_by_id() + deltas: list[ToolDelta] = [] + for ev in self.stream.events: + if isinstance(ev, PartDelta): + part = parts_map.get(ev.part_id) + if isinstance(part, ToolCallPart): + deltas.append( + ToolDelta( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + args_delta=ev.chunk, + ) ) - ) return deltas @property diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 37259b18..6969d03d 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -24,12 +24,12 @@ async def progress_tool(query: str) -> AsyncGenerator[ai.Message]: """Tool that streams progress, then returns a final answer.""" yield ai.Message( role="assistant", - parts=[messages_.TextPart(text="Working...", state="done")], - label="progress", + parts=[messages_.TextPart(text="Working...")], + agent="progress", ) yield ai.Message( role="assistant", - parts=[messages_.TextPart(text=f"Answer for {query}", state="done")], + parts=[messages_.TextPart(text=f"Answer for {query}")], ) @@ -51,7 +51,7 @@ async def test_generator_tool_streams_and_returns_result() -> None: assert llm.call_count == 2 # Intermediate progress message was forwarded to consumer. - progress = [m for m in collected if m.label == "progress"] + progress = [m for m in collected if m.agent == "progress"] assert len(progress) == 1 assert progress[0].text == "Working..." @@ -175,7 +175,7 @@ async def test_yield_from_nested_agent() -> None: assert adapter.call_count == 3 # Inner messages were forwarded to the consumer with label="inner". - inner_msgs = [m for m in collected if m.label == "inner"] + inner_msgs = [m for m in collected if m.agent == "inner"] assert len(inner_msgs) > 0 # The outer LLM's second call (index 2) must NOT contain any inner @@ -187,7 +187,7 @@ async def test_yield_from_nested_agent() -> None: # Specifically: no inner assistant text or inner tool results leaked. for m in outer_turn2_msgs: - assert m.label is None or m.label != "inner" + assert m.agent is None or m.agent != "inner" if m.role == "assistant": # This must be the outer tool-call message, not inner text. assert len(m.tool_calls) > 0 diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index aa08d1c3..101505b3 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -72,13 +72,11 @@ async def test_agent_parallel_tools() -> None: tool_call_id="tc-1", tool_name="double", tool_args='{"x": 3}', - state="done", ), messages.ToolCallPart( tool_call_id="tc-2", tool_name="double", tool_args='{"x": 7}', - state="done", ), ], ) diff --git a/tests/conftest.py b/tests/conftest.py index 1e8529f1..dd876b3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -215,10 +215,8 @@ def text_msg( text: str, *, id: str = "msg-1", - state: messages_.PartState | None = "done", - delta: str | None = None, ) -> messages_.Message: - part: messages_.Part = messages_.TextPart(text=text, state=state, delta=delta) + part: messages_.Part = messages_.TextPart(text=text) return messages_.Message(id=id, role="assistant", parts=[part]) @@ -234,7 +232,6 @@ def tool_call_msg( tool_call_id=tc_id, tool_name=name, tool_args=args, - state="done", ) return messages_.Message(id=id, role="assistant", parts=[part]) diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py index 7cc7c3a2..7490a934 100644 --- a/tests/models/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -4,6 +4,7 @@ from ai.models.core.helpers import streaming from ai.types import messages +from ai.types.messages import PartClosed, PartDelta, PartOpened # -- Text streaming -------------------------------------------------------- @@ -14,27 +15,37 @@ def test_text_lifecycle() -> None: assert len(m.parts) == 1 part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "streaming" assert part.text == "" + assert m.stream is not None + assert any(isinstance(e, PartOpened) and e.part_id == "b1" for e in m.stream.events) m = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) part = m.parts[0] assert isinstance(part, messages.TextPart) assert part.text == "Hello" - assert part.delta == "Hello" - assert part.state == "streaming" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == "Hello" + for e in m.stream.events + ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) part = m.parts[0] assert isinstance(part, messages.TextPart) assert part.text == "Hello world" - assert part.delta == " world" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == " world" + for e in m.stream.events + ) m = h.handle_event(streaming.TextEnd(block_id="b1")) part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "done" - assert part.delta is None + assert m.stream is not None + assert any(isinstance(e, PartClosed) and e.part_id == "b1" for e in m.stream.events) + # No delta events in this yield + assert not any(isinstance(e, PartDelta) for e in m.stream.events) # -- Reasoning streaming --------------------------------------------------- @@ -47,13 +58,18 @@ def test_reasoning_lifecycle() -> None: part = m.parts[0] assert isinstance(part, messages.ReasoningPart) assert part.text == "thinking" - assert part.state == "streaming" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part_id == "r1" and e.chunk == "thinking" + for e in m.stream.events + ) m = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) part = m.parts[0] assert isinstance(part, messages.ReasoningPart) - assert part.state == "done" assert part.signature == "sig123" + assert m.stream is not None + assert any(isinstance(e, PartClosed) and e.part_id == "r1" for e in m.stream.events) # -- Tool streaming -------------------------------------------------------- @@ -67,8 +83,11 @@ def test_tool_lifecycle() -> None: assert isinstance(part, messages.ToolCallPart) assert part.tool_name == "get_weather" assert part.tool_args == '{"ci' - assert part.state == "streaming" - assert part.args_delta == '{"ci' + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part_id == "tc1" and e.chunk == '{"ci' + for e in m.stream.events + ) m = h.handle_event( streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}') @@ -80,8 +99,12 @@ def test_tool_lifecycle() -> None: m = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) part = m.parts[0] assert isinstance(part, messages.ToolCallPart) - assert part.state == "done" - assert part.args_delta is None + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.events + ) + # No delta events in this yield + assert not any(isinstance(e, PartDelta) for e in m.stream.events) # -- Multi-part messages --------------------------------------------------- @@ -106,12 +129,10 @@ def test_reasoning_then_text_then_tool() -> None: assert isinstance(m.parts[0], messages.ReasoningPart) assert isinstance(m.parts[1], messages.TextPart) assert isinstance(m.parts[2], messages.ToolCallPart) - assert all( - p.state == "done" - for p in m.parts - if isinstance( - p, (messages.TextPart, messages.ToolCallPart, messages.ReasoningPart) - ) + # The last event was ToolEnd(tc1), so only that PartClosed is in events + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.events ) @@ -134,12 +155,10 @@ def test_multiple_tool_calls() -> None: h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) m = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) - assert all( - p.state == "done" - for p in m.parts - if isinstance( - p, (messages.TextPart, messages.ToolCallPart, messages.ReasoningPart) - ) + # Last event was ToolEnd(tc2), so its PartClosed is in events + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part_id == "tc2" for e in m.stream.events ) @@ -154,8 +173,9 @@ def test_message_done_finalizes_all() -> None: m = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "done" assert m.is_done + assert m.stream is not None + assert m.stream.is_done def test_message_done_propagates_usage() -> None: @@ -180,7 +200,7 @@ def test_message_done_propagates_usage() -> None: def test_deltas_only_on_active_blocks() -> None: - """Delta should be None on inactive blocks, present only on active.""" + """Delta events should only reference the active block.""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="first")) @@ -190,8 +210,17 @@ def test_deltas_only_on_active_blocks() -> None: m = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) text_parts = [p for p in m.parts if isinstance(p, messages.TextPart)] - assert text_parts[0].delta is None # t1 is done - assert text_parts[1].delta == "second" # t2 is active + assert text_parts[0].text == "first" # t1 snapshot + assert text_parts[1].text == "second" # t2 snapshot + # Only t2 has a delta event in this yield + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part_id == "t2" and e.chunk == "second" + for e in m.stream.events + ) + assert not any( + isinstance(e, PartDelta) and e.part_id == "t1" for e in m.stream.events + ) # -- File event (inline images from LLMs like Gemini/GPT-5) --------------- diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 7ca4b1b1..3ab0502c 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -57,7 +57,7 @@ async def _spy_stream( yield messages_.Message( id="m1", role="assistant", - parts=[messages_.TextPart(text="ok", state="done")], + parts=[messages_.TextPart(text="ok")], ) models.register_stream("mock", _spy_stream) @@ -89,7 +89,7 @@ async def _structured_stream( output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, ) -> AsyncGenerator[messages_.Message]: - text_part = messages_.TextPart(text=json_text, state="done") + text_part = messages_.TextPart(text=json_text) parts: list[messages_.Part] = [text_part] if output_type is not None: import json diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index ac308752..16a94f18 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -93,14 +93,14 @@ def test_idempotent() -> None: # --------------------------------------------------------------------------- -# Signal messages +# Internal messages # --------------------------------------------------------------------------- -def test_drops_signal_messages() -> None: +def test_drops_internal_messages() -> None: msgs = [ builders.user_message("hi"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), builders.assistant_message("hello"), ] result = prepare_messages(msgs) @@ -109,13 +109,13 @@ def test_drops_signal_messages() -> None: assert result[1].role == "assistant" -def test_signal_strict_raises() -> None: +def test_internal_strict_raises() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), ] with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") - assert "signal-message" in exc_info.value.issues + assert "internal-message" in exc_info.value.issues # --------------------------------------------------------------------------- @@ -346,7 +346,7 @@ def test_complete_tool_flow_unchanged() -> None: def test_strict_collects_all_issues() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), messages.Message( role="assistant", parts=[ @@ -358,13 +358,13 @@ def test_strict_collects_all_issues() -> None: with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") issues = exc_info.value.issues - assert "signal-message" in issues + assert "internal-message" in issues assert "internal-part" in issues def test_strict_keeps_recoverable_issues_when_history_is_corrupt() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), builders.user_message("go"), messages.Message( role="assistant", @@ -380,7 +380,7 @@ def test_strict_keeps_recoverable_issues_when_history_is_corrupt() -> None: ] with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") - assert "signal-message" in exc_info.value.issues + assert "internal-message" in exc_info.value.issues assert "duplicate-tool-call" in exc_info.value.issues @@ -487,8 +487,8 @@ async def test_stream_calls_prepare_messages() -> None: spy.assert_called_once_with(msgs) -async def test_stream_sanitizes_signal_messages() -> None: - """Signal messages are stripped before reaching the adapter.""" +async def test_stream_sanitizes_internal_messages() -> None: + """Internal messages are stripped before reaching the adapter.""" received: list[list[messages.Message]] = [] mock = mock_llm([[text_msg("ok")]]) @@ -514,17 +514,17 @@ async def _spy_stream( msgs = [ ai.user_message("hi"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), ai.assistant_message("hello"), ] s = await models.stream(MOCK_MODEL, msgs) async for _ in s: pass - # The adapter should have received only 2 messages (signal stripped) + # The adapter should have received only 2 messages (internal stripped) assert len(received) == 1 assert len(received[0]) == 2 - assert all(m.role != "signal" for m in received[0]) + assert all(m.role != "internal" for m in received[0]) async def test_generate_calls_prepare_messages() -> None: @@ -543,8 +543,8 @@ async def test_generate_calls_prepare_messages() -> None: spy.assert_called_once_with(msgs) -async def test_generate_sanitizes_signal_messages() -> None: - """Signal messages are stripped before reaching generate adapter.""" +async def test_generate_sanitizes_internal_messages() -> None: + """Internal messages are stripped before reaching generate adapter.""" received: list[list[messages.Message]] = [] sentinel = messages.Message( role="assistant", @@ -564,7 +564,7 @@ async def _spy_gen( msgs = [ ai.user_message("A cat"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), ] await models.generate(MOCK_MODEL, msgs, models.ImageParams(n=1)) From c52224f5003b3a6a585a8ef26fd4a3a802d8f576 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 17 Apr 2026 13:22:14 -0700 Subject: [PATCH 03/10] Update examples to match the new api --- examples/fastapi-vite/README.md | 2 +- examples/multiagent-textual/README.md | 2 +- examples/multiagent-textual/client.py | 4 ++-- examples/multiagent-textual/server.py | 2 +- examples/samples/agent_hooks.py | 6 +++--- examples/samples/agent_hooks_serverless.py | 4 ++-- examples/samples/streaming_tool.py | 8 ++++---- examples/samples/tools_schema.py | 4 ++-- 8 files changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index ae25200b..baf1c13b 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -16,7 +16,7 @@ to suspend execution whenever the LLM wants to call a tool. The flow is: 1. LLM emits a tool call 2. Backend calls `await ai.hook(...)` with `payload=ai.ToolApproval` -3. The runtime emits a `role="signal"` message containing a pending `HookPart` +3. The runtime emits a `role="internal"` message containing a pending `HookPart` 4. The frontend renders Approve / Reject buttons via the `` component (from AI Elements) 5. When the user clicks a button, `addToolApprovalResponse()` patches diff --git a/examples/multiagent-textual/README.md b/examples/multiagent-textual/README.md index c4dcc71d..5126e7de 100644 --- a/examples/multiagent-textual/README.md +++ b/examples/multiagent-textual/README.md @@ -10,7 +10,7 @@ The current implementation uses: - `ai.agent(...)` for each branch and the orchestrator - `await ai.hook(...)` for branch-specific approvals - `ai.yield_from(...)` to forward nested agent output into the outer run -- `role="signal"` messages for hook state updates over the WebSocket +- `role="internal"` messages for hook state updates over the WebSocket ## Setup diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 154d4c52..ec246a6a 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -158,7 +158,7 @@ async def run_websocket(self) -> None: # ------------------------------------------------------------------ def _handle_message(self, msg: ai.Message) -> None: - label = msg.label or "unknown" + label = msg.agent or "unknown" if (hook_part := msg.get_hook_part()) is not None: if hook_part.status == "pending": @@ -190,7 +190,7 @@ def _handle_message(self, msg: ai.Message) -> None: if msg.is_done: for part in msg.parts: match part: - case ai.ToolCallPart(tool_name=name, tool_args=args, state="done"): + case ai.ToolCallPart(tool_name=name, tool_args=args): panel.append_line(f"> {name}({args})") case ai.ToolResultPart(tool_name=name, result=result): panel.append_line(f"< {name} = {result}") diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index a35935f0..37a5bf8d 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -185,7 +185,7 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: ], ) async for msg in s: - yield msg.model_copy(update={"label": "summary"}) + yield msg.model_copy(update={"agent": "summary"}) # --------------------------------------------------------------------------- diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index 7180311b..aaec1ffc 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -3,7 +3,7 @@ Demonstrates the function-based hook API: - await hook("label", payload=Model) to suspend inside the loop - resolve_hook("label", data) to unblock from outside - - Hook messages arrive with role="signal" + - Hook messages arrive with role="internal" """ import asyncio @@ -76,8 +76,8 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: ] async for msg in my_agent.run(model, messages): - # Hook signals arrive with role="signal" - if msg.role == "signal": + # Hook signals arrive with role="internal" + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part and hook_part.status == "pending": answer = input(f"Approve {hook_part.hook_id}? [y/n] ") diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 23717689..5b0110e8 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -87,7 +87,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: durability = ai.EventLogProvider() async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "signal": + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part and hook_part.status == "pending": pending_hook_labels.append(hook_part.hook_id) @@ -108,7 +108,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: durability = ai.EventLogProvider(saved_checkpoint) async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "signal": + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part: print(f" Hook {hook_part.status}: {hook_part.hook_id}") diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 5bf8fedd..07d87f5c 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -17,15 +17,15 @@ async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Message]: for step in ["Connecting...", "Transmitting...", "Awaiting response..."]: yield ai.Message( role="assistant", - parts=[ai.TextPart(text=step, state="done")], - label="tool_progress", + parts=[ai.TextPart(text=step)], + agent="tool_progress", ) await asyncio.sleep(0.3) # The final yielded message's text is returned as the tool result. yield ai.Message( role="assistant", - parts=[ai.TextPart(text="The mothership says: Soon.", state="done")], + parts=[ai.TextPart(text="The mothership says: Soon.")], ) @@ -40,7 +40,7 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.label == "tool_progress": + if msg.agent == "tool_progress": print(f" [{msg.text}]") elif msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index 0c13ad42..256e98e0 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -29,8 +29,8 @@ async def main() -> None: if msg.text_delta: print(msg.text_delta, end="", flush=True) - for tc in msg.tool_calls: - if tc.state == "done": + if msg.is_done: + for tc in msg.tool_calls: print(f"\nTool call: {tc.tool_name}({tc.tool_args})") print() From 48435d603a1b388d1a8473317e84f31f417cf216 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 17 Apr 2026 14:11:21 -0700 Subject: [PATCH 04/10] Fix stale models usage in examples --- examples/samples/agent_nested.py | 2 +- examples/samples/agent_simple.py | 2 +- examples/samples/image_edit.py | 2 +- examples/samples/inline_image.py | 2 +- examples/samples/mcp_tools.py | 2 +- examples/samples/multimodal_input.py | 2 +- examples/samples/streaming_tool.py | 2 +- examples/samples/structured_output.py | 2 +- examples/samples/tools_schema.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/samples/agent_nested.py b/examples/samples/agent_nested.py index b4a2f535..df0b5b54 100644 --- a/examples/samples/agent_nested.py +++ b/examples/samples/agent_nested.py @@ -5,7 +5,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") @ai.tool diff --git a/examples/samples/agent_simple.py b/examples/samples/agent_simple.py index 6a90b888..b0fc3485 100644 --- a/examples/samples/agent_simple.py +++ b/examples/samples/agent_simple.py @@ -12,7 +12,7 @@ async def get_weather(city: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[get_weather]) diff --git a/examples/samples/image_edit.py b/examples/samples/image_edit.py index ac07fc4f..f99780a1 100644 --- a/examples/samples/image_edit.py +++ b/examples/samples/image_edit.py @@ -11,7 +11,7 @@ import ai -model = ai.model("ai-gateway", "openai/gpt-image-1") +model = ai.ai_gateway("openai/gpt-image-1") async def main() -> None: diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 2b2fdcc4..3eb195ab 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -11,7 +11,7 @@ import ai -model = ai.model("ai-gateway", "google/gemini-3-pro-image") +model = ai.ai_gateway("google/gemini-3-pro-image") messages = [ ai.system_message( diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index 6925fd0a..76a84d5b 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -8,7 +8,7 @@ async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") context7_tools: list[ai.Tool[..., Any]] = await ai.mcp.get_http_tools( "https://mcp.context7.com/mcp", diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 07ce4935..690213da 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -5,7 +5,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") # Load a local image file (replace with your own path). image_path = pathlib.Path("sample_image.jpg") diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 07d87f5c..c53a58c3 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -30,7 +30,7 @@ async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Message]: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[talk_to_mothership]) diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 11928053..87b36e5f 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") class Recipe(pydantic.BaseModel): diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index 256e98e0..ac714820 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -4,7 +4,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") # Define a tool schema — anything matching the ToolLike protocol works. get_weather = ai.ToolSchema( From c371aeaaba9f17a30d2b52744091b656b9751611 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 17 Apr 2026 14:33:25 -0700 Subject: [PATCH 05/10] Replace run_id with turn_id and fix semantics to match actual turns --- src/ai/__init__.py | 5 +-- src/ai/agents/agent.py | 16 ++------ src/ai/models/core/api.py | 16 ++++---- src/ai/models/core/types.py | 32 ++++++++++----- src/ai/types/messages.py | 2 +- src/ai/types/stream.py | 3 ++ tests/agents/test_runtime.py | 71 +++++++++++++++++++++++++++++++++ tests/models/test_public_api.py | 39 ++++++++++++++++++ 8 files changed, 148 insertions(+), 36 deletions(-) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 8c07b4b9..f70034a6 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,5 +1,4 @@ -from . import adapters, middleware, models -from .adapters import ai_sdk_ui +from . import middleware, models from .agents import ( TOOL_APPROVAL_HOOK_TYPE, Agent, @@ -127,6 +126,4 @@ "middleware", # Submodules "mcp", - "ai_sdk_ui", - "adapters", ] diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 99275312..37a6153f 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -188,7 +188,6 @@ class Context(pydantic.BaseModel): model: models.Model messages: list[types.Message] tools: list[Tool[..., Any]] - run_id: str | None = None model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) @@ -210,7 +209,6 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: context.model, context.messages, tools=context.tools, - run_id=context.run_id, ) async for message in stream: yield message @@ -224,9 +222,9 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: tasks = [tg.create_task(tc()) for tc in tool_calls] # Yield one merged tool-result message — history auto-collects it. + # Left un-stamped: the tool result is the input of the *next* turn, + # so the next stream() call will stamp it with that turn's id. tool_msg = builders.tool_message(*(t.result() for t in tasks)) - if context.run_id is not None: - tool_msg = tool_msg.model_copy(update={"run_id": context.run_id}) yield tool_msg @@ -318,8 +316,6 @@ async def run( First in the list = outermost. Middleware wraps model calls, tool calls, hooks, and the run itself. """ - run_id = types.generate_id("run") - call = middleware_.AgentRunContext( model=model, messages=messages, @@ -336,17 +332,11 @@ async def _real( model=call.model, messages=list(call.messages), tools=call.tools, - run_id=run_id, ) source = _collect_messages(loop_fn(context), context.messages) async for message in runtime.run(source): - updates: dict[str, Any] = {} - if message.run_id is None: - updates["run_id"] = run_id if call.label is not None: - updates["agent"] = call.label - if updates: - message = message.model_copy(update=updates) + message = message.model_copy(update={"agent": call.label}) yield message # Activate middleware for this run (and everything it calls). diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 14a31d65..898eafcd 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -27,7 +27,7 @@ async def stream( *, tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - run_id: str | None = None, + turn_id: str | None = None, **kwargs: Any, ) -> stream_.StreamResultLike: """Stream an LLM response. @@ -36,16 +36,18 @@ async def stream( collects the final ``Message``. After iteration, access ``.text``, ``.tool_calls``, ``.usage``, etc. - Every yielded message (re-emitted inputs + model response) carries - ``run_id``. If *run_id* is not provided, one is generated. + One call is one turn: a single request and its response. The model + response carries ``turn_id``; re-emitted input messages keep any + existing ``turn_id`` from prior turns and only receive the current + one when unstamped. If *turn_id* is not provided, one is generated. The client is resolved from the model: ``model.client`` if set, otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. """ messages = integrity_.prepare_messages(messages) - if run_id is None: - run_id = messages_.generate_id("run") + if turn_id is None: + turn_id = messages_.generate_id("turn") call = middleware_.ModelContext( model=model, @@ -56,7 +58,7 @@ async def stream( ) # Capture in closure for the inner function. - _run_id = run_id + _turn_id = turn_id async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: c = client_.auto_client(call.model) @@ -70,7 +72,7 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: output_type=call.output_type, **call.kwargs, ), - run_id=_run_id, + turn_id=_turn_id, input_messages=call.messages, ) diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py index 29222cf2..a15d6dd1 100644 --- a/src/ai/models/core/types.py +++ b/src/ai/models/core/types.py @@ -69,9 +69,12 @@ class StreamResult: Properties like ``.text`` and ``.tool_calls`` delegate to the final ``Message`` snapshot and are available after iteration completes. - When *run_id* is provided, every yielded message is stamped with it. - When *input_messages* is provided, they are re-emitted (with *run_id*) - before the model response stream. + One ``StreamResult`` represents one turn: a single LLM request and its + response. When *turn_id* is provided, the model response is stamped + with it. When *input_messages* is provided, they are re-emitted ahead + of the response; inputs that already carry a ``turn_id`` (from earlier + turns) are preserved as-is, only inputs with ``turn_id=None`` receive + the current *turn_id*. Satisfies :class:`~ai.types.StreamResultLike`. """ @@ -80,11 +83,11 @@ def __init__( self, gen: AsyncGenerator[messages_.Message], *, - run_id: str | None = None, + turn_id: str | None = None, input_messages: list[messages_.Message] | None = None, ) -> None: self._gen = gen - self._run_id = run_id + self._turn_id = turn_id self._input_messages = input_messages or [] self._final: messages_.Message | None = None @@ -110,18 +113,25 @@ def __aiter__(self) -> AsyncGenerator[messages_.Message]: return self._iterate() async def _iterate(self) -> AsyncGenerator[messages_.Message]: - # Re-emit input messages with run_id stamped. + # Re-emit input messages; stamp only the ones without a turn_id. + # Prior turns keep their existing ids. for msg in self._input_messages: - stamped = msg.model_copy(update={"run_id": self._run_id}) - yield stamped + if msg.turn_id is None and self._turn_id is not None: + msg = msg.model_copy(update={"turn_id": self._turn_id}) + yield msg - # Stream model response with run_id stamped. + # Stream model response with turn_id stamped (when missing). async for msg in self._gen: - if self._run_id is not None: - msg = msg.model_copy(update={"run_id": self._run_id}) + if msg.turn_id is None and self._turn_id is not None: + msg = msg.model_copy(update={"turn_id": self._turn_id}) self._final = msg yield msg + @property + def turn_id(self) -> str | None: + """The turn id stamped on this stream's response (if any).""" + return self._turn_id + @property def text(self) -> str: return self._final.text if self._final else "" diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index b1ebfb9e..5ced280a 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -316,7 +316,7 @@ class Message(pydantic.BaseModel): role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) - run_id: str | None = None + turn_id: str | None = None agent: str | None = None usage: Usage | None = None stream: MessageStreamState | None = pydantic.Field(default=None, exclude=True) diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py index 3fa43a99..f6389bc0 100644 --- a/src/ai/types/stream.py +++ b/src/ai/types/stream.py @@ -34,3 +34,6 @@ def usage(self) -> messages_.Usage | None: ... @property def output(self) -> Any: ... + + @property + def turn_id(self) -> str | None: ... diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 101505b3..6cda3a8d 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -109,3 +109,74 @@ async def test_agent_multi_turn() -> None: async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]): msgs.append(m) assert llm.call_count == 3 + + +# -- turn_id semantics: one turn per LLM round-trip ------------------------- + + +async def test_two_user_messages_produce_four_turns() -> None: + """Two agent.run invocations, each with a tool call + final reply, + produce four distinct turn ids; history from the first run keeps its + original turn ids when fed into the second run.""" + my_agent = ai.agent(tools=[double]) + + # Run 1: tool call, then text. + r1_turn1 = [tool_call_msg(tc_id="tc-1", name="double", args='{"x": 5}', id="m-1a")] + r1_turn2 = [text_msg("Ten.", id="m-1b")] + # Run 2: tool call, then text. + r2_turn1 = [tool_call_msg(tc_id="tc-2", name="double", args='{"x": 7}', id="m-2a")] + r2_turn2 = [text_msg("Fourteen.", id="m-2b")] + mock_llm([r1_turn1, r1_turn2, r2_turn1, r2_turn2]) + + def dedup(stream: list[ai.Message]) -> list[ai.Message]: + seen: dict[str, ai.Message] = {} + for m in stream: + if m.is_done: + seen[m.id] = m + return list(seen.values()) + + run1_stream: list[ai.Message] = [] + async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]): + run1_stream.append(m) + history = dedup(run1_stream) + + run2_stream: list[ai.Message] = [] + async for m in my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]): + run2_stream.append(m) + final = dedup(run2_stream) + + # Chronological list of terminal non-internal messages. Insertion order + # of ``dedup`` reflects the order they first appeared in the stream. + chronological = [m for m in final if m.role != "internal"] + assert len(chronological) == 8 + + # Expected shape: four turns, each a (input, assistant) pair. + # turn 1: user → assistant (tool call) + # turn 2: tool → assistant (text) + # turn 3: user → assistant (tool call) + # turn 4: tool → assistant (text) + expected_roles = [ + ("user", "assistant"), + ("tool", "assistant"), + ("user", "assistant"), + ("tool", "assistant"), + ] + pairs = [(chronological[2 * i], chronological[2 * i + 1]) for i in range(4)] + for (left, right), (expected_left, expected_right) in zip( + pairs, expected_roles, strict=True + ): + assert (left.role, right.role) == (expected_left, expected_right) + # Both messages in the pair share the same turn_id. + assert left.turn_id is not None + assert left.turn_id == right.turn_id + + # The four turn_ids are all distinct. + turn_ids = [left.turn_id for left, _ in pairs] + assert len(set(turn_ids)) == 4 + + # Run 1's history survives untouched into run 2. + history_ids = {h.id for h in history} + for h in history: + same = next(m for m in final if m.id == h.id) + assert same.turn_id == h.turn_id + assert any(m.id in history_ids for m in final) diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 3ab0502c..46319b67 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -40,6 +40,45 @@ async def test_stream_basic() -> None: assert "".join(deltas) == "Hello world" +async def test_stream_preserves_existing_turn_ids() -> None: + """ai.stream() stamps only inputs without a turn_id; older turns survive.""" + mock = mock_llm([[text_msg("reply")]]) + + old = ai.user_message("earlier") + old = old.model_copy(update={"turn_id": "prev"}) + fresh = ai.user_message("latest") + + s = await models.stream(MOCK_MODEL, [old, fresh]) + yielded: list[messages_.Message] = [] + async for msg in s: + yielded.append(msg) + + assert mock.call_count == 1 + # First yielded is the old input — unchanged. + assert yielded[0].turn_id == "prev" + # Fresh input was stamped with the current turn's id. + assert yielded[1].turn_id is not None + assert yielded[1].turn_id != "prev" + # Response shares the current turn id. + response_ids = [m.turn_id for m in yielded if m.role == "assistant"] + assert response_ids and all(tid == yielded[1].turn_id for tid in response_ids) + + +async def test_stream_accepts_explicit_turn_id() -> None: + """Explicit turn_id kwarg is used verbatim.""" + mock_llm([[text_msg("ok")]]) + fresh = ai.user_message("hi") + + s = await models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") + yielded: list[messages_.Message] = [] + async for msg in s: + yielded.append(msg) + + assert s.turn_id == "custom-turn" + assert yielded[0].turn_id == "custom-turn" + assert yielded[-1].turn_id == "custom-turn" + + async def test_stream_with_explicit_client() -> None: """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] From 9021a72e3d593428a770f73fdf170c4575e5de23 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 17 Apr 2026 16:16:34 -0700 Subject: [PATCH 06/10] Port and clean up the ai sdk adapter --- src/ai/agents/ui/__init__.py | 4 +- src/ai/agents/ui/ai_sdk/__init__.py | 41 +- src/ai/agents/ui/ai_sdk/_approvals.py | 31 + src/ai/agents/ui/ai_sdk/_parts.py | 141 ++++ src/ai/agents/ui/ai_sdk/inbound.py | 190 ++--- src/ai/agents/ui/ai_sdk/message_to_ui.py | 456 ------------ src/ai/agents/ui/ai_sdk/outbound.py | 303 -------- src/ai/agents/ui/ai_sdk/outbound/__init__.py | 7 + src/ai/agents/ui/ai_sdk/outbound/_state.py | 310 ++++++++ src/ai/agents/ui/ai_sdk/outbound/history.py | 66 ++ src/ai/agents/ui/ai_sdk/outbound/sse.py | 39 + src/ai/agents/ui/ai_sdk/outbound/stream.py | 43 ++ src/ai/agents/ui/ai_sdk/ui_message.py | 2 +- tests/adapters/ai_sdk_ui/test_adapter.py | 693 ------------------ tests/{adapters => agents/ui}/__init__.py | 0 .../ui/ai_sdk}/__init__.py | 0 tests/agents/ui/ai_sdk/outbound/__init__.py | 0 .../agents/ui/ai_sdk/outbound/test_history.py | 126 ++++ tests/agents/ui/ai_sdk/outbound/test_sse.py | 52 ++ .../agents/ui/ai_sdk/outbound/test_stream.py | 236 ++++++ tests/agents/ui/ai_sdk/test_approvals.py | 64 ++ tests/agents/ui/ai_sdk/test_inbound.py | 146 ++++ tests/agents/ui/ai_sdk/test_parts.py | 105 +++ 23 files changed, 1461 insertions(+), 1594 deletions(-) create mode 100644 src/ai/agents/ui/ai_sdk/_approvals.py create mode 100644 src/ai/agents/ui/ai_sdk/_parts.py delete mode 100644 src/ai/agents/ui/ai_sdk/message_to_ui.py delete mode 100644 src/ai/agents/ui/ai_sdk/outbound.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound/__init__.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound/_state.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound/history.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound/sse.py create mode 100644 src/ai/agents/ui/ai_sdk/outbound/stream.py delete mode 100644 tests/adapters/ai_sdk_ui/test_adapter.py rename tests/{adapters => agents/ui}/__init__.py (100%) rename tests/{adapters/ai_sdk_ui => agents/ui/ai_sdk}/__init__.py (100%) create mode 100644 tests/agents/ui/ai_sdk/outbound/__init__.py create mode 100644 tests/agents/ui/ai_sdk/outbound/test_history.py create mode 100644 tests/agents/ui/ai_sdk/outbound/test_sse.py create mode 100644 tests/agents/ui/ai_sdk/outbound/test_stream.py create mode 100644 tests/agents/ui/ai_sdk/test_approvals.py create mode 100644 tests/agents/ui/ai_sdk/test_inbound.py create mode 100644 tests/agents/ui/ai_sdk/test_parts.py diff --git a/src/ai/agents/ui/__init__.py b/src/ai/agents/ui/__init__.py index db357d09..2c272a79 100644 --- a/src/ai/agents/ui/__init__.py +++ b/src/ai/agents/ui/__init__.py @@ -1,5 +1,5 @@ """Downstream adapters — protocol bridges for the agent loop output.""" -from . import ai_sdk_ui +from . import ai_sdk -__all__ = ["ai_sdk_ui"] +__all__ = ["ai_sdk"] diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py index 8bb4bcaf..8f283199 100644 --- a/src/ai/agents/ui/ai_sdk/__init__.py +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -1,42 +1,23 @@ -# inbound: UI -> internal +"""AI SDK UI adapter — ``ai.Messages`` in, ``ai.Messages`` out, SSE on the wire.""" + from .inbound import ( ApprovalResponse, + apply_approvals, extract_approvals, - normalize_ui_messages, - ui_to_messages, -) - -# outbound: internal -> SSE stream -from .outbound import filter_by_label, stream_to_sse, stream_to_ui - -# message_to_ui: internal -> UI format (persistence/history) -from .message_to_ui import ( - UIMessageBuilder, - messages_to_ui, - parts_to_ui, - ui_parts_to_dicts, + to_messages, ) - -# data models +from .outbound import to_sse, to_stream, to_ui_messages from .protocol import UI_MESSAGE_STREAM_HEADERS from .ui_message import UIMessage __all__ = [ - # inbound - "ui_to_messages", - "normalize_ui_messages", - "extract_approvals", "ApprovalResponse", - # outbound - "stream_to_ui", - "stream_to_sse", - "filter_by_label", - # message_to_ui - "messages_to_ui", - "parts_to_ui", - "ui_parts_to_dicts", - "UIMessageBuilder", - # data models "UIMessage", "UI_MESSAGE_STREAM_HEADERS", + "apply_approvals", + "extract_approvals", + "to_messages", + "to_sse", + "to_stream", + "to_ui_messages", ] diff --git a/src/ai/agents/ui/ai_sdk/_approvals.py b/src/ai/agents/ui/ai_sdk/_approvals.py new file mode 100644 index 00000000..f4e98d79 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/_approvals.py @@ -0,0 +1,31 @@ +"""Approval-prefix linkage between ToolApproval hooks and tool calls. + +TODO(datamodel-rework §4): once ``HookPart.target_tool_call_id`` is added, +delete this module and replace call sites with direct field access. +""" + +from __future__ import annotations + +from ....types import messages as messages_ +from ...hooks import TOOL_APPROVAL_HOOK_TYPE + +_PREFIX = "approve_" + + +def tool_call_id_for(hook_part: messages_.HookPart) -> str | None: + """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" + if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: + return None + if hook_part.hook_id.startswith(_PREFIX): + return hook_part.hook_id[len(_PREFIX) :] + return None + + +def is_tool_approval_message(msg: messages_.Message) -> bool: + """True if every part of ``msg`` is a ToolApproval HookPart.""" + if not msg.parts: + return False + return all( + isinstance(p, messages_.HookPart) and tool_call_id_for(p) is not None + for p in msg.parts + ) diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py new file mode 100644 index 00000000..7c198cbb --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -0,0 +1,141 @@ +"""Shared conversions between internal Part objects and UIMessagePart objects. + +Used by ``outbound.history`` to reconstruct UIMessages from persisted +``ai.Message`` lists. The live outbound stream does not use these — it +emits wire-protocol deltas directly from ``Message.stream.events``. +""" + +from __future__ import annotations + +import json +from typing import Any + +from ....types import messages as messages_ +from . import _approvals, ui_message + + +def _normalize_tool_input(raw: str) -> str | dict[str, Any]: + """Parse tool args JSON string into a dict; fall back to raw string. + + TODO(datamodel-rework §4): once ``ToolCallPart.tool_args`` has a + canonical shape, drop this helper. + """ + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return raw + return parsed if isinstance(parsed, dict) else raw + + +def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: + """Convert internal Part objects to UIMessagePart objects.""" + result: list[ui_message.UIMessagePart] = [] + for part in parts: + if isinstance(part, messages_.TextPart) and part.text: + result.append(ui_message.UITextPart(type="text", text=part.text)) + elif isinstance(part, messages_.ReasoningPart) and part.text: + result.append( + ui_message.UIReasoningPart(type="reasoning", reasoning=part.text) + ) + elif isinstance(part, messages_.ToolCallPart): + result.append( + ui_message.UIToolPart.model_validate( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": _normalize_tool_input(part.tool_args), + } + ) + ) + elif isinstance(part, messages_.FilePart): + result.append( + ui_message.UIFilePart.model_validate( + { + "type": "file", + "mediaType": part.media_type, + "url": part.data if isinstance(part.data, str) else "", + "filename": part.filename, + } + ) + ) + return result + + +def merge_tool_results( + ui_parts: list[ui_message.UIMessagePart], + tool_parts: list[messages_.Part], +) -> None: + """Merge ToolResultParts into existing UIToolParts in-place.""" + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in tool_parts: + if not isinstance(part, messages_.ToolResultPart): + continue + idx_opt = tool_index.get(part.tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + if existing.state == "output-denied": + continue + state = "output-error" if part.is_error else "output-available" + ui_parts[idx] = existing.model_copy( + update={"state": state, "output": part.result} + ) + + +def merge_approval_signals( + ui_parts: list[ui_message.UIMessagePart], + internal_parts: list[messages_.Part], +) -> None: + """Merge HookPart approval state into existing UIToolParts in-place.""" + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in internal_parts: + if not isinstance(part, messages_.HookPart): + continue + + tool_call_id = _approvals.tool_call_id_for(part) + if tool_call_id is None: + continue + + idx_opt = tool_index.get(tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + + updates: dict[str, Any] = {} + if part.status == "pending": + updates["state"] = "approval-requested" + updates["approval"] = ui_message.UIToolApproval(id=part.hook_id) + elif part.status == "resolved": + resolution = part.resolution or {} + updates["approval"] = ui_message.UIToolApproval( + id=part.hook_id, + approved=resolution.get("granted"), + reason=resolution.get("reason"), + ) + if resolution.get("granted", False): + updates["state"] = "approval-responded" + else: + updates["state"] = "output-denied" + updates["output"] = None + elif part.status == "cancelled": + updates["state"] = "output-error" + updates["error_text"] = "Hook cancelled" + + if updates: + ui_parts[idx] = existing.model_copy(update=updates) diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index e8f32d70..0dcadea2 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -1,8 +1,7 @@ -""" -Inbound: UI -> internal message conversion. +"""Inbound adapter: AI SDK v6 UIMessages → internal ``ai.Message`` list. -Converts AI SDK v6 UIMessages into internal Message objects -for use with the runtime/agent loop. +The primary entry point is :func:`to_messages`, which bundles normalization, +approval extraction, parsing, and pre-registration of approval resolutions. """ from __future__ import annotations @@ -11,7 +10,8 @@ import logging from typing import Any, NamedTuple -from ...types import messages as messages_ +from ....types import messages as messages_ +from ...hooks import resolve_hook from . import ui_message logger = logging.getLogger(__name__) @@ -22,15 +22,15 @@ def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates a completed tool.""" return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates an error.""" return state in _TOOL_ERROR_STATES +# TODO(datamodel-rework §4): once tool args have a canonical shape, drop +# these normalizers. def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: """Normalize tool input (JSON string, dict, or None) to a JSON string.""" match tool_input: @@ -43,18 +43,13 @@ def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: def _normalize_tool_result(output: Any) -> dict[str, Any] | None: - """Normalize tool output to dict format for internal ToolResultPart. - - The internal ToolResultPart.result expects dict | None, but AI SDK - output can be any type. Wrap non-dict results for compatibility. - """ + """Normalize tool output to dict format for internal ToolResultPart.""" if output is None: return None return output if isinstance(output, dict) else {"value": output} def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None: - """Normalize an error-state tool result.""" normalized = _normalize_tool_result(output) if error_text: if normalized is None: @@ -64,8 +59,8 @@ def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None: return normalized -def _approval_signal_part(tp: ui_message.UIToolPart) -> messages_.HookPart | None: - """Reconstruct approval signal state from a UI tool part when possible.""" +def _approval_hook_part(tp: ui_message.UIToolPart) -> messages_.HookPart | None: + """Reconstruct approval hook state from a UI tool part when possible.""" approval = tp.approval if approval is None: return None @@ -103,12 +98,12 @@ def _approval_signal_part(tp: ui_message.UIToolPart) -> messages_.HookPart | Non # ============================================================================ -# Approval extraction +# Approval extraction + bulk resolution # ============================================================================ class ApprovalResponse(NamedTuple): - """Extracted approval response from a UIToolPart in approval-responded state.""" + """Approval response extracted from a UIToolPart in ``approval-responded`` state.""" hook_id: str granted: bool @@ -118,17 +113,9 @@ class ApprovalResponse(NamedTuple): def extract_approvals( ui_messages: list[ui_message.UIMessage], ) -> list[ApprovalResponse]: - """Extract approval responses from UI messages. - - Walks UIMessages looking for UIToolParts in ``approval-responded`` state - and returns the approval data as a list. Pure function -- does not - resolve hooks or trigger any side-effects. + """Return every approval response found in *ui_messages*. - Args: - ui_messages: List of UIMessage objects from the AI SDK v6 frontend. - - Returns: - List of ApprovalResponse tuples with hook_id, granted, and reason. + Pure function — does not resolve hooks or trigger side effects. """ approvals: list[ApprovalResponse] = [] for ui_msg in ui_messages: @@ -150,20 +137,24 @@ def extract_approvals( return approvals +def apply_approvals(approvals: list[ApprovalResponse]) -> None: + """Pre-register each approval resolution with the hooks registry.""" + for approval in approvals: + resolve_hook( + approval.hook_id, + {"granted": approval.granted, "reason": approval.reason}, + ) + + # ============================================================================ # UI message normalization (heal stale tool states) # ============================================================================ -def normalize_ui_messages( +def _normalize_ui_messages( ui_messages: list[ui_message.UIMessage], ) -> list[ui_message.UIMessage]: - """Heal stale tool-part states from previously persisted assistant history. - - Tool parts may be stored in transient states (e.g. ``"call"``) if the - stream was interrupted. This normalizes them to consistent terminal - states based on what data is actually present. - """ + """Heal stale tool-part states from previously persisted assistant history.""" normalized: list[ui_message.UIMessage] = [] for message in ui_messages: new_parts = [] @@ -201,38 +192,56 @@ def normalize_ui_messages( # ============================================================================ -# UI -> internal message conversion +# UI → internal message conversion # ============================================================================ -def ui_to_messages( +def to_messages( ui_messages: list[ui_message.UIMessage], + *, + apply_approvals_: bool = True, ) -> list[messages_.Message]: - """Convert AI SDK v6 UI messages to internal Message format. + """Parse a UI request into runtime messages. - This is a pure data transformation. It does not resolve hooks or - trigger any side-effects. Use ``extract_approvals()`` separately - to obtain approval responses for hook resolution. + Pipeline: + + 1. normalize stale tool states (``call`` → ``input-available``, etc.) + 2. extract approval responses + 3. parse UIMessages into ``ai.Message`` list, splitting at tool boundaries + 4. if *apply_approvals_* is True, pre-register resolutions via ``resolve_hook`` + 5. strip trailing assistant message when approval responses are present + """ + normalized = _normalize_ui_messages(ui_messages) + approvals = extract_approvals(normalized) + messages = _parse(normalized) + + if apply_approvals_: + apply_approvals(approvals) + + if approvals and messages: + # The assistant message that originated the approval-responded tool + # call would re-send the duplicate tool-use to the LLM on replay. + # Walk past any trailing internal (hook) messages and drop the + # assistant message beneath them. + idx = len(messages) - 1 + while idx >= 0 and messages[idx].role == "internal": + idx -= 1 + if idx >= 0 and messages[idx].role == "assistant": + logger.info("Stripping assistant message originating responded approvals") + messages = messages[:idx] + messages[idx + 1 :] - When the last message is an assistant message that contains - approval-responded tool parts, it is automatically stripped to - avoid sending duplicate tool-use content to the LLM on re-entry. - The caller should check ``extract_approvals()`` to determine - whether this stripping occurred. + return messages - Args: - ui_messages: List of UIMessage objects from the AI SDK v6 frontend. - Returns: - List of internal Message objects ready for use with the runtime. - """ +def _parse( + ui_messages: list[ui_message.UIMessage], +) -> list[messages_.Message]: result: list[messages_.Message] = [] - has_approval_responses = False for ui_msg in ui_messages: assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] - signal_parts: list[messages_.HookPart] = [] + hook_parts: list[messages_.HookPart] = [] for part in ui_msg.parts: match part: @@ -243,7 +252,6 @@ def ui_to_messages( assistant_parts.append(messages_.ReasoningPart(text=reasoning)) case ui_message.UIToolInvocationPart() as inv: - # Legacy tool-invocation type tool_args = json.dumps(inv.args) if inv.args else "{}" assistant_parts.append( messages_.ToolCallPart( @@ -263,7 +271,6 @@ def ui_to_messages( ) case ui_message.UIToolPart() as tp: - # Dynamic tool-{toolName} type (e.g., "tool-get_weather") assistant_parts.append( messages_.ToolCallPart( tool_call_id=tp.tool_call_id, @@ -271,9 +278,9 @@ def ui_to_messages( tool_args=_normalize_tool_args(tp.input), ) ) - approval_signal = _approval_signal_part(tp) - if approval_signal is not None: - signal_parts.append(approval_signal) + approval_hook = _approval_hook_part(tp) + if approval_hook is not None: + hook_parts.append(approval_hook) if tp.state in _TOOL_RESULT_STATES: tool_result_parts.append( @@ -293,8 +300,6 @@ def ui_to_messages( is_error=True, ) ) - if tp.state == "approval-responded": - has_approval_responses = True case ui_message.UIFilePart() as fp: assistant_parts.append( @@ -310,9 +315,8 @@ def ui_to_messages( | ui_message.UISourceUrlPart() | ui_message.UISourceDocumentPart() ): - pass # Skip unsupported/boundary parts + pass - # Validate user/system messages have content - OpenAI requires it. if ui_msg.role in ("user", "system") and not assistant_parts: raise ValueError( f"Message '{ui_msg.id}' has role '{ui_msg.role}' but no content. " @@ -320,23 +324,23 @@ def ui_to_messages( ) # The UI sends one assistant message per conversation turn, but a - # single turn may span multiple loop iterations (e.g. - # [text, tool_call, tool_result, text, tool_call, tool_result, text]). + # single turn may span multiple loop iterations (e.g. [text, + # tool_call, tool_result, text, tool_call, tool_result, text]). # LLM APIs expect one message per iteration, so split into # assistant + tool message pairs at tool-result boundaries. if ui_msg.role == "assistant": - split_messages = _split_assistant_parts( - assistant_parts, tool_result_parts, msg_id=ui_msg.id + result.extend( + _split_assistant_parts( + assistant_parts, tool_result_parts, msg_id=ui_msg.id + ) ) - result.extend(split_messages) - if signal_parts: - result.extend( + for hp in hook_parts: + result.append( messages_.Message( id=ui_msg.id, - role="signal", - parts=[part], + role="internal", + parts=[hp], ) - for part in signal_parts ) else: result.append( @@ -347,13 +351,6 @@ def ui_to_messages( ) ) - # When resuming after approval responses, the frontend sends the full - # history including the assistant message from the interrupted run. - # Strip it to avoid sending duplicate tool-use content to the LLM. - if has_approval_responses and result and result[-1].role == "assistant": - logger.info("Stripping trailing assistant message (approval responses present)") - result = result[:-1] - return result @@ -362,52 +359,28 @@ def _split_assistant_parts( tool_results: list[messages_.ToolResultPart], msg_id: str, ) -> list[messages_.Message]: - """Split assistant parts into assistant + tool message pairs. - - The UI sends one big assistant message per turn, but internally each - loop iteration produces an assistant message (with tool calls) followed - by a tool message (with results). This reconstructs that structure. - - Returns a list of Messages: alternating assistant and tool messages, - split at tool-call boundaries when results are available. - """ - # Index tool results by their tool_call_id for lookup + """Split assistant parts into assistant + tool message pairs.""" results_by_id = {tr.tool_call_id: tr for tr in tool_results} - messages: list[messages_.Message] = [] - current: list[messages_.Part] = [] pending_results: list[messages_.ToolResultPart] = [] - for part in parts: - current.append(part) - - # When we see a ToolCallPart that has a result, accumulate it if ( isinstance(part, messages_.ToolCallPart) and part.tool_call_id in results_by_id ): pending_results.append(results_by_id[part.tool_call_id]) - # If there are pending results and more parts follow, we need to split. - # Walk again, splitting at boundaries where all accumulated tool calls - # have results and a non-tool part follows. if not pending_results: - # No completed tools -- single assistant message - if current: - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - return messages + if parts: + return [messages_.Message(role="assistant", parts=parts, id=msg_id)] + return [] - # Re-walk to split at tool-call boundaries - messages = [] - current = [] + messages: list[messages_.Message] = [] + current: list[messages_.Part] = [] current_results: list[messages_.ToolResultPart] = [] seen_tool_call = False for part in parts: - # If we had a completed tool call group and now see a non-tool part, - # split here if ( seen_tool_call and current_results @@ -428,7 +401,6 @@ def _split_assistant_parts( if part.tool_call_id in results_by_id: current_results.append(results_by_id[part.tool_call_id]) - # Flush remaining if current: messages.append(messages_.Message(role="assistant", parts=current, id=msg_id)) if current_results: diff --git a/src/ai/agents/ui/ai_sdk/message_to_ui.py b/src/ai/agents/ui/ai_sdk/message_to_ui.py deleted file mode 100644 index 6778504e..00000000 --- a/src/ai/agents/ui/ai_sdk/message_to_ui.py +++ /dev/null @@ -1,456 +0,0 @@ -""" -Message to UI: internal Message -> UI format conversion. - -Converts internal Message objects back into AI SDK v6 UI format, -both for persistence (DB storage, GET endpoints) and for -accumulating streaming assistant turns. -""" - -from __future__ import annotations - -import json -import logging -from typing import Any - -from ...types import messages as messages_ -from . import ui_message - -logger = logging.getLogger(__name__) - - -# ============================================================================ -# Part-level conversions -# ============================================================================ - - -def parts_to_ui(parts: list[messages_.Part]) -> list[dict[str, Any]]: - """Convert internal Part objects to UI-compatible dicts. - - The frontend expects the AI SDK UI protocol shape (``type``, ``text``, - ``toolCallId``, ``toolName``, ``input``, ``output``, ``state``, etc.) - which differs from the internal model. - - Handles: TextPart, ReasoningPart, ToolCallPart, ToolResultPart, FilePart. - ToolResultParts are returned as standalone dicts (with ``output`` and - terminal state); callers that want merged tool-call+result dicts should - use ``UIMessageBuilder`` instead. - """ - result: list[dict[str, Any]] = [] - for part in parts: - if isinstance(part, messages_.TextPart): - if part.text: - result.append({"type": "text", "text": part.text}) - elif isinstance(part, messages_.ReasoningPart): - if part.text: - result.append({"type": "reasoning", "reasoning": part.text}) - elif isinstance(part, messages_.ToolCallPart): - result.append( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "toolName": part.tool_name, - "state": "input-available", - "input": _normalize_tool_input(part.tool_args), - } - ) - elif isinstance(part, messages_.ToolResultPart): - state = "output-error" if part.is_error else "output-available" - result.append( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "toolName": part.tool_name, - "state": state, - "output": part.result, - } - ) - elif isinstance(part, messages_.FilePart): - entry: dict[str, Any] = { - "type": "file", - "mediaType": part.media_type, - "url": part.data if isinstance(part.data, str) else "", - } - if part.filename: - entry["filename"] = part.filename - result.append(entry) - return result - - -def ui_parts_to_dicts( - parts: list[ui_message.UIMessagePart], -) -> list[dict[str, Any]]: - """Serialize UIMessage parts to plain dicts for DB storage.""" - return [ - part.model_dump() if hasattr(part, "model_dump") else dict(part) # type: ignore[call-overload] - for part in parts - ] - - -# ============================================================================ -# Message-level conversion (batch, for history loading) -# ============================================================================ - - -def messages_to_ui( - messages: list[messages_.Message], -) -> list[ui_message.UIMessage]: - """Convert internal Messages to UIMessages. - - This is the inverse of ``inbound.ui_to_messages()``. It merges - consecutive assistant + tool + signal message groups back into single - UIMessages with the correct tool states, producing the format - expected by the AI SDK frontend and suitable for DB persistence. - - User/system messages are converted directly. Assistant messages - are accumulated until a non assistant/tool/signal message is seen (or - the list ends), merging tool results and approval signals into the - preceding assistant's tool-call parts. - """ - result: list[ui_message.UIMessage] = [] - - i = 0 - while i < len(messages): - msg = messages[i] - - if msg.role in ("user", "system"): - result.append( - ui_message.UIMessage( - id=msg.id, - role=msg.role, - parts=_internal_parts_to_ui_parts(msg.parts), - ) - ) - i += 1 - continue - - if msg.role == "assistant": - # Accumulate: merge this assistant message with any following - # tool/signal messages, and possibly more assistant+tool pairs that - # belong to the same UI turn. - ui_parts: list[ui_message.UIMessagePart] = [] - turn_id = msg.id - - while i < len(messages) and messages[i].role in ( - "assistant", - "tool", - "signal", - ): - current = messages[i] - if current.role == "assistant": - ui_parts.extend(_internal_parts_to_ui_parts(current.parts)) - elif current.role == "tool": - _merge_tool_results(ui_parts, current.parts) - elif current.role == "signal": - _merge_signal_parts(ui_parts, current.parts) - i += 1 - - result.append( - ui_message.UIMessage( - id=turn_id, - role="assistant", - parts=ui_parts, - ) - ) - continue - - # Skip signal, tool (orphaned), or unknown roles - i += 1 - - return result - - -def _internal_parts_to_ui_parts( - parts: list[messages_.Part], -) -> list[ui_message.UIMessagePart]: - """Convert internal Part objects to UIMessagePart objects.""" - result: list[ui_message.UIMessagePart] = [] - for part in parts: - if isinstance(part, messages_.TextPart) and part.text: - result.append(ui_message.UITextPart(type="text", text=part.text)) - elif isinstance(part, messages_.ReasoningPart) and part.text: - result.append( - ui_message.UIReasoningPart(type="reasoning", reasoning=part.text) - ) - elif isinstance(part, messages_.ToolCallPart): - tool_input = _normalize_tool_input(part.tool_args) - result.append( - ui_message.UIToolPart( - type=f"tool-{part.tool_name}", - tool_call_id=part.tool_call_id, - state="input-available", - input=tool_input, - ) - ) - elif isinstance(part, messages_.FilePart): - result.append( - ui_message.UIFilePart( - type="file", - media_type=part.media_type, - url=part.data if isinstance(part.data, str) else "", - filename=part.filename, - ) - ) - return result - - -def _merge_tool_results( - ui_parts: list[ui_message.UIMessagePart], - tool_parts: list[messages_.Part], -) -> None: - """Merge tool result parts into existing UI tool-call parts in-place. - - Finds the matching UIToolPart by tool_call_id and updates its state - and output to reflect the tool result. - """ - # Index existing tool parts by tool_call_id - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx - - for part in tool_parts: - if not isinstance(part, messages_.ToolResultPart): - continue - idx = tool_index.get(part.tool_call_id) - if idx is None: - continue - existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): - continue - if existing.state == "output-denied": - continue - state = "output-error" if part.is_error else "output-available" - ui_parts[idx] = existing.model_copy( - update={"state": state, "output": part.result} - ) - - -def _merge_signal_parts( - ui_parts: list[ui_message.UIMessagePart], - signal_parts: list[messages_.Part], -) -> None: - """Merge HookPart approval state into existing UI tool-call parts.""" - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx - - for part in signal_parts: - if not isinstance(part, messages_.HookPart): - continue - - tool_call_id = _tool_call_id_from_approval_id(part.hook_id) - if tool_call_id is None: - continue - - idx = tool_index.get(tool_call_id) - if idx is None: - continue - - existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): - continue - - updates: dict[str, Any] = {} - if part.status == "pending": - updates["state"] = "approval-requested" - updates["approval"] = {"id": part.hook_id} - elif part.status == "resolved": - resolution = part.resolution or {} - updates["approval"] = { - "id": part.hook_id, - "approved": resolution.get("granted"), - "reason": resolution.get("reason"), - } - if resolution.get("granted", False): - updates["state"] = "approval-responded" - else: - updates["state"] = "output-denied" - updates["output"] = None - elif part.status == "cancelled": - updates["state"] = "output-error" - updates["error_text"] = "Hook cancelled" - - if updates: - ui_parts[idx] = existing.model_copy(update=updates) - - -# ============================================================================ -# UIMessageBuilder: streaming accumulator for assistant turns -# ============================================================================ - - -class UIMessageBuilder: - """Accumulate streaming runtime messages into a single UI assistant message. - - Processes internal Message objects as they arrive from the agent loop - and builds up a single UIMessage with all parts in the AI SDK UI format. - Handles text, reasoning, tool calls, tool results, and approval signals. - - Usage:: - - builder = UIMessageBuilder() - async for msg in agent_stream: - builder.ingest(msg) - ui_msg = builder.build() - """ - - def __init__(self, message_id: str | None = None) -> None: - self.message_id = message_id - self.parts: list[dict[str, Any]] = [] - self._tool_indexes: dict[str, int] = {} - - @classmethod - def from_ui_message(cls, message: ui_message.UIMessage) -> UIMessageBuilder: - """Seed the builder from an existing UI assistant message (for resume).""" - builder = cls(message_id=message.id) - builder.parts = ui_parts_to_dicts(message.parts) - for index, part in enumerate(builder.parts): - part_type = part.get("type") - tool_call_id = part.get("toolCallId") - if ( - isinstance(part_type, str) - and part_type.startswith("tool-") - and isinstance(tool_call_id, str) - ): - builder._tool_indexes[tool_call_id] = index - return builder - - def ingest(self, message: messages_.Message) -> None: - """Consume one runtime message. - - Routes by role: - - ``assistant`` (done): appends text, reasoning, and tool-call parts - - ``tool``: updates existing tool parts with results - - ``signal``: updates tool parts with approval state - """ - if message.role == "assistant" and message.is_done: - self._ingest_assistant(message) - elif message.role == "tool": - self._ingest_tool(message) - elif message.role == "signal": - self._ingest_signal(message) - - def build(self) -> ui_message.UIMessage | None: - """Return the accumulated UIMessage, or None if nothing was ingested.""" - if not self.parts or self.message_id is None: - return None - # Parse the accumulated dicts back into typed UIMessageParts - parsed_parts: list[ui_message.UIMessagePart] = [] - for part_dict in self.parts: - parsed = ui_message._parse_ui_part(part_dict) - if parsed is not None: - parsed_parts.append(parsed) - return ui_message.UIMessage( - id=self.message_id, - role="assistant", - parts=parsed_parts, - ) - - @property - def raw_parts(self) -> list[dict[str, Any]]: - """Access the accumulated parts as raw dicts (for direct DB storage).""" - return self.parts - - # -- Private ingest handlers -- - - def _ingest_assistant(self, message: messages_.Message) -> None: - if self.message_id is None: - self.message_id = message.id - for part in message.parts: - if isinstance(part, messages_.ReasoningPart) and part.text: - candidate = {"type": "reasoning", "reasoning": part.text} - if self.parts[-1:] != [candidate]: - self.parts.append(candidate) - elif isinstance(part, messages_.TextPart) and part.text: - candidate = {"type": "text", "text": part.text} - if self.parts[-1:] != [candidate]: - self.parts.append(candidate) - elif isinstance(part, messages_.ToolCallPart): - if part.tool_call_id in self._tool_indexes: - continue - self._tool_indexes[part.tool_call_id] = len(self.parts) - self.parts.append( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "toolName": part.tool_name, - "state": "input-available", - "input": _normalize_tool_input(part.tool_args), - } - ) - - def _ingest_tool(self, message: messages_.Message) -> None: - for part in message.parts: - if not isinstance(part, messages_.ToolResultPart): - continue - index = self._tool_indexes.get(part.tool_call_id) - if index is None: - continue - tool_part = dict(self.parts[index]) - if tool_part.get("state") != "output-denied": - tool_part["state"] = ( - "output-error" if part.is_error else "output-available" - ) - tool_part["output"] = part.result - self.parts[index] = tool_part - - def _ingest_signal(self, message: messages_.Message) -> None: - hook_part = message.get_hook_part() - if hook_part is None: - return - tool_call_id = _tool_call_id_from_approval_id(hook_part.hook_id) - if tool_call_id is None: - return - index = self._tool_indexes.get(tool_call_id) - if index is None: - return - - tool_part = dict(self.parts[index]) - if hook_part.status == "pending": - tool_part["state"] = "approval-requested" - tool_part["approval"] = {"id": hook_part.hook_id} - elif hook_part.status == "resolved": - resolution = hook_part.resolution or {} - tool_part["approval"] = { - "id": hook_part.hook_id, - "approved": resolution.get("granted"), - "reason": resolution.get("reason"), - } - if resolution.get("granted", False): - tool_part["state"] = "approval-responded" - else: - tool_part["state"] = "output-denied" - elif hook_part.status == "cancelled": - tool_part["state"] = "output-error" - tool_part["errorText"] = "Hook cancelled" - self.parts[index] = tool_part - - -# ============================================================================ -# Shared helpers -# ============================================================================ - - -def _tool_call_id_from_approval_id(approval_id: str) -> str | None: - """Extract the tool_call_id from a ToolApproval hook label. - - E.g. ``"approve_tc_abc123"`` -> ``"tc_abc123"``. - """ - prefix = "approve_" - if approval_id.startswith(prefix): - return approval_id[len(prefix) :] - return None - - -def _normalize_tool_input(raw: str) -> str | dict[str, Any]: - """Normalize tool input for the UI's accepted string-or-dict shape. - - Tries to parse the JSON string into a dict. Falls back to the - raw string if parsing fails or the result isn't a dict. - """ - try: - parsed = json.loads(raw) - except (json.JSONDecodeError, TypeError): - return raw - return parsed if isinstance(parsed, dict) else raw diff --git a/src/ai/agents/ui/ai_sdk/outbound.py b/src/ai/agents/ui/ai_sdk/outbound.py deleted file mode 100644 index b5494f8a..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound.py +++ /dev/null @@ -1,303 +0,0 @@ -""" -Outbound: internal message stream -> AI SDK UI stream. - -Converts the internal runtime stream into AI SDK UI stream protocol parts -and optionally serializes them as SSE payloads. -""" - -from __future__ import annotations - -import dataclasses -import json -from collections.abc import AsyncGenerator, AsyncIterable - -from ...agents.hooks import TOOL_APPROVAL_HOOK_TYPE -from ...types import messages as messages_ -from . import protocol - - -def _to_camel_case(snake_str: str) -> str: - """Convert snake_case to camelCase.""" - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def serialize_part(part: protocol.UIMessageStreamPart) -> str: - """Serialize a stream part to JSON with camelCase keys.""" - d = dataclasses.asdict(part) - if isinstance(part, protocol.DataPart): - d["type"] = part.type - del d["data_type"] - camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} - return json.dumps(camel_dict) - - -def format_sse(part: protocol.UIMessageStreamPart) -> str: - """Format a stream part as an SSE data line.""" - return f"data: {serialize_part(part)}\n\n" - - -class _StreamState: - """Tracks state for UI message stream event sequencing.""" - - def __init__(self) -> None: - self.text_id: str | None = None - self.reasoning_id: str | None = None - self.label: str | None = None - self.message_id: str | None = None - self.emitted_start = False - self.in_step = False - self.started_tool_calls: set[str] = set() - self.emitted_tool_results: set[str] = set() - self.pending_tool_calls: set[str] = set() - self.emitted_approval_requests: set[str] = set() - - def close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: - parts: list[protocol.UIMessageStreamPart] = [] - if self.reasoning_id: - parts.append(protocol.ReasoningEndPart(id=self.reasoning_id)) - self.reasoning_id = None - if self.text_id: - parts.append(protocol.TextEndPart(id=self.text_id)) - self.text_id = None - return parts - - def finish_step(self) -> list[protocol.UIMessageStreamPart]: - parts = self.close_open_blocks() - if self.in_step: - parts.append(protocol.FinishStepPart()) - self.in_step = False - return parts - - def reset_tool_tracking(self) -> None: - self.started_tool_calls = set() - self.emitted_tool_results = set() - self.pending_tool_calls = set() - self.emitted_approval_requests = set() - - def begin_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: - parts: list[protocol.UIMessageStreamPart] = [] - is_new_message = self.message_id is not None and msg.id != self.message_id - - if not self.emitted_start or (msg.label and msg.label != self.label): - parts.extend(self.finish_step()) - if self.emitted_start: - parts.append(protocol.FinishPart(finish_reason="stop")) - - parts.append(protocol.StartPart(message_id=msg.id)) - parts.append(protocol.StartStepPart()) - self.emitted_start = True - self.in_step = True - self.label = msg.label - self.message_id = msg.id - self.reset_tool_tracking() - elif is_new_message: - parts.extend(self.finish_step()) - parts.append(protocol.StartStepPart()) - self.in_step = True - self.message_id = msg.id - - return parts - - -def _tool_call_id_from_approval_hook(hook_part: messages_.HookPart) -> str | None: - """Extract tool_call_id from a ToolApproval HookPart.""" - if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: - return None - prefix = "approve_" - if hook_part.hook_id.startswith(prefix): - return hook_part.hook_id[len(prefix) :] - return None - - -def _is_tool_approval_hook_message(msg: messages_.Message) -> bool: - """True if this message contains only ToolApproval HookParts.""" - if not msg.parts: - return False - return all( - isinstance(p, messages_.HookPart) - and _tool_call_id_from_approval_hook(p) is not None - for p in msg.parts - ) - - -def _tool_error_text(part: messages_.ToolResultPart) -> str: - """Best-effort error text for failed tool executions.""" - if isinstance(part.result, str) and part.result: - return part.result - if isinstance(part.result, dict): - for key in ("error", "message", "detail"): - value = part.result.get(key) - if isinstance(value, str) and value: - return value - return "Tool execution failed" - - -async def to_ui_message_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[protocol.UIMessageStreamPart]: - """ - Convert an internal message stream into AI SDK UI stream parts. - """ - state = _StreamState() - - async for msg in messages: - if msg.role == "tool" and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - if _is_tool_approval_hook_message(msg) and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - for part in state.begin_message(msg): - yield part - - if delta := msg.reasoning_delta: - if not state.reasoning_id: - state.reasoning_id = messages_.generate_id("reasoning") - yield protocol.ReasoningStartPart(id=state.reasoning_id) - yield protocol.ReasoningDeltaPart(id=state.reasoning_id, delta=delta) - - if delta := msg.text_delta: - if state.reasoning_id: - yield protocol.ReasoningEndPart(id=state.reasoning_id) - state.reasoning_id = None - - if not state.text_id: - state.text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=state.text_id) - yield protocol.TextDeltaPart(id=state.text_id, delta=delta) - - for tool_delta in msg.tool_deltas: - if tool_delta.tool_call_id not in state.started_tool_calls: - state.started_tool_calls.add(tool_delta.tool_call_id) - yield protocol.ToolInputStartPart( - tool_call_id=tool_delta.tool_call_id, - tool_name=tool_delta.tool_name, - ) - yield protocol.ToolInputDeltaPart( - tool_call_id=tool_delta.tool_call_id, - input_text_delta=tool_delta.args_delta, - ) - - if msg.is_done: - had_active_text = state.text_id is not None - for part in state.close_open_blocks(): - yield part - - has_new_pending_tools = any( - isinstance(p, messages_.ToolCallPart) - and p.tool_call_id not in state.pending_tool_calls - for p in msg.parts - ) - has_new_tool_results = any( - isinstance(p, messages_.ToolResultPart) - and p.tool_call_id not in state.emitted_tool_results - for p in msg.parts - ) - - for msg_part in msg.parts: - match msg_part: - case messages_.TextPart(text=text) if ( - text - and not had_active_text - and not has_new_pending_tools - and not has_new_tool_results - ): - text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=text_id) - yield protocol.TextEndPart(id=text_id) - case messages_.ToolCallPart( - tool_call_id=tc_id, - tool_name=name, - tool_args=args, - ): - if tc_id not in state.started_tool_calls: - state.started_tool_calls.add(tc_id) - yield protocol.ToolInputStartPart( - tool_call_id=tc_id, - tool_name=name, - ) - if tc_id not in state.pending_tool_calls: - state.pending_tool_calls.add(tc_id) - yield protocol.ToolInputAvailablePart( - tool_call_id=tc_id, - tool_name=name, - input=args, - ) - - if has_new_tool_results: - for msg_part in msg.parts: - if not isinstance(msg_part, messages_.ToolResultPart): - continue - if msg_part.tool_call_id in state.emitted_tool_results: - continue - - state.emitted_tool_results.add(msg_part.tool_call_id) - state.pending_tool_calls.discard(msg_part.tool_call_id) - - if msg_part.is_error: - yield protocol.ToolOutputErrorPart( - tool_call_id=msg_part.tool_call_id, - error_text=_tool_error_text(msg_part), - ) - else: - yield protocol.ToolOutputAvailablePart( - tool_call_id=msg_part.tool_call_id, - output=msg_part.result, - ) - - for msg_part in msg.parts: - if not isinstance(msg_part, messages_.HookPart): - continue - approval_tc_id = _tool_call_id_from_approval_hook(msg_part) - if approval_tc_id is None: - continue - - if msg_part.status == "pending": - if approval_tc_id not in state.emitted_approval_requests: - state.emitted_approval_requests.add(approval_tc_id) - yield protocol.ToolApprovalRequestPart( - approval_id=msg_part.hook_id, - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "resolved": - resolution = msg_part.resolution or {} - if not resolution.get("granted", False): - yield protocol.ToolOutputDeniedPart( - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "cancelled": - yield protocol.ToolOutputErrorPart( - tool_call_id=approval_tc_id, - error_text="Hook cancelled", - ) - - for part in state.finish_step(): - yield part - if state.emitted_start: - yield protocol.FinishPart(finish_reason="stop") - - -async def filter_by_label( - messages: AsyncIterable[messages_.Message], - label: str | None = None, -) -> AsyncGenerator[messages_.Message]: - """Filter a message stream to a single agent label.""" - async for msg in messages: - if label is None: - label = msg.label - if msg.label == label: - yield msg - - -async def to_sse_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[str]: - """Convert an internal message stream into SSE strings.""" - async for part in to_ui_message_stream(messages): - yield format_sse(part) - - -# Backward-compatible aliases for the current package surface. -stream_to_ui = to_ui_message_stream -stream_to_sse = to_sse_stream diff --git a/src/ai/agents/ui/ai_sdk/outbound/__init__.py b/src/ai/agents/ui/ai_sdk/outbound/__init__.py new file mode 100644 index 00000000..9347170e --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/__init__.py @@ -0,0 +1,7 @@ +"""Outbound adapter: ``ai.Message`` stream → AI SDK UI protocol.""" + +from .history import to_ui_messages +from .sse import to_sse +from .stream import to_stream + +__all__ = ["to_stream", "to_sse", "to_ui_messages"] diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py new file mode 100644 index 00000000..58be0775 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -0,0 +1,310 @@ +"""Stream state bookkeeping for the live outbound walk. + +Owns message/step boundary logic (via ``turn_id`` + ``agent``), tracks +which parts have open text/reasoning blocks, and guards against +re-emission when the runtime re-yields an already-finalized message. +""" + +from __future__ import annotations + +from typing import Any + +from .....types import messages as messages_ +from .. import _approvals, protocol + + +def _tool_error_text(part: messages_.ToolResultPart) -> str: + """Best-effort error text extraction from a failed tool result.""" + if isinstance(part.result, str) and part.result: + return part.result + if isinstance(part.result, dict): + for key in ("error", "message", "detail"): + value = part.result.get(key) + if isinstance(value, str) and value: + return value + return "Tool execution failed" + + +class _StreamState: + """Single-pass state across one ``to_stream()`` call.""" + + def __init__(self) -> None: + self.current_turn_id: str | None = None + self.current_agent: str | None = None + self.ui_message_id: str | None = None + self.emitted_start: bool = False + self.in_step: bool = False + + # Message-level dedup — an ``is_done`` message re-emitted as input to a + # later ``stream()`` call must not fire events twice. + self.seen_done: set[str] = set() + + # Tool-call dedup — keyed by tool_call_id. + self.started_tool_inputs: set[str] = set() + self.input_available_emitted: set[str] = set() + self.emitted_tool_results: set[str] = set() + self.emitted_approval_requests: set[str] = set() + + # Open streaming blocks — keyed by part id. + self.open_text_ids: set[str] = set() + self.open_reasoning_ids: set[str] = set() + + # -- boundary helpers ---------------------------------------------------- + + def _close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: + parts: list[protocol.UIMessageStreamPart] = [] + for rid in list(self.open_reasoning_ids): + parts.append(protocol.ReasoningEndPart(id=rid)) + self.open_reasoning_ids.clear() + for tid in list(self.open_text_ids): + parts.append(protocol.TextEndPart(id=tid)) + self.open_text_ids.clear() + return parts + + def _finish_step(self) -> list[protocol.UIMessageStreamPart]: + parts = self._close_open_blocks() + if self.in_step: + parts.append(protocol.FinishStepPart()) + self.in_step = False + return parts + + def _reset_step_tracking(self) -> None: + self.started_tool_inputs.clear() + self.input_available_emitted.clear() + self.emitted_tool_results.clear() + self.emitted_approval_requests.clear() + + # -- phase: message start ------------------------------------------------ + + def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: + """Emit UIMessage/step boundary parts for *msg*.""" + parts: list[protocol.UIMessageStreamPart] = [] + + agent_changed = ( + self.emitted_start + and msg.agent is not None + and msg.agent != self.current_agent + ) + + if not self.emitted_start or agent_changed: + parts.extend(self._finish_step()) + if self.emitted_start: + parts.append(protocol.FinishPart(finish_reason="stop")) + + self.ui_message_id = msg.id + parts.append(protocol.StartPart(message_id=msg.id)) + parts.append(protocol.StartStepPart()) + self.emitted_start = True + self.in_step = True + self.current_agent = msg.agent + self.current_turn_id = msg.turn_id + self._reset_step_tracking() + return parts + + # Same UIMessage — check for step boundary via turn_id change. Only + # non-None → different-non-None transitions fire a step boundary; + # None carries the current step (tool results yielded by the loop are + # intentionally left unstamped until the next stream() stamps them). + if ( + msg.turn_id is not None + and self.current_turn_id is not None + and msg.turn_id != self.current_turn_id + ): + parts.extend(self._finish_step()) + parts.append(protocol.StartStepPart()) + self.in_step = True + self._reset_step_tracking() + self.current_turn_id = msg.turn_id + elif msg.turn_id is not None and self.current_turn_id is None: + self.current_turn_id = msg.turn_id + + return parts + + # -- phase: per-event (mid-stream) --------------------------------------- + + def on_event( + self, + msg: messages_.Message, + event: messages_.StreamEvent, + parts_by_id: dict[str, messages_.Part], + ) -> list[protocol.UIMessageStreamPart]: + part = parts_by_id.get(event.part_id) + if part is None: + return [] + + match event, part: + case messages_.PartOpened(), messages_.TextPart(): + self.open_text_ids.add(event.part_id) + return [protocol.TextStartPart(id=event.part_id)] + + case messages_.PartDelta(chunk=chunk), messages_.TextPart(): + if event.part_id not in self.open_text_ids: + self.open_text_ids.add(event.part_id) + return [ + protocol.TextStartPart(id=event.part_id), + protocol.TextDeltaPart(id=event.part_id, delta=chunk), + ] + return [protocol.TextDeltaPart(id=event.part_id, delta=chunk)] + + case messages_.PartClosed(), messages_.TextPart(): + if event.part_id in self.open_text_ids: + self.open_text_ids.discard(event.part_id) + return [protocol.TextEndPart(id=event.part_id)] + return [] + + case messages_.PartOpened(), messages_.ReasoningPart(): + self.open_reasoning_ids.add(event.part_id) + return [protocol.ReasoningStartPart(id=event.part_id)] + + case messages_.PartDelta(chunk=chunk), messages_.ReasoningPart(): + if event.part_id not in self.open_reasoning_ids: + self.open_reasoning_ids.add(event.part_id) + return [ + protocol.ReasoningStartPart(id=event.part_id), + protocol.ReasoningDeltaPart(id=event.part_id, delta=chunk), + ] + return [protocol.ReasoningDeltaPart(id=event.part_id, delta=chunk)] + + case messages_.PartClosed(), messages_.ReasoningPart(): + if event.part_id in self.open_reasoning_ids: + self.open_reasoning_ids.discard(event.part_id) + return [protocol.ReasoningEndPart(id=event.part_id)] + return [] + + case messages_.PartOpened(), messages_.ToolCallPart() as tc: + if tc.tool_call_id in self.started_tool_inputs: + return [] + self.started_tool_inputs.add(tc.tool_call_id) + return [ + protocol.ToolInputStartPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + ) + ] + + case messages_.PartDelta(chunk=chunk), messages_.ToolCallPart() as tc: + out: list[protocol.UIMessageStreamPart] = [] + if tc.tool_call_id not in self.started_tool_inputs: + self.started_tool_inputs.add(tc.tool_call_id) + out.append( + protocol.ToolInputStartPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + ) + ) + out.append( + protocol.ToolInputDeltaPart( + tool_call_id=tc.tool_call_id, + input_text_delta=chunk, + ) + ) + return out + + case messages_.PartClosed(), messages_.ToolCallPart(): + # ToolInputAvailablePart is emitted in ``on_terminal`` from + # the terminal ``tool_args`` snapshot. + return [] + + return [] + + # -- phase: terminal (tool results, approvals, final tool-input) --------- + + def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: + if not msg.is_done: + return [] + + out: list[protocol.UIMessageStreamPart] = [] + + # Close any blocks that were opened but didn't see an explicit + # PartClosed (e.g. provider terminates abruptly — safety net). + if msg.stream is not None: + opened_ids = { + e.part_id + for e in msg.stream.events + if isinstance(e, messages_.PartOpened) + } + for tid in list(self.open_text_ids): + if tid in opened_ids and not any( + isinstance(e, messages_.PartClosed) and e.part_id == tid + for e in msg.stream.events + ): + out.append(protocol.TextEndPart(id=tid)) + self.open_text_ids.discard(tid) + + for part in msg.parts: + if isinstance(part, messages_.ToolCallPart): + if part.tool_call_id in self.input_available_emitted: + continue + self.input_available_emitted.add(part.tool_call_id) + # Ensure ToolInputStart was emitted (no streaming events case). + if part.tool_call_id not in self.started_tool_inputs: + self.started_tool_inputs.add(part.tool_call_id) + out.append( + protocol.ToolInputStartPart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + ) + out.append( + protocol.ToolInputAvailablePart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + input=part.tool_args, + ) + ) + + elif isinstance(part, messages_.ToolResultPart): + if part.tool_call_id in self.emitted_tool_results: + continue + self.emitted_tool_results.add(part.tool_call_id) + if part.is_error: + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=part.tool_call_id, + error_text=_tool_error_text(part), + ) + ) + else: + out.append( + protocol.ToolOutputAvailablePart( + tool_call_id=part.tool_call_id, + output=part.result, + ) + ) + + elif isinstance(part, messages_.HookPart): + tc_id = _approvals.tool_call_id_for(part) + if tc_id is None: + continue + + if part.status == "pending": + if tc_id in self.emitted_approval_requests: + continue + self.emitted_approval_requests.add(tc_id) + out.append( + protocol.ToolApprovalRequestPart( + approval_id=part.hook_id, + tool_call_id=tc_id, + ) + ) + elif part.status == "resolved": + resolution: dict[str, Any] = part.resolution or {} + if not resolution.get("granted", False): + out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) + elif part.status == "cancelled": + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=tc_id, + error_text="Hook cancelled", + ) + ) + + return out + + # -- phase: stream finish ------------------------------------------------ + + def finish(self) -> list[protocol.UIMessageStreamPart]: + parts = self._finish_step() + if self.emitted_start: + parts.append(protocol.FinishPart(finish_reason="stop")) + return parts diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py new file mode 100644 index 00000000..c5d5809b --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -0,0 +1,66 @@ +"""Persisted-message → UIMessage list for history endpoints.""" + +from __future__ import annotations + +from .....types import messages as messages_ +from .. import _parts, ui_message + + +def to_ui_messages( + messages: list[messages_.Message], +) -> list[ui_message.UIMessage]: + """Group persisted messages into UIMessage bubbles. + + ``user``/``system`` messages become standalone UIMessages. Runs of + ``assistant``/``tool``/``internal`` messages merge into a single + assistant UIMessage, with tool results and approval state folded into + the corresponding tool-call parts. + """ + result: list[ui_message.UIMessage] = [] + + i = 0 + while i < len(messages): + msg = messages[i] + + if msg.role in ("user", "system"): + result.append( + ui_message.UIMessage( + id=msg.id, + role=msg.role, + parts=_parts.to_ui_parts(msg.parts), + ) + ) + i += 1 + continue + + if msg.role == "assistant": + ui_parts: list[ui_message.UIMessagePart] = [] + bubble_id = msg.id + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "internal", + ): + current = messages[i] + if current.role == "assistant": + ui_parts.extend(_parts.to_ui_parts(current.parts)) + elif current.role == "tool": + _parts.merge_tool_results(ui_parts, current.parts) + elif current.role == "internal": + _parts.merge_approval_signals(ui_parts, current.parts) + i += 1 + + result.append( + ui_message.UIMessage( + id=bubble_id, + role="assistant", + parts=ui_parts, + ) + ) + continue + + # Orphan tool / internal messages — skip; they have no assistant anchor. + i += 1 + + return result diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py new file mode 100644 index 00000000..019f2894 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -0,0 +1,39 @@ +"""Serialize the UI message stream as Server-Sent Events.""" + +from __future__ import annotations + +import dataclasses +import json +from collections.abc import AsyncGenerator, AsyncIterable + +from .....types import messages as messages_ +from .. import protocol +from .stream import to_stream + + +def _to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def serialize_part(part: protocol.UIMessageStreamPart) -> str: + """Serialize a stream part to JSON with camelCase keys.""" + d = dataclasses.asdict(part) + if isinstance(part, protocol.DataPart): + d["type"] = part.type + del d["data_type"] + camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} + return json.dumps(camel_dict) + + +def format_sse(part: protocol.UIMessageStreamPart) -> str: + """Format a stream part as an SSE data line.""" + return f"data: {serialize_part(part)}\n\n" + + +async def to_sse( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[str]: + """Convert an internal message stream into SSE strings.""" + async for part in to_stream(messages): + yield format_sse(part) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py new file mode 100644 index 00000000..169715bd --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -0,0 +1,43 @@ +"""Convert an internal ``ai.Message`` stream into AI SDK UI protocol parts.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterable + +from .....types import messages as messages_ +from .. import protocol +from ._state import _StreamState + + +async def to_stream( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[protocol.UIMessageStreamPart]: + """Walk ``messages`` once, emitting AI SDK UI stream parts. + + Drives off ``Message.stream.events`` for incremental deltas and + ``Message.parts`` for terminal tool input/output/approval parts. + Re-emitted messages (same id, already seen ``is_done``) are skipped. + """ + state = _StreamState() + + async for msg in messages: + if msg.id in state.seen_done: + continue + + for part in state.on_message(msg): + yield part + + if msg.stream is not None and msg.stream.events: + parts_by_id = {p.id: p for p in msg.parts} + for event in msg.stream.events: + for out in state.on_event(msg, event, parts_by_id): + yield out + + for part in state.on_terminal(msg): + yield part + + if msg.is_done: + state.seen_done.add(msg.id) + + for part in state.finish(): + yield part diff --git a/src/ai/agents/ui/ai_sdk/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_message.py index b886444b..406a699f 100644 --- a/src/ai/agents/ui/ai_sdk/ui_message.py +++ b/src/ai/agents/ui/ai_sdk/ui_message.py @@ -13,7 +13,7 @@ import pydantic -from ...types import messages as messages_ +from ....types import messages as messages_ class UITextPart(pydantic.BaseModel): diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py deleted file mode 100644 index dae8732f..00000000 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ /dev/null @@ -1,693 +0,0 @@ -""" -Based on: .reference/ai/packages/ai/src/ui/process-ui-message-stream.test.ts -""" - -from __future__ import annotations - -import asyncio -from collections.abc import AsyncGenerator - -import ai -from ai.adapters.ai_sdk_ui import adapter, ui_message -from ai.agents import hooks -from ai.types import messages - -from ...conftest import MOCK_MODEL, mock_llm, tool_call_msg - - -async def get_event_types(msgs: list[messages.Message]) -> list[str]: - """Stream messages through adapter and return event type sequence.""" - - async def stream() -> AsyncGenerator[messages.Message]: - for m in msgs: - yield m - - return [p.type async for p in adapter.to_ui_message_stream(stream())] - - -# ----------------------------------------------------------------------------- -# Event sequence tests -# ----------------------------------------------------------------------------- - - -async def test_text_streaming() -> None: - """Text: start -> start-step -> text-start/delta/end -> finish-step -> finish""" - msgs = [ - messages.Message( - id="msg-1", - role="assistant", - parts=[messages.TextPart(text="Hello", delta="Hello", state="streaming")], - ), - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart( - text="Hello, world!", delta=", world!", state="streaming" - ) - ], - ), - messages.Message( - id="msg-1", - role="assistant", - parts=[messages.TextPart(text="Hello, world!", state="done")], - ), - ] - - assert await get_event_types(msgs) == [ - "start", - "start-step", - "text-start", - "text-delta", - "text-delta", - "text-end", - "finish-step", - "finish", - ] - - -async def test_tool_roundtrip() -> None: - """Server-side tool: input-available -> output-available -> text response. - - Reference: process-ui-message-stream.test.ts "server-side tool roundtrip" - """ - msgs = [ - # Tool call (assistant message) - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args='{"city": "London"}', - state="done", - ), - ], - ), - # Tool result (tool message, pinned to same id for same step) - messages.Message( - id="msg-1", - role="tool", - parts=[ - messages.ToolResultPart( - tool_call_id="tc-1", - tool_name="get_weather", - result={"weather": "sunny"}, - ), - ], - ), - # Final text - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart(text="The weather is sunny.", state="done"), - ], - ), - ] - - assert await get_event_types(msgs) == [ - "start", - "start-step", - "tool-input-start", - "tool-input-available", - "tool-output-available", - "finish-step", - "start-step", - "text-start", - "text-end", - "finish-step", - "finish", - ] - - -async def test_text_then_tool_then_text() -> None: - """Full mothership scenario: text -> tool -> result -> final text. - - Input: "when will the robots take over?" - 1. Text: "I'll check with the mothership..." - 2. Tool: talk_to_mothership(question="...") - 3. Result: "Soon." - 4. Text: "According to the mothership: Soon." - """ - msgs = [ - # Streaming initial text - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart( - text="I'll check with the mothership.", - delta="I'll check with the mothership.", - state="streaming", - ) - ], - ), - # Text done + tool call - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart(text="I'll check with the mothership.", state="done"), - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="talk_to_mothership", - tool_args='{"question": "when?"}', - state="done", - ), - ], - ), - # Tool result (pinned to same id for same step) - messages.Message( - id="msg-1", - role="tool", - parts=[ - messages.ToolResultPart( - tool_call_id="tc-1", - tool_name="talk_to_mothership", - result={"answer": "Soon."}, - ), - ], - ), - # Final text (new message) - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart( - text="According to the mothership: Soon.", - delta="According to the mothership: Soon.", - state="streaming", - ) - ], - ), - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart( - text="According to the mothership: Soon.", state="done" - ) - ], - ), - ] - - # Per AI SDK protocol, tool-input-available and tool-output-available - # are in the SAME step (one LLM turn). Reference: - # process-ui-message-stream.test.ts - # "server-side tool roundtrip with multiple assistant texts" - assert await get_event_types(msgs) == [ - "start", - "start-step", - "text-start", - "text-delta", - "text-end", - "tool-input-start", - "tool-input-available", - "tool-output-available", # Same step as tool-input (AI SDK protocol) - "finish-step", - # New step for second LLM call (new message ID) - "start-step", - "text-start", - "text-delta", - "text-end", - "finish-step", - "finish", - ] - - -# ----------------------------------------------------------------------------- -# Integration tests - runtime-based execution -# ----------------------------------------------------------------------------- - - -@ai.tool -async def get_weather(city: str) -> str: - """Get weather for a city.""" - return f"Sunny in {city}" - - -async def test_runtime_tool_roundtrip() -> None: - """ - Integration test: run an Agent through agent.run() and verify - that tool-input-available and tool-output-available events are emitted. - - This test demonstrates the bug: the runtime yields the message with - the tool call, but by the time it's yielded the tool has already been - executed and the ToolPart has been mutated to status="result". The UI - adapter never sees the intermediate status="pending" state. - - Root cause: the default loop appends the message, then executes tools - which mutate the message in-place. The message was already yielded with - status="pending", but pydantic models are mutable so when we collect - them at the end, we see the mutated state. - """ - weather_agent = ai.agent(tools=[get_weather]) - - # First LLM call: returns a tool call - tool_call_response = [ - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args='{"city": "London"}', - state="done", - ), - ], - ), - ] - - # Second LLM call: returns final text - final_text_response = [ - messages.Message( - id="msg-2", - role="assistant", - parts=[messages.TextPart(text="The weather is sunny.", state="done")], - ), - ] - - mock_llm([tool_call_response, final_text_response]) - - # Collect all messages from the runtime - runtime_messages: list[messages.Message] = [] - async for msg in weather_agent.run( - MOCK_MODEL, [ai.user_message("What's the weather in London?")] - ): - runtime_messages.append(msg) - - # Stream through UI adapter - event_types = [ - p.type - async for p in adapter.to_ui_message_stream(_async_iter(runtime_messages)) - ] - - # This is what SHOULD happen: - # 1. First step streams tool call args then completes - # -> tool-input-start, tool-input-delta, tool-input-available - # 2. After tool execution, we yield the same message with - # status="result" -> tool-output-available - # (same step because same message ID) - # 3. Second LLM step streams text then completes - # -> text-start, text-delta, text-end, (final done msg) text-start, text-end - expected = [ - "start", - "start-step", - "tool-input-start", - "tool-input-delta", - "tool-input-available", - "tool-output-available", # Same step as input (same message ID) - "finish-step", - # Second LLM call (new message ID = new step) - "start-step", - "text-start", - "text-delta", - "text-end", - "text-start", # Final done message re-emits completed text - "text-end", - "finish-step", - "finish", - ] - - assert event_types == expected - - -async def _async_iter( - items: list[messages.Message], -) -> AsyncGenerator[messages.Message]: - """Helper to convert a list to an async generator.""" - for item in items: - yield item - - -# ----------------------------------------------------------------------------- -# UI → Internal conversion tests -# ----------------------------------------------------------------------------- - - -def test_ui_to_internal_two_turn_with_tool() -> None: - """Test converting a realistic two-turn conversation with tool call. - - This test uses the exact payload structure from a real AI SDK frontend - that was causing 422 validation errors due to: - 1. step-start parts (boundary markers) - 2. tool-{toolName} dynamic type pattern (e.g., "tool-talk_to_mothership") - """ - # Exact structure from a failing request - raw_messages = [ - { - "id": "lmaOqWZJdKOVUbYT", - "role": "user", - "parts": [{"type": "text", "text": "when will the robots take over?"}], - }, - { - "id": "d04b88d9a82e", - "role": "assistant", - "parts": [ - {"type": "step-start"}, - { - "type": "text", - "text": "I'll check with the mothership " - "about this important question.", - "state": "done", - }, - { - "type": "tool-talk_to_mothership", - "toolCallId": "toolu_01FiXNXhq1kHx4TegRjSaJyv", - "state": "output-available", - "input": '{"question": "when will the robots take over?"}', - "output": "Soon.", - }, - {"type": "text", "text": "", "state": "done"}, # Empty text part - {"type": "step-start"}, - { - "type": "text", - "text": "The mothership has spoken: Soon.", - "state": "done", - }, - {"type": "text", "text": "", "state": "done"}, # Empty text part - ], - }, - { - "id": "ZLi3qVpgZLBjwMxZ", - "role": "user", - "parts": [ - { - "type": "text", - "text": "this is a test run. can you remember the first turn?", - } - ], - }, - ] - - # Parse UI messages - this should NOT raise validation errors - ui_messages = [ui_message.UIMessage.model_validate(m) for m in raw_messages] - - # Verify parsing worked - assert len(ui_messages) == 3 - assert ui_messages[0].role == "user" - assert ui_messages[1].role == "assistant" - assert ui_messages[2].role == "user" - - # Check that step-start and tool parts were parsed correctly - assistant_parts = ui_messages[1].parts - assert isinstance(assistant_parts[0], ui_message.UIStepStartPart) - assert isinstance(assistant_parts[1], ui_message.UITextPart) - assert isinstance(assistant_parts[2], ui_message.UIToolPart) - assert assistant_parts[2].tool_name == "talk_to_mothership" - assert assistant_parts[2].state == "output-available" - - # Convert to internal format - internal = adapter.to_messages(ui_messages) - - # The single UI assistant message contains [text, tool(done), text] from - # two stream_loop iterations. to_messages splits at the tool-result - # boundary so LLM adapters receive one message per iteration: - # user, assistant (text + tool_call), tool (result), assistant (text), user - assert len(internal) == 5 - assert internal[0].role == "user" - assert internal[0].text == "when will the robots take over?" - - # First iteration: text + tool call (assistant message) - assert internal[1].role == "assistant" - assert internal[1].text == ( - "I'll check with the mothership about this important question." - ) - assert len(internal[1].tool_calls) == 1 - assert internal[1].tool_calls[0].tool_name == "talk_to_mothership" - assert internal[1].tool_calls[0].tool_call_id == "toolu_01FiXNXhq1kHx4TegRjSaJyv" - - # Tool result (separate tool message) - assert internal[2].role == "tool" - assert len(internal[2].tool_results) == 1 - assert internal[2].tool_results[0].result == {"value": "Soon."} - - # Second iteration: follow-up text - assert internal[3].role == "assistant" - assert internal[3].text == "The mothership has spoken: Soon." - assert len(internal[3].tool_calls) == 0 - - assert internal[4].role == "user" - assert internal[4].text == "this is a test run. can you remember the first turn?" - - -def test_ui_tool_part_with_dict_input() -> None: - """Test that tool parts with dict input (not JSON string) are handled.""" - raw_message = { - "id": "msg-1", - "role": "assistant", - "parts": [ - { - "type": "tool-get_weather", - "toolCallId": "tc-1", - "state": "input-available", - "input": {"city": "London"}, # Dict, not JSON string - } - ], - } - - ui_msg = ui_message.UIMessage.model_validate(raw_message) - internal = adapter.to_messages([ui_msg]) - - assert len(internal) == 1 - tool_part = internal[0].tool_calls[0] - assert tool_part.tool_name == "get_weather" - assert tool_part.tool_args == '{"city": "London"}' - # input-available means call is present but no result yet (no tool message) - - -def test_ui_file_part_converted_to_core_file_part() -> None: - """UIFilePart from the frontend is converted to a core FilePart.""" - raw_message = { - "id": "msg-1", - "role": "user", - "parts": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "file", - "mediaType": "image/png", - "url": "https://example.com/photo.png", - "filename": "photo.png", - }, - ], - } - ui_msg = ui_message.UIMessage.model_validate(raw_message) - internal = adapter.to_messages([ui_msg]) - - assert len(internal) == 1 - msg = internal[0] - assert msg.role == "user" - assert len(msg.parts) == 2 - assert isinstance(msg.parts[0], messages.TextPart) - assert isinstance(msg.parts[1], messages.FilePart) - fp = msg.parts[1] - assert fp.data == "https://example.com/photo.png" - assert fp.media_type == "image/png" - assert fp.filename == "photo.png" - - -def test_ui_skips_unsupported_parts() -> None: - """Test that unsupported part types are skipped gracefully.""" - raw_message = { - "id": "msg-1", - "role": "assistant", - "parts": [ - {"type": "text", "text": "Hello"}, - {"type": "data-custom", "data": {"foo": "bar"}}, # Unsupported - {"type": "unknown-type", "content": "xyz"}, # Unsupported - {"type": "text", "text": "World"}, - ], - } - - ui_msg = ui_message.UIMessage.model_validate(raw_message) - # Only text parts should be parsed (data-* and unknown skipped) - assert len(ui_msg.parts) == 2 - assert all(isinstance(p, ui_message.UITextPart) for p in ui_msg.parts) - - internal = adapter.to_messages([ui_msg]) - assert len(internal[0].parts) == 2 - - -# ----------------------------------------------------------------------------- -# Tool approval (human-in-the-loop) tests -# ----------------------------------------------------------------------------- - - -async def test_tool_approval_hook_emits_approval_request() -> None: - """Pending ToolApproval HookPart emits tool-approval-request on the wire. - - The HookPart message uses a *different* id from the tool message, - matching what the Runtime actually does (it creates an ad-hoc Message - with its own auto-generated id at runtime.py:452). The adapter must - keep both in the same step so the frontend's sendAutomaticallyWhen - helper can find the tool part when the user responds to the approval. - """ - msgs = [ - # Tool call (args complete, awaiting approval) - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="rm_rf", - tool_args='{"path": "/"}', - state="done", - ), - ], - ), - # Hook pending (approval requested) — different message id, - # just like the Runtime produces at runtime.py:452. - messages.Message( - id="hook-msg-1", - role="assistant", - parts=[ - messages.HookPart( - hook_id="approve_tc-1", - hook_type=hooks.TOOL_APPROVAL_HOOK_TYPE, - status="pending", - metadata={"tool_name": "rm_rf", "tool_args": '{"path": "/"}'}, - ), - ], - ), - ] - - event_types = await get_event_types(msgs) - # tool-approval-request must be in the SAME step as the tool input — - # no extra start-step/finish-step between them. - assert event_types == [ - "start", - "start-step", - "tool-input-start", - "tool-input-available", - "tool-approval-request", - "finish-step", - "finish", - ] - - -def test_approval_responded_resolves_hook() -> None: - """to_messages() resolves the ToolApproval hook for approval-responded parts.""" - label = "approve_tc-42" - raw_messages = [ - { - "id": "msg-1", - "role": "assistant", - "parts": [ - { - "type": "tool-dangerous_action", - "toolCallId": "tc-42", - "state": "approval-responded", - "input": '{"x": 1}', - "approval": { - "id": label, - "approved": True, - "reason": "looks safe", - }, - } - ], - }, - ] - - # Clean up any leftover state from other tests - hooks._pending_resolutions.pop(label, None) - - ui_msgs = [ui_message.UIMessage.model_validate(m) for m in raw_messages] - adapter.to_messages(ui_msgs) - - # The side-effect should have pre-registered the resolution - assert label in hooks._pending_resolutions - resolution = hooks._pending_resolutions.pop(label) - assert resolution == {"granted": True, "reason": "looks safe"} - - -async def test_runtime_tool_approval_same_step() -> None: - """E2E: tool-approval-request must land in the same SSE step as the tool call. - - Runs a graph with ToolApproval (interrupt_loop=True) through agent.run(), - collects runtime messages, streams through the adapter, and asserts - that no spurious step boundary appears between tool-input-available - and tool-approval-request. - """ - from collections.abc import AsyncGenerator as AG - - @ai.tool - async def dangerous_action(path: str) -> str: - """Do something dangerous.""" - return f"deleted {path}" - - approval_agent = ai.agent(tools=[dangerous_action]) - - @approval_agent.loop - async def custom(context: ai.Context) -> AG[ai.Message]: - stream = await ai.models.stream( - context.model, context.messages, tools=context.tools - ) - async for msg in stream: - yield msg - - tool_calls = context.resolve(stream.tool_calls) - if not tool_calls: - return - - async def approve_and_execute(tc: ai.ToolCall) -> ai.Message: - approval = await ai.hook( - f"approve_{tc.id}", - payload=ai.ToolApproval, - metadata={"tool_name": tc.name}, - interrupt_loop=True, - ) - if approval.granted: - return await tc() - return ai.Message( - role="tool", - parts=[ - ai.ToolResultPart( - tool_call_id=tc.id, - tool_name=tc.name, - result="denied", - is_error=True, - ) - ], - ) - - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(approve_and_execute(tc)) for tc in tool_calls] - yield ai.tool_message(*(t.result() for t in tasks)) - - mock_llm( - [ - [ - tool_call_msg( - tc_id="tc-1", - name="dangerous_action", - args='{"path": "/tmp"}', - ) - ], - ] - ) - - runtime_messages: list[messages.Message] = [] - async for msg in approval_agent.run(MOCK_MODEL, [ai.user_message("delete /tmp")]): - runtime_messages.append(msg) - - # Stream through UI adapter - event_types = [ - p.type - async for p in adapter.to_ui_message_stream(_async_iter(runtime_messages)) - ] - - # tool-approval-request must be in the SAME step as tool-input. - assert event_types == [ - "start", - "start-step", - "tool-input-start", - "tool-input-delta", - "tool-input-available", - "tool-approval-request", - "finish-step", - "finish", - ] diff --git a/tests/adapters/__init__.py b/tests/agents/ui/__init__.py similarity index 100% rename from tests/adapters/__init__.py rename to tests/agents/ui/__init__.py diff --git a/tests/adapters/ai_sdk_ui/__init__.py b/tests/agents/ui/ai_sdk/__init__.py similarity index 100% rename from tests/adapters/ai_sdk_ui/__init__.py rename to tests/agents/ui/ai_sdk/__init__.py diff --git a/tests/agents/ui/ai_sdk/outbound/__init__.py b/tests/agents/ui/ai_sdk/outbound/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py new file mode 100644 index 00000000..7ebdf3e8 --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import to_ui_messages +from ai.agents.ui.ai_sdk.ui_message import ( + UITextPart, + UIToolPart, +) +from ai.types import messages as messages_ + + +def test_to_ui_messages_user_and_assistant() -> None: + msgs = [ + messages_.Message(id="u1", role="user", parts=[messages_.TextPart(text="hi")]), + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="hello back")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + assert result[1].id == "a1" + + +def test_to_ui_messages_merges_assistant_tool_internal() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.TextPart(text="calling"), + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + ], + ), + messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 2}, + ) + ], + ), + messages_.Message( + role="assistant", + parts=[messages_.TextPart(text="done")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + ui_msg = result[0] + assert ui_msg.role == "assistant" + assert ui_msg.id == "a1" + assert isinstance(ui_msg.parts[0], UITextPart) + assert ui_msg.parts[0].text == "calling" + assert isinstance(ui_msg.parts[1], UIToolPart) + assert ui_msg.parts[1].state == "output-available" + assert ui_msg.parts[1].output == {"hits": 2} + assert isinstance(ui_msg.parts[2], UITextPart) + assert ui_msg.parts[2].text == "done" + + +def test_to_ui_messages_internal_role_merges_approval() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + ), + messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + ] + result = to_ui_messages(msgs) + ui_msg = result[0] + tool_part = ui_msg.parts[0] + assert isinstance(tool_part, UIToolPart) + assert tool_part.state == "approval-requested" + assert tool_part.approval is not None + assert tool_part.approval.id == "approve_tc1" + + +def test_to_ui_messages_user_message_uses_own_id() -> None: + msgs = [ + messages_.Message(id="u1", role="user", parts=[messages_.TextPart(text="a")]) + ] + result = to_ui_messages(msgs) + assert result[0].id == "u1" + + +def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="first")], + ), + messages_.Message( + id="a2", + role="assistant", + parts=[messages_.TextPart(text="second")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + assert result[0].id == "a1" diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py new file mode 100644 index 00000000..930c385b --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator + +from ai.agents.ui.ai_sdk import protocol, to_sse +from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part +from ai.types import messages as messages_ + + +def test_serialize_part_camelcases_keys() -> None: + part = protocol.StartPart(message_id="m1") + payload = json.loads(serialize_part(part)) + assert payload == {"type": "start", "messageId": "m1"} + + +def test_format_sse_wraps_data_line() -> None: + part = protocol.TextDeltaPart(id="t1", delta="hi") + line = format_sse(part) + assert line.startswith("data: ") + assert line.endswith("\n\n") + + +def test_serialize_data_part_uses_type_with_prefix() -> None: + part = protocol.DataPart(data_type="custom", data={"k": 1}) + payload = json.loads(serialize_part(part)) + assert payload["type"] == "data-custom" + assert "dataType" not in payload + + +async def _gen( + msgs: list[messages_.Message], +) -> AsyncGenerator[messages_.Message]: + for m in msgs: + yield m + + +async def test_to_sse_emits_data_prefixed_lines() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hi")], + stream=messages_.MessageStreamState(events=[], is_done=True), + ) + ] + lines = [line async for line in to_sse(_gen(msgs))] + assert all(line.startswith("data: ") for line in lines) + # first line is the start part + first = json.loads(lines[0].removeprefix("data: ").rstrip()) + assert first["type"] == "start" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py new file mode 100644 index 00000000..1a680525 --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator + +from ai.agents.ui.ai_sdk import protocol, to_stream +from ai.types import messages as messages_ + + +async def _gen( + msgs: list[messages_.Message], +) -> AsyncGenerator[messages_.Message]: + for m in msgs: + yield m + + +async def _collect( + msgs: list[messages_.Message], +) -> list[protocol.UIMessageStreamPart]: + return [p async for p in to_stream(_gen(msgs))] + + +def _text_stream_message( + msg_id: str, + turn_id: str | None, + text_id: str, + chunk: str, + *, + is_done: bool, + full_text: str | None = None, +) -> messages_.Message: + events: list[messages_.StreamEvent] + if is_done: + events = [messages_.PartClosed(part_id=text_id)] + else: + events = [messages_.PartDelta(part_id=text_id, chunk=chunk)] + return messages_.Message( + id=msg_id, + role="assistant", + turn_id=turn_id, + parts=[messages_.TextPart(id=text_id, text=full_text or chunk)], + stream=messages_.MessageStreamState(events=events, is_done=is_done), + ) + + +async def test_event_driven_text_streaming() -> None: + text_id = "txt1" + msgs = [ + # Initial: PartOpened + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(id=text_id, text="")], + stream=messages_.MessageStreamState( + events=[messages_.PartOpened(part_id=text_id)], + is_done=False, + ), + ), + # Delta: "hi" + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(id=text_id, text="hi")], + stream=messages_.MessageStreamState( + events=[messages_.PartDelta(part_id=text_id, chunk="hi")], + is_done=False, + ), + ), + # Closed + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(id=text_id, text="hi")], + stream=messages_.MessageStreamState( + events=[messages_.PartClosed(part_id=text_id)], + is_done=True, + ), + ), + ] + out = await _collect(msgs) + # expect: Start, StartStep, TextStart, TextDelta, TextEnd, FinishStep, Finish + assert isinstance(out[0], protocol.StartPart) + assert out[0].message_id == "m1" + assert isinstance(out[1], protocol.StartStepPart) + assert isinstance(out[2], protocol.TextStartPart) and out[2].id == text_id + assert isinstance(out[3], protocol.TextDeltaPart) and out[3].delta == "hi" + assert isinstance(out[4], protocol.TextEndPart) and out[4].id == text_id + assert isinstance(out[5], protocol.FinishStepPart) + assert isinstance(out[6], protocol.FinishPart) + + +async def test_turn_id_change_emits_step_boundary() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hello")], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + messages_.Message( + id="m2", + role="assistant", + turn_id="t2", # different turn → step boundary + parts=[messages_.TextPart(text="world")], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + ] + out = await _collect(msgs) + # Look for FinishStep followed by StartStep between messages. + has_mid_step_boundary = any( + isinstance(out[i], protocol.FinishStepPart) + and i + 1 < len(out) + and isinstance(out[i + 1], protocol.StartStepPart) + for i in range(1, len(out) - 1) + ) + assert has_mid_step_boundary + + +async def test_agent_change_emits_message_boundary() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + agent="a1", + parts=[messages_.TextPart(text="from a")], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + messages_.Message( + id="m2", + role="assistant", + agent="a2", # different agent → FinishPart + StartPart + parts=[messages_.TextPart(text="from b")], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + ] + out = await _collect(msgs) + # There should be a FinishPart+StartPart pair mid-stream. + has_mid_msg_boundary = any( + isinstance(out[i], protocol.FinishPart) + and i + 1 < len(out) + and isinstance(out[i + 1], protocol.StartPart) + for i in range(1, len(out) - 1) + ) + assert has_mid_msg_boundary + + +async def test_tool_call_and_result_emit_terminal_parts() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ) + ], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ), + ] + out = await _collect(msgs) + types = [type(p).__name__ for p in out] + assert "ToolInputStartPart" in types + assert "ToolInputAvailablePart" in types + assert "ToolOutputAvailablePart" in types + + +async def test_approval_request_hook_emits_approval_part() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + stream=messages_.MessageStreamState(events=[], is_done=True), + ), + messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + ] + out = await _collect(msgs) + approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] + assert len(approval_parts) == 1 + assert approval_parts[0].tool_call_id == "tc1" + assert approval_parts[0].approval_id == "approve_tc1" + + +async def test_dedup_on_reemitted_message_id() -> None: + msg = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(id="txt1", text="hi")], + stream=messages_.MessageStreamState( + events=[ + messages_.PartOpened(part_id="txt1"), + messages_.PartDelta(part_id="txt1", chunk="hi"), + messages_.PartClosed(part_id="txt1"), + ], + is_done=True, + ), + ) + out = await _collect([msg, msg]) # re-emit the same done message + text_deltas = [p for p in out if isinstance(p, protocol.TextDeltaPart)] + # only the first emission should fire a TextDelta + assert len(text_deltas) == 1 diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py new file mode 100644 index 00000000..e0733757 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import _approvals +from ai.types import messages as messages_ + + +def test_tool_call_id_for_strips_prefix() -> None: + hook = messages_.HookPart( + hook_id="approve_tc_42", + hook_type="ToolApproval", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) == "tc_42" + + +def test_tool_call_id_for_rejects_non_approval_type() -> None: + hook = messages_.HookPart( + hook_id="approve_tc_42", + hook_type="SomethingElse", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) is None + + +def test_tool_call_id_for_rejects_bad_prefix() -> None: + hook = messages_.HookPart( + hook_id="tc_42", + hook_type="ToolApproval", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) is None + + +def test_is_tool_approval_message_detects_all_approval_hooks() -> None: + msg = messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc_1", + hook_type="ToolApproval", + status="pending", + ), + ], + ) + assert _approvals.is_tool_approval_message(msg) + + +def test_is_tool_approval_message_false_for_non_approval() -> None: + msg = messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="other", + hook_type="Something", + status="pending", + ), + ], + ) + assert not _approvals.is_tool_approval_message(msg) + + +def test_is_tool_approval_message_false_for_empty() -> None: + msg = messages_.Message(role="internal", parts=[]) + assert not _approvals.is_tool_approval_message(msg) diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py new file mode 100644 index 00000000..9b808d82 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from ai.agents.ui.ai_sdk import to_messages +from ai.agents.ui.ai_sdk.inbound import ( + _normalize_ui_messages, + extract_approvals, +) +from ai.agents.ui.ai_sdk.ui_message import UIMessage, UIToolPart + + +def _ui(role: str, *parts: dict[str, Any], id: str = "m1") -> UIMessage: + return UIMessage.model_validate({"id": id, "role": role, "parts": list(parts)}) + + +def _text(text: str) -> dict[str, Any]: + return {"type": "text", "text": text} + + +def _tool( + tool_name: str, + tool_call_id: str, + state: str, + **extra: Any, +) -> dict[str, Any]: + return { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "state": state, + **extra, + } + + +def test_to_messages_user_text() -> None: + result = to_messages([_ui("user", _text("hello"))]) + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].text == "hello" + + +def test_to_messages_splits_at_tool_boundary() -> None: + result = to_messages( + [ + _ui( + "assistant", + _text("before"), + _tool( + "search", + "tc1", + "output-available", + input={"q": "x"}, + output={"hits": 3}, + ), + _text("after"), + ) + ] + ) + assert [m.role for m in result] == ["assistant", "tool", "assistant"] + assert result[1].tool_results[0].tool_call_id == "tc1" + + +def test_to_messages_approval_hook_emitted_as_internal() -> None: + result = to_messages( + [ + _ui( + "assistant", + _tool( + "delete", + "tc1", + "approval-requested", + approval={"id": "approve_tc1"}, + ), + ) + ], + apply_approvals_=False, + ) + assert [m.role for m in result] == ["assistant", "internal"] + hook = result[1].parts[0] + assert hook.type == "hook" + assert hook.hook_id == "approve_tc1" + + +def test_to_messages_strips_trailing_assistant_when_approved() -> None: + result = to_messages( + [ + _ui("user", _text("delete it"), id="u1"), + _ui( + "assistant", + _tool( + "delete", + "tc1", + "approval-responded", + approval={"id": "approve_tc1", "approved": True, "reason": None}, + ), + id="a1", + ), + ], + apply_approvals_=False, + ) + assert [m.role for m in result] == ["user", "internal"] + + +def test_extract_approvals_returns_approved_responses() -> None: + approvals = extract_approvals( + [ + _ui( + "assistant", + _tool( + "x", + "tc1", + "approval-responded", + approval={ + "id": "approve_tc1", + "approved": False, + "reason": "nope", + }, + ), + ) + ] + ) + assert len(approvals) == 1 + assert approvals[0].hook_id == "approve_tc1" + assert approvals[0].granted is False + assert approvals[0].reason == "nope" + + +def test_normalize_ui_messages_heals_stale_tool_state() -> None: + ui = [ + _ui( + "assistant", + _tool("x", "tc1", "input-available", output={"ok": True}), + ) + ] + normalized = _normalize_ui_messages(ui) + tool_part = normalized[0].parts[0] + assert isinstance(tool_part, UIToolPart) + assert tool_part.state == "output-available" + + +def test_to_messages_rejects_empty_user() -> None: + ui = [UIMessage.model_validate({"id": "u1", "role": "user", "parts": []})] + with pytest.raises(ValueError): + to_messages(ui) diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py new file mode 100644 index 00000000..69ca1566 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_parts.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import _parts +from ai.agents.ui.ai_sdk.ui_message import ( + UIReasoningPart, + UITextPart, + UIToolApproval, + UIToolPart, +) +from ai.types import messages as messages_ + + +def test_to_ui_parts_text_and_reasoning() -> None: + parts: list[messages_.Part] = [ + messages_.ReasoningPart(text="thinking"), + messages_.TextPart(text="hi"), + ] + ui_parts = _parts.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIReasoningPart) + assert ui_parts[0].reasoning == "thinking" + assert isinstance(ui_parts[1], UITextPart) + assert ui_parts[1].text == "hi" + + +def test_to_ui_parts_tool_call_parses_json_args() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q": "x"}', + ) + ] + ui_parts = _parts.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIToolPart) + assert ui_parts[0].type == "tool-search" + assert ui_parts[0].input == {"q": "x"} + assert ui_parts[0].state == "input-available" + + +def test_merge_tool_results_updates_state_and_output() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + ] + ui_parts = _parts.to_ui_parts(parts) + _parts.merge_tool_results( + ui_parts, + [ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 3}, + ) + ], + ) + merged = ui_parts[0] + assert isinstance(merged, UIToolPart) + assert merged.state == "output-available" + assert merged.output == {"hits": 3} + + +def test_merge_approval_signals_pending_then_resolved() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ] + ui_parts = _parts.to_ui_parts(parts) + + _parts.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ) + requested = ui_parts[0] + assert isinstance(requested, UIToolPart) + assert requested.state == "approval-requested" + assert isinstance(requested.approval, UIToolApproval) + + _parts.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + resolution={"granted": True, "reason": None}, + ) + ], + ) + responded = ui_parts[0] + assert isinstance(responded, UIToolPart) + assert responded.state == "approval-responded" + assert responded.approval is not None + assert responded.approval.approved is True From 5c1d7d9dfcf2105e900c12342971ecd41ab35c3c Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Sat, 18 Apr 2026 11:05:33 -0700 Subject: [PATCH 07/10] Rename some properties and types to be less confusing --- examples/multiagent-textual/client.py | 2 +- examples/samples/streaming_tool.py | 4 +-- src/ai/__init__.py | 4 +-- src/ai/agents/agent.py | 2 +- src/ai/agents/ui/ai_sdk/_parts.py | 2 +- src/ai/agents/ui/ai_sdk/outbound/_state.py | 10 +++--- src/ai/agents/ui/ai_sdk/outbound/stream.py | 6 ++-- src/ai/models/core/helpers/streaming.py | 4 +-- src/ai/types/__init__.py | 4 +-- src/ai/types/messages.py | 20 +++++------ tests/agents/test_generator_tools.py | 8 ++--- tests/agents/ui/ai_sdk/outbound/test_sse.py | 2 +- .../agents/ui/ai_sdk/outbound/test_stream.py | 34 +++++++++---------- tests/models/core/test_streaming.py | 34 +++++++++++-------- 14 files changed, 71 insertions(+), 65 deletions(-) diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index ec246a6a..85541c23 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -158,7 +158,7 @@ async def run_websocket(self) -> None: # ------------------------------------------------------------------ def _handle_message(self, msg: ai.Message) -> None: - label = msg.agent or "unknown" + label = msg.source_label or "unknown" if (hook_part := msg.get_hook_part()) is not None: if hook_part.status == "pending": diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index c53a58c3..06c57f02 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -18,7 +18,7 @@ async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Message]: yield ai.Message( role="assistant", parts=[ai.TextPart(text=step)], - agent="tool_progress", + source_label="tool_progress", ) await asyncio.sleep(0.3) @@ -40,7 +40,7 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.agent == "tool_progress": + if msg.source_label == "tool_progress": print(f" [{msg.text}]") elif msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index f70034a6..9c612923 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -35,13 +35,13 @@ FilePart, HookPart, Message, - MessageStreamState, Part, PartClosed, PartDelta, PartOpened, ReasoningPart, StreamResultLike, + StreamState, StructuredOutputPart, TextPart, ToolCallPart, @@ -64,7 +64,6 @@ __all__ = [ # Types (from types/) "Message", - "MessageStreamState", "Part", "PartClosed", "PartDelta", @@ -76,6 +75,7 @@ "ReasoningPart", "FilePart", "HookPart", + "StreamState", "StructuredOutputPart", "ToolLike", "ToolSchema", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 37a6153f..cd8bd3be 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -336,7 +336,7 @@ async def _real( source = _collect_messages(loop_fn(context), context.messages) async for message in runtime.run(source): if call.label is not None: - message = message.model_copy(update={"agent": call.label}) + message = message.model_copy(update={"source_label": call.label}) yield message # Activate middleware for this run (and everything it calls). diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py index 7c198cbb..9d4d6ecc 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -2,7 +2,7 @@ Used by ``outbound.history`` to reconstruct UIMessages from persisted ``ai.Message`` lists. The live outbound stream does not use these — it -emits wire-protocol deltas directly from ``Message.stream.events``. +emits wire-protocol deltas directly from ``Message.stream.new_events``. """ from __future__ import annotations diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index 58be0775..77b592d5 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -82,8 +82,8 @@ def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPar agent_changed = ( self.emitted_start - and msg.agent is not None - and msg.agent != self.current_agent + and msg.source_label is not None + and msg.source_label != self.current_agent ) if not self.emitted_start or agent_changed: @@ -96,7 +96,7 @@ def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPar parts.append(protocol.StartStepPart()) self.emitted_start = True self.in_step = True - self.current_agent = msg.agent + self.current_agent = msg.source_label self.current_turn_id = msg.turn_id self._reset_step_tracking() return parts @@ -220,13 +220,13 @@ def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPa if msg.stream is not None: opened_ids = { e.part_id - for e in msg.stream.events + for e in msg.stream.new_events if isinstance(e, messages_.PartOpened) } for tid in list(self.open_text_ids): if tid in opened_ids and not any( isinstance(e, messages_.PartClosed) and e.part_id == tid - for e in msg.stream.events + for e in msg.stream.new_events ): out.append(protocol.TextEndPart(id=tid)) self.open_text_ids.discard(tid) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 169715bd..9c2ab96f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -14,7 +14,7 @@ async def to_stream( ) -> AsyncGenerator[protocol.UIMessageStreamPart]: """Walk ``messages`` once, emitting AI SDK UI stream parts. - Drives off ``Message.stream.events`` for incremental deltas and + Drives off ``Message.stream.new_events`` for incremental deltas and ``Message.parts`` for terminal tool input/output/approval parts. Re-emitted messages (same id, already seen ``is_done``) are skipped. """ @@ -27,9 +27,9 @@ async def to_stream( for part in state.on_message(msg): yield part - if msg.stream is not None and msg.stream.events: + if msg.stream is not None and msg.stream.new_events: parts_by_id = {p.id: p for p in msg.parts} - for event in msg.stream.events: + for event in msg.stream.new_events: for out in state.on_event(msg, event, parts_by_id): yield out diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index 855a7a99..3546cb2b 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -217,8 +217,8 @@ def _build_message( role="assistant", parts=parts, usage=self._usage if self._is_done else None, - stream=messages_.MessageStreamState( - events=stream_events, + stream=messages_.StreamState( + new_events=stream_events, is_done=self._is_done, ), ) diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index 4ce2d699..f10ce117 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -8,12 +8,12 @@ FilePart, HookPart, Message, - MessageStreamState, Part, PartClosed, PartDelta, PartOpened, ReasoningPart, + StreamState, StructuredOutputPart, TextPart, ToolCallPart, @@ -29,13 +29,13 @@ "FilePart", "HookPart", "Message", - "MessageStreamState", "Part", "PartClosed", "PartDelta", "PartOpened", "ReasoningPart", "StreamResultLike", + "StreamState", "StructuredOutputPart", "TextPart", "ToolCallPart", diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 5ced280a..536d21a5 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -299,14 +299,14 @@ class PartClosed(pydantic.BaseModel): ] -class MessageStreamState(pydantic.BaseModel): +class StreamState(pydantic.BaseModel): """Transient streaming state attached to a Message during streaming. - ``events`` contains the events since the previous yield — never cumulative. + ``new_events`` contains the events since the previous yield — never cumulative. ``is_done`` is True once the stream has finished. """ - events: list[StreamEvent] = pydantic.Field(default_factory=list) + new_events: list[StreamEvent] = pydantic.Field(default_factory=list) is_done: bool = False @@ -317,9 +317,9 @@ class Message(pydantic.BaseModel): parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) turn_id: str | None = None - agent: str | None = None + source_label: str | None = None usage: Usage | None = None - stream: MessageStreamState | None = pydantic.Field(default=None, exclude=True) + stream: StreamState | None = pydantic.Field(default=None, exclude=True) @overload def replace(self, new: Part, /) -> Message: ... @@ -380,11 +380,11 @@ def _parts_by_id(self) -> dict[str, Part]: @property def text_delta(self) -> str: - """Derive from ``stream.events`` — first PartDelta whose part is TextPart.""" + """First PartDelta in ``stream.new_events`` whose part is a TextPart.""" if self.stream is None: return "" parts_map = self._parts_by_id() - for ev in self.stream.events: + for ev in self.stream.new_events: if isinstance(ev, PartDelta): part = parts_map.get(ev.part_id) if isinstance(part, TextPart): @@ -397,7 +397,7 @@ def reasoning_delta(self) -> str: if self.stream is None: return "" parts_map = self._parts_by_id() - for ev in self.stream.events: + for ev in self.stream.new_events: if isinstance(ev, PartDelta): part = parts_map.get(ev.part_id) if isinstance(part, ReasoningPart): @@ -406,12 +406,12 @@ def reasoning_delta(self) -> str: @property def tool_deltas(self) -> list[ToolDelta]: - """Derive from ``stream.events`` — PartDeltas whose parts are ToolCallPart.""" + """PartDeltas in ``stream.new_events`` whose parts are ToolCallPart.""" if self.stream is None: return [] parts_map = self._parts_by_id() deltas: list[ToolDelta] = [] - for ev in self.stream.events: + for ev in self.stream.new_events: if isinstance(ev, PartDelta): part = parts_map.get(ev.part_id) if isinstance(part, ToolCallPart): diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 6969d03d..c1e6a4f7 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -25,7 +25,7 @@ async def progress_tool(query: str) -> AsyncGenerator[ai.Message]: yield ai.Message( role="assistant", parts=[messages_.TextPart(text="Working...")], - agent="progress", + source_label="progress", ) yield ai.Message( role="assistant", @@ -51,7 +51,7 @@ async def test_generator_tool_streams_and_returns_result() -> None: assert llm.call_count == 2 # Intermediate progress message was forwarded to consumer. - progress = [m for m in collected if m.agent == "progress"] + progress = [m for m in collected if m.source_label == "progress"] assert len(progress) == 1 assert progress[0].text == "Working..." @@ -175,7 +175,7 @@ async def test_yield_from_nested_agent() -> None: assert adapter.call_count == 3 # Inner messages were forwarded to the consumer with label="inner". - inner_msgs = [m for m in collected if m.agent == "inner"] + inner_msgs = [m for m in collected if m.source_label == "inner"] assert len(inner_msgs) > 0 # The outer LLM's second call (index 2) must NOT contain any inner @@ -187,7 +187,7 @@ async def test_yield_from_nested_agent() -> None: # Specifically: no inner assistant text or inner tool results leaked. for m in outer_turn2_msgs: - assert m.agent is None or m.agent != "inner" + assert m.source_label is None or m.source_label != "inner" if m.role == "assistant": # This must be the outer tool-call message, not inner text. assert len(m.tool_calls) > 0 diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 930c385b..af7f998f 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -42,7 +42,7 @@ async def test_to_sse_emits_data_prefixed_lines() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(text="hi")], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ) ] lines = [line async for line in to_sse(_gen(msgs))] diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 1a680525..23cc4c05 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -38,7 +38,7 @@ def _text_stream_message( role="assistant", turn_id=turn_id, parts=[messages_.TextPart(id=text_id, text=full_text or chunk)], - stream=messages_.MessageStreamState(events=events, is_done=is_done), + stream=messages_.StreamState(new_events=events, is_done=is_done), ) @@ -51,8 +51,8 @@ async def test_event_driven_text_streaming() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(id=text_id, text="")], - stream=messages_.MessageStreamState( - events=[messages_.PartOpened(part_id=text_id)], + stream=messages_.StreamState( + new_events=[messages_.PartOpened(part_id=text_id)], is_done=False, ), ), @@ -62,8 +62,8 @@ async def test_event_driven_text_streaming() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(id=text_id, text="hi")], - stream=messages_.MessageStreamState( - events=[messages_.PartDelta(part_id=text_id, chunk="hi")], + stream=messages_.StreamState( + new_events=[messages_.PartDelta(part_id=text_id, chunk="hi")], is_done=False, ), ), @@ -73,8 +73,8 @@ async def test_event_driven_text_streaming() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(id=text_id, text="hi")], - stream=messages_.MessageStreamState( - events=[messages_.PartClosed(part_id=text_id)], + stream=messages_.StreamState( + new_events=[messages_.PartClosed(part_id=text_id)], is_done=True, ), ), @@ -98,14 +98,14 @@ async def test_turn_id_change_emits_step_boundary() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(text="hello")], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), messages_.Message( id="m2", role="assistant", turn_id="t2", # different turn → step boundary parts=[messages_.TextPart(text="world")], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), ] out = await _collect(msgs) @@ -124,16 +124,16 @@ async def test_agent_change_emits_message_boundary() -> None: messages_.Message( id="m1", role="assistant", - agent="a1", + source_label="a1", parts=[messages_.TextPart(text="from a")], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), messages_.Message( id="m2", role="assistant", - agent="a2", # different agent → FinishPart + StartPart + source_label="a2", # different source → FinishPart + StartPart parts=[messages_.TextPart(text="from b")], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), ] out = await _collect(msgs) @@ -161,7 +161,7 @@ async def test_tool_call_and_result_emit_terminal_parts() -> None: tool_args='{"q":"x"}', ) ], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), messages_.Message( role="tool", @@ -195,7 +195,7 @@ async def test_approval_request_hook_emits_approval_part() -> None: tool_args="{}", ) ], - stream=messages_.MessageStreamState(events=[], is_done=True), + stream=messages_.StreamState(new_events=[], is_done=True), ), messages_.Message( role="internal", @@ -221,8 +221,8 @@ async def test_dedup_on_reemitted_message_id() -> None: role="assistant", turn_id="t1", parts=[messages_.TextPart(id="txt1", text="hi")], - stream=messages_.MessageStreamState( - events=[ + stream=messages_.StreamState( + new_events=[ messages_.PartOpened(part_id="txt1"), messages_.PartDelta(part_id="txt1", chunk="hi"), messages_.PartClosed(part_id="txt1"), diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py index 7490a934..78e3c3b6 100644 --- a/tests/models/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -17,7 +17,9 @@ def test_text_lifecycle() -> None: assert isinstance(part, messages.TextPart) assert part.text == "" assert m.stream is not None - assert any(isinstance(e, PartOpened) and e.part_id == "b1" for e in m.stream.events) + assert any( + isinstance(e, PartOpened) and e.part_id == "b1" for e in m.stream.new_events + ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) part = m.parts[0] @@ -26,7 +28,7 @@ def test_text_lifecycle() -> None: assert m.stream is not None assert any( isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == "Hello" - for e in m.stream.events + for e in m.stream.new_events ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) @@ -36,16 +38,18 @@ def test_text_lifecycle() -> None: assert m.stream is not None assert any( isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == " world" - for e in m.stream.events + for e in m.stream.new_events ) m = h.handle_event(streaming.TextEnd(block_id="b1")) part = m.parts[0] assert isinstance(part, messages.TextPart) assert m.stream is not None - assert any(isinstance(e, PartClosed) and e.part_id == "b1" for e in m.stream.events) + assert any( + isinstance(e, PartClosed) and e.part_id == "b1" for e in m.stream.new_events + ) # No delta events in this yield - assert not any(isinstance(e, PartDelta) for e in m.stream.events) + assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) # -- Reasoning streaming --------------------------------------------------- @@ -61,7 +65,7 @@ def test_reasoning_lifecycle() -> None: assert m.stream is not None assert any( isinstance(e, PartDelta) and e.part_id == "r1" and e.chunk == "thinking" - for e in m.stream.events + for e in m.stream.new_events ) m = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) @@ -69,7 +73,9 @@ def test_reasoning_lifecycle() -> None: assert isinstance(part, messages.ReasoningPart) assert part.signature == "sig123" assert m.stream is not None - assert any(isinstance(e, PartClosed) and e.part_id == "r1" for e in m.stream.events) + assert any( + isinstance(e, PartClosed) and e.part_id == "r1" for e in m.stream.new_events + ) # -- Tool streaming -------------------------------------------------------- @@ -86,7 +92,7 @@ def test_tool_lifecycle() -> None: assert m.stream is not None assert any( isinstance(e, PartDelta) and e.part_id == "tc1" and e.chunk == '{"ci' - for e in m.stream.events + for e in m.stream.new_events ) m = h.handle_event( @@ -101,10 +107,10 @@ def test_tool_lifecycle() -> None: assert isinstance(part, messages.ToolCallPart) assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.events + isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.new_events ) # No delta events in this yield - assert not any(isinstance(e, PartDelta) for e in m.stream.events) + assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) # -- Multi-part messages --------------------------------------------------- @@ -132,7 +138,7 @@ def test_reasoning_then_text_then_tool() -> None: # The last event was ToolEnd(tc1), so only that PartClosed is in events assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.events + isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.new_events ) @@ -158,7 +164,7 @@ def test_multiple_tool_calls() -> None: # Last event was ToolEnd(tc2), so its PartClosed is in events assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc2" for e in m.stream.events + isinstance(e, PartClosed) and e.part_id == "tc2" for e in m.stream.new_events ) @@ -216,10 +222,10 @@ def test_deltas_only_on_active_blocks() -> None: assert m.stream is not None assert any( isinstance(e, PartDelta) and e.part_id == "t2" and e.chunk == "second" - for e in m.stream.events + for e in m.stream.new_events ) assert not any( - isinstance(e, PartDelta) and e.part_id == "t1" for e in m.stream.events + isinstance(e, PartDelta) and e.part_id == "t1" for e in m.stream.new_events ) From faa062112c5969c537fa95c60552a591bdac6153 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Sun, 19 Apr 2026 14:13:31 -0700 Subject: [PATCH 08/10] Remove dead function from streaming.py --- src/ai/models/core/helpers/streaming.py | 32 ------------------------- 1 file changed, 32 deletions(-) diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index 3546cb2b..aaba2afb 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -1,10 +1,6 @@ from __future__ import annotations import dataclasses -import json -from collections.abc import AsyncGenerator - -import pydantic from ....types import messages as messages_ @@ -222,31 +218,3 @@ def _build_message( is_done=self._is_done, ), ) - - -async def events_to_messages( - events: AsyncGenerator[StreamEvent], - output_type: type[pydantic.BaseModel] | None = None, -) -> AsyncGenerator[messages_.Message]: - """Convert a stream of events into Message snapshots. - - This is the standalone version of the logic that ``LanguageModel.stream()`` - uses. Wire functions call this to turn their ``StreamEvent`` generators - into ``Message`` generators suitable for ``Stream``. - """ - handler = StreamHandler() - msg: messages_.Message | None = None - async for event in events: - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy(update={"parts": [*msg.parts, part]}) - yield msg From de9c421b33cf9b7d9fe0be498d945fc41285a0ca Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 20 Apr 2026 11:17:36 -0700 Subject: [PATCH 09/10] Replace _get_parts_by_id with get_part(id) --- src/ai/types/messages.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 536d21a5..5a761abd 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -375,18 +375,21 @@ def is_done(self) -> bool: return True return self.stream.is_done - def _parts_by_id(self) -> dict[str, Part]: - return {p.id: p for p in self.parts} + def get_part(self, part_id: str) -> Part | None: + """Find a part by id, or return None if not found.""" + for part in self.parts: + if part.id == part_id: + return part + return None @property def text_delta(self) -> str: """First PartDelta in ``stream.new_events`` whose part is a TextPart.""" if self.stream is None: return "" - parts_map = self._parts_by_id() for ev in self.stream.new_events: if isinstance(ev, PartDelta): - part = parts_map.get(ev.part_id) + part = self.get_part(ev.part_id) if isinstance(part, TextPart): return ev.chunk return "" @@ -396,10 +399,9 @@ def reasoning_delta(self) -> str: """First PartDelta whose part is a ReasoningPart.""" if self.stream is None: return "" - parts_map = self._parts_by_id() for ev in self.stream.new_events: if isinstance(ev, PartDelta): - part = parts_map.get(ev.part_id) + part = self.get_part(ev.part_id) if isinstance(part, ReasoningPart): return ev.chunk return "" @@ -409,11 +411,10 @@ def tool_deltas(self) -> list[ToolDelta]: """PartDeltas in ``stream.new_events`` whose parts are ToolCallPart.""" if self.stream is None: return [] - parts_map = self._parts_by_id() deltas: list[ToolDelta] = [] for ev in self.stream.new_events: if isinstance(ev, PartDelta): - part = parts_map.get(ev.part_id) + part = self.get_part(ev.part_id) if isinstance(part, ToolCallPart): deltas.append( ToolDelta( From b4b9dad95210fa5ccc738fa29e871671c811d217 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 20 Apr 2026 16:08:40 -0700 Subject: [PATCH 10/10] Expose PartDelta directly --- examples/multiagent-textual/client.py | 18 +-- examples/samples/agent_custom_loop.py | 5 +- examples/samples/agent_hooks.py | 5 +- examples/samples/agent_hooks_serverless.py | 12 +- examples/samples/agent_nested.py | 5 +- examples/samples/agent_simple.py | 5 +- examples/samples/explicit_client.py | 5 +- examples/samples/inline_image.py | 5 +- examples/samples/mcp_tools.py | 5 +- examples/samples/middleware_simple.py | 5 +- examples/samples/multimodal_input.py | 5 +- examples/samples/stream.py | 5 +- examples/samples/streaming_tool.py | 6 +- examples/samples/structured_output.py | 5 +- examples/samples/tools_schema.py | 5 +- src/ai/__init__.py | 2 - src/ai/agents/ui/ai_sdk/outbound/_state.py | 71 +++++------ src/ai/agents/ui/ai_sdk/outbound/stream.py | 3 +- src/ai/models/__init__.py | 4 +- src/ai/models/core/helpers/streaming.py | 118 +++++++++--------- src/ai/types/__init__.py | 2 - src/ai/types/messages.py | 77 +++++------- .../agents/ui/ai_sdk/outbound/test_stream.py | 32 +++-- tests/models/core/test_streaming.py | 24 ++-- tests/models/test_public_api.py | 5 +- 25 files changed, 217 insertions(+), 217 deletions(-) diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 85541c23..819e9f74 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -176,15 +176,15 @@ def _handle_message(self, msg: ai.Message) -> None: if panel.status == "idle": panel.status = "streaming..." - # Text deltas - if msg.text_delta: - panel.append_text(msg.text_delta) - if msg.reasoning_delta: - panel.append_text(msg.reasoning_delta, style="dim") - - # Tool argument deltas - for delta in msg.tool_deltas: - panel.append_text(delta.args_delta, style="dim") + # Text / reasoning / tool-arg deltas + for ev in msg.deltas: + match ev.part: + case ai.TextPart(): + panel.append_text(ev.chunk) + case ai.ReasoningPart(): + panel.append_text(ev.chunk, style="dim") + case ai.ToolCallPart(): + panel.append_text(ev.chunk, style="dim") # Completed message — show tool calls and results if msg.is_done: diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 27a5a17a..5f2bce27 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -55,8 +55,9 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: model, [ai.user_message("Compare the weather and population of New York and Tokyo.")], ): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index aaec1ffc..d062bee5 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -90,8 +90,9 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: ) continue - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 5b0110e8..b0974f5e 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -95,8 +95,10 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: f" Hook pending: {hook_part.hook_id}" f" (metadata={hook_part.metadata})" ) - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) saved_checkpoint = durability.checkpoint() print(f"\n Checkpoint saved: {len(saved_checkpoint.steps)} steps\n") @@ -112,8 +114,10 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: hook_part = msg.get_hook_part() if hook_part: print(f" Hook {hook_part.status}: {hook_part.hook_id}") - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_nested.py b/examples/samples/agent_nested.py index df0b5b54..9a6b9dc6 100644 --- a/examples/samples/agent_nested.py +++ b/examples/samples/agent_nested.py @@ -45,8 +45,9 @@ async def main() -> None: ] async for msg in orchestrator.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_simple.py b/examples/samples/agent_simple.py index b0fc3485..1f62588e 100644 --- a/examples/samples/agent_simple.py +++ b/examples/samples/agent_simple.py @@ -22,8 +22,9 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index fa95ada8..2fc2c5d6 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -20,8 +20,9 @@ async def main() -> None: try: async for msg in await ai.models.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() finally: await client.aclose() diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 3eb195ab..2d2228c7 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -26,8 +26,9 @@ async def main() -> None: # Stream — text deltas arrive as usual, images arrive as FileParts async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) last_msg = msg print() diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index 76a84d5b..d5f8329f 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -26,8 +26,9 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index e0deb06b..a9fdf933 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -99,8 +99,9 @@ async def main() -> None: print("--- starting agent run ---\n") async for msg in my_agent.run(model, messages, middleware=[PrintMiddleware()]): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print("\n\n--- done ---") diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 690213da..2663ec46 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -21,8 +21,9 @@ async def main() -> None: async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/stream.py b/examples/samples/stream.py index dbfaae48..70bdafee 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -14,8 +14,9 @@ async def main() -> None: async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 06c57f02..9cf68e0c 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -42,8 +42,10 @@ async def main() -> None: async for msg in my_agent.run(model, messages): if msg.source_label == "tool_progress": print(f" [{msg.text}]") - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 87b36e5f..333841c6 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -22,8 +22,9 @@ class Recipe(pydantic.BaseModel): async def main() -> None: # Stream with structured output — watch JSON arrive, get validated at the end async for msg in await ai.stream(model, messages, output_type=Recipe): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) if msg.output: recipe: Recipe = msg.output print(f"\n\nParsed recipe: {recipe.name}") diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index ac714820..c10ae4c0 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -26,8 +26,9 @@ async def main() -> None: # Stream with tools — the model may emit tool calls async for msg in await ai.stream(model, messages, tools=[get_weather]): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) if msg.is_done: for tc in msg.tool_calls: diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 9c612923..570d9cdc 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -45,7 +45,6 @@ StructuredOutputPart, TextPart, ToolCallPart, - ToolDelta, ToolLike, ToolResultPart, ToolSchema, @@ -71,7 +70,6 @@ "TextPart", "ToolCallPart", "ToolResultPart", - "ToolDelta", "ReasoningPart", "FilePart", "HookPart", diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index 77b592d5..14b26e45 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -126,52 +126,47 @@ def on_event( self, msg: messages_.Message, event: messages_.StreamEvent, - parts_by_id: dict[str, messages_.Part], ) -> list[protocol.UIMessageStreamPart]: - part = parts_by_id.get(event.part_id) - if part is None: - return [] - - match event, part: - case messages_.PartOpened(), messages_.TextPart(): - self.open_text_ids.add(event.part_id) - return [protocol.TextStartPart(id=event.part_id)] - - case messages_.PartDelta(chunk=chunk), messages_.TextPart(): - if event.part_id not in self.open_text_ids: - self.open_text_ids.add(event.part_id) + match event: + case messages_.PartOpened(part=messages_.TextPart(id=pid)): + self.open_text_ids.add(pid) + return [protocol.TextStartPart(id=pid)] + + case messages_.PartDelta(part=messages_.TextPart(id=pid), chunk=chunk): + if pid not in self.open_text_ids: + self.open_text_ids.add(pid) return [ - protocol.TextStartPart(id=event.part_id), - protocol.TextDeltaPart(id=event.part_id, delta=chunk), + protocol.TextStartPart(id=pid), + protocol.TextDeltaPart(id=pid, delta=chunk), ] - return [protocol.TextDeltaPart(id=event.part_id, delta=chunk)] + return [protocol.TextDeltaPart(id=pid, delta=chunk)] - case messages_.PartClosed(), messages_.TextPart(): - if event.part_id in self.open_text_ids: - self.open_text_ids.discard(event.part_id) - return [protocol.TextEndPart(id=event.part_id)] + case messages_.PartClosed(part=messages_.TextPart(id=pid)): + if pid in self.open_text_ids: + self.open_text_ids.discard(pid) + return [protocol.TextEndPart(id=pid)] return [] - case messages_.PartOpened(), messages_.ReasoningPart(): - self.open_reasoning_ids.add(event.part_id) - return [protocol.ReasoningStartPart(id=event.part_id)] + case messages_.PartOpened(part=messages_.ReasoningPart(id=pid)): + self.open_reasoning_ids.add(pid) + return [protocol.ReasoningStartPart(id=pid)] - case messages_.PartDelta(chunk=chunk), messages_.ReasoningPart(): - if event.part_id not in self.open_reasoning_ids: - self.open_reasoning_ids.add(event.part_id) + case messages_.PartDelta(part=messages_.ReasoningPart(id=pid), chunk=chunk): + if pid not in self.open_reasoning_ids: + self.open_reasoning_ids.add(pid) return [ - protocol.ReasoningStartPart(id=event.part_id), - protocol.ReasoningDeltaPart(id=event.part_id, delta=chunk), + protocol.ReasoningStartPart(id=pid), + protocol.ReasoningDeltaPart(id=pid, delta=chunk), ] - return [protocol.ReasoningDeltaPart(id=event.part_id, delta=chunk)] + return [protocol.ReasoningDeltaPart(id=pid, delta=chunk)] - case messages_.PartClosed(), messages_.ReasoningPart(): - if event.part_id in self.open_reasoning_ids: - self.open_reasoning_ids.discard(event.part_id) - return [protocol.ReasoningEndPart(id=event.part_id)] + case messages_.PartClosed(part=messages_.ReasoningPart(id=pid)): + if pid in self.open_reasoning_ids: + self.open_reasoning_ids.discard(pid) + return [protocol.ReasoningEndPart(id=pid)] return [] - case messages_.PartOpened(), messages_.ToolCallPart() as tc: + case messages_.PartOpened(part=messages_.ToolCallPart() as tc): if tc.tool_call_id in self.started_tool_inputs: return [] self.started_tool_inputs.add(tc.tool_call_id) @@ -182,7 +177,7 @@ def on_event( ) ] - case messages_.PartDelta(chunk=chunk), messages_.ToolCallPart() as tc: + case messages_.PartDelta(part=messages_.ToolCallPart() as tc, chunk=chunk): out: list[protocol.UIMessageStreamPart] = [] if tc.tool_call_id not in self.started_tool_inputs: self.started_tool_inputs.add(tc.tool_call_id) @@ -200,7 +195,7 @@ def on_event( ) return out - case messages_.PartClosed(), messages_.ToolCallPart(): + case messages_.PartClosed(part=messages_.ToolCallPart()): # ToolInputAvailablePart is emitted in ``on_terminal`` from # the terminal ``tool_args`` snapshot. return [] @@ -219,13 +214,13 @@ def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPa # PartClosed (e.g. provider terminates abruptly — safety net). if msg.stream is not None: opened_ids = { - e.part_id + e.part.id for e in msg.stream.new_events if isinstance(e, messages_.PartOpened) } for tid in list(self.open_text_ids): if tid in opened_ids and not any( - isinstance(e, messages_.PartClosed) and e.part_id == tid + isinstance(e, messages_.PartClosed) and e.part.id == tid for e in msg.stream.new_events ): out.append(protocol.TextEndPart(id=tid)) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 9c2ab96f..3bccf95f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -28,9 +28,8 @@ async def to_stream( yield part if msg.stream is not None and msg.stream.new_events: - parts_by_id = {p.id: p for p in msg.parts} for event in msg.stream.new_events: - for out in state.on_event(msg, event, parts_by_id): + for out in state.on_event(msg, event): yield out for part in state.on_terminal(msg): diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 55ba9648..a4d9c4ca 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -13,7 +13,9 @@ msgs = [ai.user_message("hello")] s = await ai.stream(model, msgs) async for msg in s: - print(msg.text_delta, end="") + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) # explicit client for custom auth client = ai.Client(base_url="https://custom.example.com/v1", api_key="sk-...") diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index aaba2afb..bd7db603 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -91,21 +91,16 @@ class StreamHandler: Accumulates LLM adapter events and produces Messages with stateful parts. This is the normalization layer between LLM adapters and the rest of the system. + Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, + updated in place as events stream in. Each event carries the just-constructed + frozen part snapshot, so consumers never need to look parts up by id. """ message_id: str = dataclasses.field(default_factory=messages_.generate_id) - # Accumulators - _text_blocks: dict[str, str] = dataclasses.field(default_factory=dict) - _reasoning_blocks: dict[str, tuple[str, str | None]] = dataclasses.field( - default_factory=dict - ) # (text, signature) - _tool_calls: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # (name, args) - _files: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # block_id -> (media_type, data) + # Single source of truth for part state, keyed by id. Insertion order + # preserves provider emission order. + _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) # Active tracking _active_text_id: str | None = None @@ -123,52 +118,86 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: match event: case TextStart(block_id=bid): - self._text_blocks[bid] = "" + part: messages_.Part = messages_.TextPart(id=bid, text="") + self._current_parts[bid] = part self._active_text_id = bid - stream_events.append(messages_.PartOpened(part_id=bid)) + stream_events.append(messages_.PartOpened(part=part)) case TextDelta(block_id=bid, delta=d): - self._text_blocks[bid] += d - stream_events.append(messages_.PartDelta(part_id=bid, chunk=d)) + existing = self._current_parts[bid] + assert isinstance(existing, messages_.TextPart) + part = messages_.TextPart(id=bid, text=existing.text + d) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case TextEnd(block_id=bid): if self._active_text_id == bid: self._active_text_id = None - stream_events.append(messages_.PartClosed(part_id=bid)) + stream_events.append( + messages_.PartClosed(part=self._current_parts[bid]) + ) case ReasoningStart(block_id=bid): - self._reasoning_blocks[bid] = ("", None) + part = messages_.ReasoningPart(id=bid, text="") + self._current_parts[bid] = part self._active_reasoning_id = bid - stream_events.append(messages_.PartOpened(part_id=bid)) + stream_events.append(messages_.PartOpened(part=part)) case ReasoningDelta(block_id=bid, delta=d): - text, sig = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text + d, sig) - stream_events.append(messages_.PartDelta(part_id=bid, chunk=d)) + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, + text=existing.text + d, + signature=existing.signature, + ) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case ReasoningEnd(block_id=bid, signature=sig): - text, _ = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text, sig) + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, text=existing.text, signature=sig + ) + self._current_parts[bid] = part if self._active_reasoning_id == bid: self._active_reasoning_id = None - stream_events.append(messages_.PartClosed(part_id=bid)) + stream_events.append(messages_.PartClosed(part=part)) case ToolStart(tool_call_id=tcid, tool_name=name): - self._tool_calls[tcid] = (name, "") + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=tcid, + tool_name=name, + tool_args="", + ) + self._current_parts[tcid] = part self._active_tool_ids.add(tcid) - stream_events.append(messages_.PartOpened(part_id=tcid)) + stream_events.append(messages_.PartOpened(part=part)) case ToolArgsDelta(tool_call_id=tcid, delta=d): - name, args = self._tool_calls[tcid] - self._tool_calls[tcid] = (name, args + d) - stream_events.append(messages_.PartDelta(part_id=tcid, chunk=d)) + existing = self._current_parts[tcid] + assert isinstance(existing, messages_.ToolCallPart) + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=existing.tool_call_id, + tool_name=existing.tool_name, + tool_args=existing.tool_args + d, + ) + self._current_parts[tcid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) - stream_events.append(messages_.PartClosed(part_id=tcid)) + stream_events.append( + messages_.PartClosed(part=self._current_parts[tcid]) + ) case FileEvent(block_id=bid, media_type=mt, data=d): - self._files[bid] = (mt, d) + self._current_parts[bid] = messages_.FilePart( + id=bid, data=d, media_type=mt + ) case MessageDone(usage=usage): self._is_done = True @@ -183,35 +212,10 @@ def _build_message( self, stream_events: list[messages_.StreamEvent], ) -> messages_.Message: - parts: list[messages_.Part] = [] - - # Reasoning parts first (like thinking blocks) - for bid, (text, sig) in self._reasoning_blocks.items(): - parts.append(messages_.ReasoningPart(id=bid, text=text, signature=sig)) - - # Text parts - for bid, text in self._text_blocks.items(): - parts.append(messages_.TextPart(id=bid, text=text)) - - # Tool call parts - for tcid, (name, args) in self._tool_calls.items(): - parts.append( - messages_.ToolCallPart( - id=tcid, - tool_call_id=tcid, - tool_name=name, - tool_args=args, - ) - ) - - # File parts (inline images/videos from LLMs like Gemini, GPT-5) - for bid, (media_type, data) in self._files.items(): - parts.append(messages_.FilePart(id=bid, data=data, media_type=media_type)) - return messages_.Message( id=self.message_id, role="assistant", - parts=parts, + parts=list(self._current_parts.values()), usage=self._usage if self._is_done else None, stream=messages_.StreamState( new_events=stream_events, diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index f10ce117..6a41f9d6 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -17,7 +17,6 @@ StructuredOutputPart, TextPart, ToolCallPart, - ToolDelta, ToolResultPart, Usage, generate_id, @@ -39,7 +38,6 @@ "StructuredOutputPart", "TextPart", "ToolCallPart", - "ToolDelta", "ToolLike", "ToolResultPart", "ToolSchema", diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 5a761abd..25ecd6ce 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -258,38 +258,48 @@ def _add_optional(a: int | None, b: int | None) -> int | None: ) -class ToolDelta(pydantic.BaseModel): - model_config = pydantic.ConfigDict(frozen=True) - - tool_call_id: str - tool_name: str - args_delta: str - - # --------------------------------------------------------------------------- # Streaming sidecar — transient state excluded from persistence. # --------------------------------------------------------------------------- class PartOpened(pydantic.BaseModel): + """A new streaming block was opened by the LLM. + + ``part`` holds the initial snapshot of the part (empty text/args). + """ + model_config = pydantic.ConfigDict(frozen=True) - part_id: str + part: Part type: Literal["part_opened"] = "part_opened" class PartDelta(pydantic.BaseModel): + """An incremental update to a streaming part. + + ``part`` is the post-delta snapshot (state accumulated up to and including + ``chunk``). ``chunk`` is the new fragment appended this step (plain text + for :class:`TextPart` / :class:`ReasoningPart`, a JSON-args fragment for + :class:`ToolCallPart`). + """ + model_config = pydantic.ConfigDict(frozen=True) - part_id: str + part: Part chunk: str type: Literal["part_delta"] = "part_delta" class PartClosed(pydantic.BaseModel): + """A streaming block was closed by the LLM. + + ``part`` holds the final snapshot of the part. + """ + model_config = pydantic.ConfigDict(frozen=True) - part_id: str + part: Part type: Literal["part_closed"] = "part_closed" @@ -383,47 +393,16 @@ def get_part(self, part_id: str) -> Part | None: return None @property - def text_delta(self) -> str: - """First PartDelta in ``stream.new_events`` whose part is a TextPart.""" - if self.stream is None: - return "" - for ev in self.stream.new_events: - if isinstance(ev, PartDelta): - part = self.get_part(ev.part_id) - if isinstance(part, TextPart): - return ev.chunk - return "" + def deltas(self) -> list[PartDelta]: + """PartDelta events from this yield step, in order. - @property - def reasoning_delta(self) -> str: - """First PartDelta whose part is a ReasoningPart.""" - if self.stream is None: - return "" - for ev in self.stream.new_events: - if isinstance(ev, PartDelta): - part = self.get_part(ev.part_id) - if isinstance(part, ReasoningPart): - return ev.chunk - return "" - - @property - def tool_deltas(self) -> list[ToolDelta]: - """PartDeltas in ``stream.new_events`` whose parts are ToolCallPart.""" + Empty list means nothing streamed in this step. Each event carries + its post-delta :class:`Part` snapshot via ``ev.part`` and the chunk + fragment via ``ev.chunk``. + """ if self.stream is None: return [] - deltas: list[ToolDelta] = [] - for ev in self.stream.new_events: - if isinstance(ev, PartDelta): - part = self.get_part(ev.part_id) - if isinstance(part, ToolCallPart): - deltas.append( - ToolDelta( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - args_delta=ev.chunk, - ) - ) - return deltas + return [ev for ev in self.stream.new_events if isinstance(ev, PartDelta)] @property def files(self) -> list[FilePart]: diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 23cc4c05..7a14fd1e 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -28,31 +28,35 @@ def _text_stream_message( is_done: bool, full_text: str | None = None, ) -> messages_.Message: + text = full_text or chunk + part = messages_.TextPart(id=text_id, text=text) events: list[messages_.StreamEvent] if is_done: - events = [messages_.PartClosed(part_id=text_id)] + events = [messages_.PartClosed(part=part)] else: - events = [messages_.PartDelta(part_id=text_id, chunk=chunk)] + events = [messages_.PartDelta(part=part, chunk=chunk)] return messages_.Message( id=msg_id, role="assistant", turn_id=turn_id, - parts=[messages_.TextPart(id=text_id, text=full_text or chunk)], + parts=[part], stream=messages_.StreamState(new_events=events, is_done=is_done), ) async def test_event_driven_text_streaming() -> None: text_id = "txt1" + empty_text = messages_.TextPart(id=text_id, text="") + hi_text = messages_.TextPart(id=text_id, text="hi") msgs = [ # Initial: PartOpened messages_.Message( id="m1", role="assistant", turn_id="t1", - parts=[messages_.TextPart(id=text_id, text="")], + parts=[empty_text], stream=messages_.StreamState( - new_events=[messages_.PartOpened(part_id=text_id)], + new_events=[messages_.PartOpened(part=empty_text)], is_done=False, ), ), @@ -61,9 +65,9 @@ async def test_event_driven_text_streaming() -> None: id="m1", role="assistant", turn_id="t1", - parts=[messages_.TextPart(id=text_id, text="hi")], + parts=[hi_text], stream=messages_.StreamState( - new_events=[messages_.PartDelta(part_id=text_id, chunk="hi")], + new_events=[messages_.PartDelta(part=hi_text, chunk="hi")], is_done=False, ), ), @@ -72,9 +76,9 @@ async def test_event_driven_text_streaming() -> None: id="m1", role="assistant", turn_id="t1", - parts=[messages_.TextPart(id=text_id, text="hi")], + parts=[hi_text], stream=messages_.StreamState( - new_events=[messages_.PartClosed(part_id=text_id)], + new_events=[messages_.PartClosed(part=hi_text)], is_done=True, ), ), @@ -216,16 +220,18 @@ async def test_approval_request_hook_emits_approval_part() -> None: async def test_dedup_on_reemitted_message_id() -> None: + empty = messages_.TextPart(id="txt1", text="") + hi = messages_.TextPart(id="txt1", text="hi") msg = messages_.Message( id="m1", role="assistant", turn_id="t1", - parts=[messages_.TextPart(id="txt1", text="hi")], + parts=[hi], stream=messages_.StreamState( new_events=[ - messages_.PartOpened(part_id="txt1"), - messages_.PartDelta(part_id="txt1", chunk="hi"), - messages_.PartClosed(part_id="txt1"), + messages_.PartOpened(part=empty), + messages_.PartDelta(part=hi, chunk="hi"), + messages_.PartClosed(part=hi), ], is_done=True, ), diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py index 78e3c3b6..b123f0e8 100644 --- a/tests/models/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -18,7 +18,7 @@ def test_text_lifecycle() -> None: assert part.text == "" assert m.stream is not None assert any( - isinstance(e, PartOpened) and e.part_id == "b1" for e in m.stream.new_events + isinstance(e, PartOpened) and e.part.id == "b1" for e in m.stream.new_events ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) @@ -27,7 +27,7 @@ def test_text_lifecycle() -> None: assert part.text == "Hello" assert m.stream is not None assert any( - isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == "Hello" + isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == "Hello" for e in m.stream.new_events ) @@ -37,7 +37,7 @@ def test_text_lifecycle() -> None: assert part.text == "Hello world" assert m.stream is not None assert any( - isinstance(e, PartDelta) and e.part_id == "b1" and e.chunk == " world" + isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == " world" for e in m.stream.new_events ) @@ -46,7 +46,7 @@ def test_text_lifecycle() -> None: assert isinstance(part, messages.TextPart) assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "b1" for e in m.stream.new_events + isinstance(e, PartClosed) and e.part.id == "b1" for e in m.stream.new_events ) # No delta events in this yield assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) @@ -64,7 +64,7 @@ def test_reasoning_lifecycle() -> None: assert part.text == "thinking" assert m.stream is not None assert any( - isinstance(e, PartDelta) and e.part_id == "r1" and e.chunk == "thinking" + isinstance(e, PartDelta) and e.part.id == "r1" and e.chunk == "thinking" for e in m.stream.new_events ) @@ -74,7 +74,7 @@ def test_reasoning_lifecycle() -> None: assert part.signature == "sig123" assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "r1" for e in m.stream.new_events + isinstance(e, PartClosed) and e.part.id == "r1" for e in m.stream.new_events ) @@ -91,7 +91,7 @@ def test_tool_lifecycle() -> None: assert part.tool_args == '{"ci' assert m.stream is not None assert any( - isinstance(e, PartDelta) and e.part_id == "tc1" and e.chunk == '{"ci' + isinstance(e, PartDelta) and e.part.id == "tc1" and e.chunk == '{"ci' for e in m.stream.new_events ) @@ -107,7 +107,7 @@ def test_tool_lifecycle() -> None: assert isinstance(part, messages.ToolCallPart) assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.new_events + isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events ) # No delta events in this yield assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) @@ -138,7 +138,7 @@ def test_reasoning_then_text_then_tool() -> None: # The last event was ToolEnd(tc1), so only that PartClosed is in events assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc1" for e in m.stream.new_events + isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events ) @@ -164,7 +164,7 @@ def test_multiple_tool_calls() -> None: # Last event was ToolEnd(tc2), so its PartClosed is in events assert m.stream is not None assert any( - isinstance(e, PartClosed) and e.part_id == "tc2" for e in m.stream.new_events + isinstance(e, PartClosed) and e.part.id == "tc2" for e in m.stream.new_events ) @@ -221,11 +221,11 @@ def test_deltas_only_on_active_blocks() -> None: # Only t2 has a delta event in this yield assert m.stream is not None assert any( - isinstance(e, PartDelta) and e.part_id == "t2" and e.chunk == "second" + isinstance(e, PartDelta) and e.part.id == "t2" and e.chunk == "second" for e in m.stream.new_events ) assert not any( - isinstance(e, PartDelta) and e.part_id == "t1" for e in m.stream.new_events + isinstance(e, PartDelta) and e.part.id == "t1" for e in m.stream.new_events ) diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 46319b67..8f1ddb1b 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -32,8 +32,9 @@ async def test_stream_basic() -> None: s = await models.stream(MOCK_MODEL, [ai.user_message("Hi")]) deltas: list[str] = [] async for msg in s: - if msg.text_delta: - deltas.append(msg.text_delta) + for ev in msg.deltas: + if isinstance(ev.part, messages_.TextPart): + deltas.append(ev.chunk) assert mock.call_count == 1 assert s.text == "Hello world"