diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index ae25200b..baf1c13b 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -16,7 +16,7 @@ to suspend execution whenever the LLM wants to call a tool. The flow is: 1. LLM emits a tool call 2. Backend calls `await ai.hook(...)` with `payload=ai.ToolApproval` -3. The runtime emits a `role="signal"` message containing a pending `HookPart` +3. The runtime emits a `role="internal"` message containing a pending `HookPart` 4. The frontend renders Approve / Reject buttons via the `` component (from AI Elements) 5. When the user clicks a button, `addToolApprovalResponse()` patches diff --git a/examples/multiagent-textual/README.md b/examples/multiagent-textual/README.md index c4dcc71d..5126e7de 100644 --- a/examples/multiagent-textual/README.md +++ b/examples/multiagent-textual/README.md @@ -10,7 +10,7 @@ The current implementation uses: - `ai.agent(...)` for each branch and the orchestrator - `await ai.hook(...)` for branch-specific approvals - `ai.yield_from(...)` to forward nested agent output into the outer run -- `role="signal"` messages for hook state updates over the WebSocket +- `role="internal"` messages for hook state updates over the WebSocket ## Setup diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 154d4c52..819e9f74 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -158,7 +158,7 @@ async def run_websocket(self) -> None: # ------------------------------------------------------------------ def _handle_message(self, msg: ai.Message) -> None: - label = msg.label or "unknown" + label = msg.source_label or "unknown" if (hook_part := msg.get_hook_part()) is not None: if hook_part.status == "pending": @@ -176,21 +176,21 @@ def _handle_message(self, msg: ai.Message) -> None: if panel.status == "idle": panel.status = "streaming..." - # Text deltas - if msg.text_delta: - panel.append_text(msg.text_delta) - if msg.reasoning_delta: - panel.append_text(msg.reasoning_delta, style="dim") - - # Tool argument deltas - for delta in msg.tool_deltas: - panel.append_text(delta.args_delta, style="dim") + # Text / reasoning / tool-arg deltas + for ev in msg.deltas: + match ev.part: + case ai.TextPart(): + panel.append_text(ev.chunk) + case ai.ReasoningPart(): + panel.append_text(ev.chunk, style="dim") + case ai.ToolCallPart(): + panel.append_text(ev.chunk, style="dim") # Completed message — show tool calls and results if msg.is_done: for part in msg.parts: match part: - case ai.ToolCallPart(tool_name=name, tool_args=args, state="done"): + case ai.ToolCallPart(tool_name=name, tool_args=args): panel.append_line(f"> {name}({args})") case ai.ToolResultPart(tool_name=name, result=result): panel.append_line(f"< {name} = {result}") diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index a35935f0..37a5bf8d 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -185,7 +185,7 @@ async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: ], ) async for msg in s: - yield msg.model_copy(update={"label": "summary"}) + yield msg.model_copy(update={"agent": "summary"}) # --------------------------------------------------------------------------- diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 27a5a17a..5f2bce27 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -55,8 +55,9 @@ async def custom(context: ai.Context) -> AsyncGenerator[ai.Message]: model, [ai.user_message("Compare the weather and population of New York and Tokyo.")], ): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index 7180311b..d062bee5 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -3,7 +3,7 @@ Demonstrates the function-based hook API: - await hook("label", payload=Model) to suspend inside the loop - resolve_hook("label", data) to unblock from outside - - Hook messages arrive with role="signal" + - Hook messages arrive with role="internal" """ import asyncio @@ -76,8 +76,8 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: ] async for msg in my_agent.run(model, messages): - # Hook signals arrive with role="signal" - if msg.role == "signal": + # Hook signals arrive with role="internal" + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part and hook_part.status == "pending": answer = input(f"Approve {hook_part.hook_id}? [y/n] ") @@ -90,8 +90,9 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Message]: ) continue - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 23717689..b0974f5e 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -87,7 +87,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: durability = ai.EventLogProvider() async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "signal": + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part and hook_part.status == "pending": pending_hook_labels.append(hook_part.hook_id) @@ -95,8 +95,10 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: f" Hook pending: {hook_part.hook_id}" f" (metadata={hook_part.metadata})" ) - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) saved_checkpoint = durability.checkpoint() print(f"\n Checkpoint saved: {len(saved_checkpoint.steps)} steps\n") @@ -108,12 +110,14 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Message]: durability = ai.EventLogProvider(saved_checkpoint) async for msg in my_agent.run(model, messages, durability=durability): - if msg.role == "signal": + if msg.role == "internal": hook_part = msg.get_hook_part() if hook_part: print(f" Hook {hook_part.status}: {hook_part.hook_id}") - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_nested.py b/examples/samples/agent_nested.py index b4a2f535..9a6b9dc6 100644 --- a/examples/samples/agent_nested.py +++ b/examples/samples/agent_nested.py @@ -5,7 +5,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") @ai.tool @@ -45,8 +45,9 @@ async def main() -> None: ] async for msg in orchestrator.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/agent_simple.py b/examples/samples/agent_simple.py index 6a90b888..1f62588e 100644 --- a/examples/samples/agent_simple.py +++ b/examples/samples/agent_simple.py @@ -12,7 +12,7 @@ async def get_weather(city: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[get_weather]) @@ -22,8 +22,9 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index fa95ada8..2fc2c5d6 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -20,8 +20,9 @@ async def main() -> None: try: async for msg in await ai.models.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() finally: await client.aclose() diff --git a/examples/samples/image_edit.py b/examples/samples/image_edit.py index ac07fc4f..f99780a1 100644 --- a/examples/samples/image_edit.py +++ b/examples/samples/image_edit.py @@ -11,7 +11,7 @@ import ai -model = ai.model("ai-gateway", "openai/gpt-image-1") +model = ai.ai_gateway("openai/gpt-image-1") async def main() -> None: diff --git a/examples/samples/inline_image.py b/examples/samples/inline_image.py index 2b2fdcc4..2d2228c7 100644 --- a/examples/samples/inline_image.py +++ b/examples/samples/inline_image.py @@ -11,7 +11,7 @@ import ai -model = ai.model("ai-gateway", "google/gemini-3-pro-image") +model = ai.ai_gateway("google/gemini-3-pro-image") messages = [ ai.system_message( @@ -26,8 +26,9 @@ async def main() -> None: # Stream — text deltas arrive as usual, images arrive as FileParts async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) last_msg = msg print() diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index 6925fd0a..d5f8329f 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -8,7 +8,7 @@ async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") context7_tools: list[ai.Tool[..., Any]] = await ai.mcp.get_http_tools( "https://mcp.context7.com/mcp", @@ -26,8 +26,9 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index e0deb06b..a9fdf933 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -99,8 +99,9 @@ async def main() -> None: print("--- starting agent run ---\n") async for msg in my_agent.run(model, messages, middleware=[PrintMiddleware()]): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print("\n\n--- done ---") diff --git a/examples/samples/multimodal_input.py b/examples/samples/multimodal_input.py index 07ce4935..2663ec46 100644 --- a/examples/samples/multimodal_input.py +++ b/examples/samples/multimodal_input.py @@ -5,7 +5,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") # Load a local image file (replace with your own path). image_path = pathlib.Path("sample_image.jpg") @@ -21,8 +21,9 @@ async def main() -> None: async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/stream.py b/examples/samples/stream.py index dbfaae48..70bdafee 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -14,8 +14,9 @@ async def main() -> None: async for msg in await ai.stream(model, messages): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 5bf8fedd..9cf68e0c 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -17,20 +17,20 @@ async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Message]: for step in ["Connecting...", "Transmitting...", "Awaiting response..."]: yield ai.Message( role="assistant", - parts=[ai.TextPart(text=step, state="done")], - label="tool_progress", + parts=[ai.TextPart(text=step)], + source_label="tool_progress", ) await asyncio.sleep(0.3) # The final yielded message's text is returned as the tool result. yield ai.Message( role="assistant", - parts=[ai.TextPart(text="The mothership says: Soon.", state="done")], + parts=[ai.TextPart(text="The mothership says: Soon.")], ) async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[talk_to_mothership]) @@ -40,10 +40,12 @@ async def main() -> None: ] async for msg in my_agent.run(model, messages): - if msg.label == "tool_progress": + if msg.source_label == "tool_progress": print(f" [{msg.text}]") - elif msg.text_delta: - print(msg.text_delta, end="", flush=True) + else: + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) print() diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index 11928053..333841c6 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") class Recipe(pydantic.BaseModel): @@ -22,8 +22,9 @@ class Recipe(pydantic.BaseModel): async def main() -> None: # Stream with structured output — watch JSON arrive, get validated at the end async for msg in await ai.stream(model, messages, output_type=Recipe): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) if msg.output: recipe: Recipe = msg.output print(f"\n\nParsed recipe: {recipe.name}") diff --git a/examples/samples/tools_schema.py b/examples/samples/tools_schema.py index 0c13ad42..c10ae4c0 100644 --- a/examples/samples/tools_schema.py +++ b/examples/samples/tools_schema.py @@ -4,7 +4,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") # Define a tool schema — anything matching the ToolLike protocol works. get_weather = ai.ToolSchema( @@ -26,11 +26,12 @@ async def main() -> None: # Stream with tools — the model may emit tool calls async for msg in await ai.stream(model, messages, tools=[get_weather]): - if msg.text_delta: - print(msg.text_delta, end="", flush=True) + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) - for tc in msg.tool_calls: - if tc.state == "done": + if msg.is_done: + for tc in msg.tool_calls: print(f"\nTool call: {tc.tool_name}({tc.tool_args})") print() diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 00a7c589..570d9cdc 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,5 +1,4 @@ -from . import adapters, middleware, models -from .adapters import ai_sdk_ui +from . import middleware, models from .agents import ( TOOL_APPROVAL_HOOK_TYPE, Agent, @@ -37,13 +36,15 @@ HookPart, Message, Part, - PartState, + PartClosed, + PartDelta, + PartOpened, ReasoningPart, StreamResultLike, + StreamState, StructuredOutputPart, TextPart, ToolCallPart, - ToolDelta, ToolLike, ToolResultPart, ToolSchema, @@ -63,14 +64,16 @@ # Types (from types/) "Message", "Part", - "PartState", + "PartClosed", + "PartDelta", + "PartOpened", "TextPart", "ToolCallPart", "ToolResultPart", - "ToolDelta", "ReasoningPart", "FilePart", "HookPart", + "StreamState", "StructuredOutputPart", "ToolLike", "ToolSchema", @@ -121,6 +124,4 @@ "middleware", # Submodules "mcp", - "ai_sdk_ui", - "adapters", ] diff --git a/src/ai/adapters/ai_sdk_ui/__init__.py b/src/ai/adapters/ai_sdk_ui/__init__.py deleted file mode 100644 index 4b8bf175..00000000 --- a/src/ai/adapters/ai_sdk_ui/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from .adapter import filter_by_label, to_messages, to_sse_stream, to_ui_message_stream -from .protocol import UI_MESSAGE_STREAM_HEADERS -from .ui_message import UIMessage - -__all__ = [ - "to_ui_message_stream", - "filter_by_label", - "to_sse_stream", - "to_messages", - "UIMessage", - "UI_MESSAGE_STREAM_HEADERS", -] diff --git a/src/ai/adapters/ai_sdk_ui/adapter.py b/src/ai/adapters/ai_sdk_ui/adapter.py deleted file mode 100644 index 99a20a82..00000000 --- a/src/ai/adapters/ai_sdk_ui/adapter.py +++ /dev/null @@ -1,606 +0,0 @@ -""" -Reference: https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol -""" - -from __future__ import annotations - -import dataclasses -import json -import logging -from collections.abc import AsyncGenerator, AsyncIterable -from typing import Any - -from ...agents import hooks -from ...agents.hooks import TOOL_APPROVAL_HOOK_TYPE -from ...types import messages as messages_ -from . import protocol, ui_message - -logger = logging.getLogger(__name__) - -# ============================================================================ -# Serialization utilities -# ============================================================================ - - -def _to_camel_case(snake_str: str) -> str: - """Convert snake_case to camelCase.""" - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def serialize_part(part: protocol.UIMessageStreamPart) -> str: - """Serialize a stream part to JSON with camelCase keys.""" - d = dataclasses.asdict(part) - if isinstance(part, protocol.DataPart): - # DataPart's wire type is computed (``data-{data_type}``); replace - # the raw ``data_type`` field with the protocol ``type`` key. - d["type"] = part.type - del d["data_type"] - camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} - return json.dumps(camel_dict) - - -def format_sse(part: protocol.UIMessageStreamPart) -> str: - """Format a stream part as an SSE data line.""" - return f"data: {serialize_part(part)}\n\n" - - -# ============================================================================ -# Internal Message → UI Message Stream Conversion -# ============================================================================ - - -class _StreamState: - """Tracks state for UI message stream event sequencing. - - Encapsulates the mutable state needed to properly sequence events - (reasoning blocks, text blocks, steps, tool calls) when converting - an internal message stream to the AI SDK UI protocol. - """ - - def __init__(self) -> None: - self.text_id: str | None = None - self.reasoning_id: str | None = None - self.label: str | None = None - self.message_id: str | None = None - self.emitted_start: bool = False - self.in_step: bool = False - self.started_tool_calls: set[str] = set() - self.emitted_tool_results: set[str] = set() - self.pending_tool_calls: set[str] = set() - self.emitted_approval_requests: set[str] = set() - - def close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: - """Close any open reasoning/text blocks, returning parts to emit.""" - parts: list[protocol.UIMessageStreamPart] = [] - if self.reasoning_id: - parts.append(protocol.ReasoningEndPart(id=self.reasoning_id)) - self.reasoning_id = None - if self.text_id: - parts.append(protocol.TextEndPart(id=self.text_id)) - self.text_id = None - return parts - - def finish_step(self) -> list[protocol.UIMessageStreamPart]: - """Close open blocks and finish the current step if active.""" - parts = self.close_open_blocks() - if self.in_step: - parts.append(protocol.FinishStepPart()) - self.in_step = False - return parts - - def reset_tool_tracking(self) -> None: - """Reset tool tracking sets (for new message/agent boundaries).""" - self.started_tool_calls = set() - self.emitted_tool_results = set() - self.pending_tool_calls = set() - self.emitted_approval_requests = set() - - def begin_message( - self, msg: messages_.Message - ) -> list[protocol.UIMessageStreamPart]: - """Handle message/step boundaries, returning parts to emit. - - Decides whether to start a new message (first message or agent switch) - or a new step (same stream, different message ID), closing any open - blocks and steps as needed. - """ - parts: list[protocol.UIMessageStreamPart] = [] - is_new_message = self.message_id is not None and msg.id != self.message_id - - if not self.emitted_start or (msg.label and msg.label != self.label): - # First message or label change (new agent) - parts.extend(self.finish_step()) - if self.emitted_start: - parts.append(protocol.FinishPart(finish_reason="stop")) - - parts.append(protocol.StartPart(message_id=msg.id)) - parts.append(protocol.StartStepPart()) - self.emitted_start = True - self.in_step = True - self.label = msg.label - self.message_id = msg.id - self.reset_tool_tracking() - elif is_new_message: - # New message ID within the same stream = new step - parts.extend(self.finish_step()) - parts.append(protocol.StartStepPart()) - self.in_step = True - self.message_id = msg.id - - return parts - - -def _tool_call_id_from_approval_hook( - hook_part: messages_.HookPart, -) -> str | None: - """Extract tool_call_id from a ToolApproval HookPart. - - Returns the tool_call_id if this is a ToolApproval hook whose hook_id - follows the ``approve_{tool_call_id}`` convention, otherwise None. - """ - if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: - return None - prefix = "approve_" - if hook_part.hook_id.startswith(prefix): - return hook_part.hook_id[len(prefix) :] - return None - - -def _is_tool_approval_hook_message(msg: messages_.Message) -> bool: - """True if this message contains only ToolApproval HookParts.""" - if not msg.parts: - return False - return all( - isinstance(p, messages_.HookPart) - and _tool_call_id_from_approval_hook(p) is not None - for p in msg.parts - ) - - -async def to_ui_message_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[protocol.UIMessageStreamPart]: - """ - Convert a proto_sdk message stream into AI SDK UI message stream parts. - - This adapter transforms the internal message format into the AI SDK - protocol that can be consumed by useChat and other AI SDK UI hooks. - """ - state = _StreamState() - - async for msg in messages: - # Tool-result messages (role="tool") are emitted by the Runtime - # as separate Message objects (with their own auto-generated id). - # To the frontend they belong to the *same* step as the tool - # call, so we pin the message id to avoid a spurious step boundary. - if msg.role == "tool" and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - # Tool-approval hook messages are emitted by the Runtime as - # separate Message objects (with their own id). To the frontend - # they belong to the *same* step as the tool call, so we pin - # the message id to avoid creating a spurious step boundary. - if _is_tool_approval_hook_message(msg) and state.message_id: - msg = msg.model_copy(update={"id": state.message_id}) - - for part in state.begin_message(msg): - yield part - - # Handle reasoning streaming (deltas) - reasoning comes before text - if delta := msg.reasoning_delta: - if not state.reasoning_id: - state.reasoning_id = messages_.generate_id("reasoning") - yield protocol.ReasoningStartPart(id=state.reasoning_id) - yield protocol.ReasoningDeltaPart(id=state.reasoning_id, delta=delta) - - # Handle text streaming (deltas) - if delta := msg.text_delta: - # Close reasoning block when text starts (reasoning precedes text) - if state.reasoning_id: - yield protocol.ReasoningEndPart(id=state.reasoning_id) - state.reasoning_id = None - - if not state.text_id: - state.text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=state.text_id) - yield protocol.TextDeltaPart(id=state.text_id, delta=delta) - - # Handle streaming tool call arguments - for tool_delta in msg.tool_deltas: - if tool_delta.tool_call_id not in state.started_tool_calls: - state.started_tool_calls.add(tool_delta.tool_call_id) - yield protocol.ToolInputStartPart( - tool_call_id=tool_delta.tool_call_id, - tool_name=tool_delta.tool_name, - ) - yield protocol.ToolInputDeltaPart( - tool_call_id=tool_delta.tool_call_id, - input_text_delta=tool_delta.args_delta, - ) - - # Handle completed messages - if msg.is_done: - had_active_text = state.text_id is not None - for part in state.close_open_blocks(): - yield part - - # Scan for new pending tool calls or tool results - has_new_pending_tools = any( - isinstance(p, messages_.ToolCallPart) - and p.tool_call_id not in state.pending_tool_calls - for p in msg.parts - ) - has_new_tool_results = any( - isinstance(p, messages_.ToolResultPart) - and p.tool_call_id not in state.emitted_tool_results - for p in msg.parts - ) - - # Process parts in passes: - # 1. Text and pending tool calls (from assistant messages) - # 2. Tool results (from tool messages) - - # Pass 1: Text and pending tool inputs - for msg_part in msg.parts: - match msg_part: - case messages_.TextPart(text=text) if ( - text - and not had_active_text - and not has_new_pending_tools - and not has_new_tool_results - ): - text_id = messages_.generate_id("text") - yield protocol.TextStartPart(id=text_id) - yield protocol.TextEndPart(id=text_id) - case messages_.ToolCallPart( - tool_call_id=tc_id, - tool_name=name, - tool_args=args, - ): - if tc_id not in state.started_tool_calls: - state.started_tool_calls.add(tc_id) - yield protocol.ToolInputStartPart( - tool_call_id=tc_id, - tool_name=name, - ) - if tc_id not in state.pending_tool_calls: - state.pending_tool_calls.add(tc_id) - yield protocol.ToolInputAvailablePart( - tool_call_id=tc_id, - tool_name=name, - input=args, - ) - - # Pass 2: Tool results - if has_new_tool_results: - for msg_part in msg.parts: - if ( - isinstance(msg_part, messages_.ToolResultPart) - and msg_part.tool_call_id not in state.emitted_tool_results - ): - state.emitted_tool_results.add(msg_part.tool_call_id) - state.pending_tool_calls.discard(msg_part.tool_call_id) - yield protocol.ToolOutputAvailablePart( - tool_call_id=msg_part.tool_call_id, - output=msg_part.result, - ) - - # Pass 3: Hook-based tool approvals - for msg_part in msg.parts: - if not isinstance(msg_part, messages_.HookPart): - continue - approval_tc_id = _tool_call_id_from_approval_hook(msg_part) - if approval_tc_id is None: - continue - - if msg_part.status == "pending": - if approval_tc_id not in state.emitted_approval_requests: - state.emitted_approval_requests.add(approval_tc_id) - yield protocol.ToolApprovalRequestPart( - approval_id=msg_part.hook_id, - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "resolved": - resolution = msg_part.resolution or {} - if not resolution.get("granted", False): - yield protocol.ToolOutputDeniedPart( - tool_call_id=approval_tc_id, - ) - elif msg_part.status == "cancelled": - yield protocol.ToolOutputErrorPart( - tool_call_id=approval_tc_id, - error_text="Hook cancelled", - ) - - # Final cleanup - for part in state.finish_step(): - yield part - if state.emitted_start: - yield protocol.FinishPart(finish_reason="stop") - - -async def filter_by_label( - messages: AsyncIterable[messages_.Message], - label: str | None = None, -) -> AsyncGenerator[messages_.Message]: - """Filter a message stream to a single agent label. - - If label is provided, only messages with that label pass through. - If label is None, auto-locks to whichever label arrives first. - """ - async for msg in messages: - if label is None: - label = msg.label - if msg.label == label: - yield msg - - -async def to_sse_stream( - messages: AsyncIterable[messages_.Message], -) -> AsyncGenerator[str]: - """Convert a proto_sdk message stream directly into SSE-formatted strings.""" - async for part in to_ui_message_stream(messages): - yield format_sse(part) - - -# ============================================================================ -# Tool conversion helpers -# ============================================================================ - -_TOOL_RESULT_STATES: frozenset[str] = frozenset({"output-available"}) -_TOOL_ERROR_STATES: frozenset[str] = frozenset({"output-error", "output-denied"}) - - -def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates a completed tool.""" - return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES - - -def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: - """Return True if the tool invocation state indicates an error.""" - return state in _TOOL_ERROR_STATES - - -def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: - """Normalize tool input (JSON string, dict, or None) to a JSON string.""" - match tool_input: - case str(): - return tool_input - case dict(): - return json.dumps(tool_input) - case _: - return "{}" - - -def _normalize_tool_result(output: Any) -> dict[str, Any] | None: - """Normalize tool output to dict format for internal ToolResultPart. - - The internal ToolResultPart.result expects dict | None, but AI SDK - output can be any type. Wrap non-dict results for compatibility. - """ - if output is None: - return None - return output if isinstance(output, dict) else {"value": output} - - -def to_messages( - ui_messages: list[ui_message.UIMessage], -) -> list[messages_.Message]: - """Convert AI SDK v6 UI messages to internal Message format. - - As a side-effect, tool parts in ``approval-responded`` state trigger - ``ToolApproval.resolve()`` so the agent loop can resume execution - without the caller needing to handle approval routing explicitly. - - When approvals are resolved, the trailing assistant message is - automatically stripped to avoid sending duplicate tool-use content - to the LLM on re-entry. - - Args: - ui_messages: List of UIMessage objects from the AI SDK v6 frontend. - - Returns: - List of internal Message objects ready for use with the runtime. - """ - result: list[messages_.Message] = [] - resolved_any_approval = False - - for ui_msg in ui_messages: - # For assistant messages, separate tool calls from tool results. - assistant_parts: list[messages_.Part] = [] - tool_result_parts: list[messages_.ToolResultPart] = [] - - for part in ui_msg.parts: - match part: - case ui_message.UITextPart(text=text) if text: - assistant_parts.append(messages_.TextPart(text=text)) - - case ui_message.UIReasoningPart(reasoning=reasoning): - assistant_parts.append(messages_.ReasoningPart(text=reasoning)) - - case ui_message.UIToolInvocationPart() as inv: - # Legacy tool-invocation type — always create the call part - tool_args = json.dumps(inv.args) if inv.args else "{}" - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - tool_args=tool_args, - ) - ) - if _is_tool_completed(inv.state): - tool_result_parts.append( - messages_.ToolResultPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - result=inv.result, - is_error=_is_tool_error(inv.state), - ) - ) - - case ui_message.UIToolPart() as tp: - # Dynamic tool-{toolName} type (e.g., "tool-get_weather") - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - tool_args=_normalize_tool_args(tp.input), - ) - ) - if _is_tool_completed(tp.state): - tool_result_parts.append( - messages_.ToolResultPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - result=_normalize_tool_result(tp.output), - is_error=_is_tool_error(tp.state), - ) - ) - # Side-effect: resolve ToolApproval hooks from approval - # responses so the agent loop can resume execution. - if ( - tp.state == "approval-responded" - and tp.approval is not None - and tp.approval.approved is not None - ): - hooks.resolve_hook( - tp.approval.id, - { - "granted": tp.approval.approved, - "reason": tp.approval.reason, - }, - ) - resolved_any_approval = True - - case ui_message.UIFilePart() as fp: - assistant_parts.append( - messages_.FilePart( - data=fp.url, - media_type=fp.media_type, - filename=fp.filename, - ) - ) - - case ( - ui_message.UIStepStartPart() - | ui_message.UISourceUrlPart() - | ui_message.UISourceDocumentPart() - ): - pass # Skip unsupported/boundary parts - - # Validate user/system messages have content - OpenAI requires it. - if ui_msg.role in ("user", "system") and not assistant_parts: - raise ValueError( - f"Message '{ui_msg.id}' has role '{ui_msg.role}' but no content. " - "User and system messages require non-empty content." - ) - - # The UI sends one assistant message per conversation turn, but a - # single turn may span multiple default-loop iterations (e.g. - # [text, tool_call, tool_result, text, tool_call, tool_result, text]). - # LLM APIs expect one message per iteration, so split into - # assistant + tool message pairs at tool-result boundaries. - if ui_msg.role == "assistant": - result.extend( - _split_assistant_parts( - assistant_parts, tool_result_parts, msg_id=ui_msg.id - ) - ) - else: - result.append( - messages_.Message( - id=ui_msg.id, - role=ui_msg.role, - parts=assistant_parts, - ) - ) - - # When resuming after approvals were resolved above, the frontend - # sends the full history including the assistant message from the - # interrupted run. Strip it to avoid duplicate tool-use content. - if resolved_any_approval and result and result[-1].role == "assistant": - logger.info("Stripping trailing assistant message (approvals resolved)") - result = result[:-1] - - return result - - -def _split_assistant_parts( - parts: list[messages_.Part], - tool_results: list[messages_.ToolResultPart], - msg_id: str, -) -> list[messages_.Message]: - """Split assistant parts into assistant + tool message pairs. - - The UI sends one big assistant message per turn, but internally each - loop iteration produces an assistant message (with tool calls) followed - by a tool message (with results). This reconstructs that structure. - - Returns a list of Messages: alternating assistant and tool messages, - split at tool-call boundaries when results are available. - """ - # Index tool results by their tool_call_id for lookup - results_by_id = {tr.tool_call_id: tr for tr in tool_results} - - messages: list[messages_.Message] = [] - current: list[messages_.Part] = [] - pending_results: list[messages_.ToolResultPart] = [] - - for part in parts: - current.append(part) - - # When we see a ToolCallPart that has a result, accumulate it - if ( - isinstance(part, messages_.ToolCallPart) - and part.tool_call_id in results_by_id - ): - pending_results.append(results_by_id[part.tool_call_id]) - - # If there are pending results and more parts follow, we need to split. - # Walk again, splitting at boundaries where all accumulated tool calls - # have results and a non-tool part follows. - if not pending_results: - # No completed tools — single assistant message - if current: - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - return messages - - # Re-walk to split at tool-call boundaries - messages = [] - current = [] - current_results: list[messages_.ToolResultPart] = [] - seen_tool_call = False - - for part in parts: - # If we had a completed tool call group and now see a non-tool part, - # split here - if ( - seen_tool_call - and current_results - and not isinstance(part, messages_.ToolCallPart) - ): - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - messages.append(messages_.Message(role="tool", parts=list(current_results))) - current = [] - current_results = [] - seen_tool_call = False - - current.append(part) - - if isinstance(part, messages_.ToolCallPart): - seen_tool_call = True - if part.tool_call_id in results_by_id: - current_results.append(results_by_id[part.tool_call_id]) - - # Flush remaining - if current: - messages.append(messages_.Message(role="assistant", parts=current, id=msg_id)) - if current_results: - messages.append(messages_.Message(role="tool", parts=list(current_results))) - - return messages diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index c7d73a9d..cd8bd3be 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -206,7 +206,9 @@ def __call__(self, context: Context) -> AsyncGenerator[types.Message]: ... async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: while True: stream = await models.stream( - context.model, context.messages, tools=context.tools + context.model, + context.messages, + tools=context.tools, ) async for message in stream: yield message @@ -220,7 +222,10 @@ async def _default_loop(context: Context) -> AsyncGenerator[types.Message]: tasks = [tg.create_task(tc()) for tc in tool_calls] # Yield one merged tool-result message — history auto-collects it. - yield builders.tool_message(*(t.result() for t in tasks)) + # Left un-stamped: the tool result is the input of the *next* turn, + # so the next stream() call will stamp it with that turn's id. + tool_msg = builders.tool_message(*(t.result() for t in tasks)) + yield tool_msg async def _collect_messages( @@ -331,7 +336,7 @@ async def _real( source = _collect_messages(loop_fn(context), context.messages) async for message in runtime.run(source): if call.label is not None: - message = message.model_copy(update={"label": call.label}) + message = message.model_copy(update={"source_label": call.label}) yield message # Activate middleware for this run (and everything it calls). diff --git a/src/ai/agents/hooks.py b/src/ai/agents/hooks.py index 871a2a9d..6205dcb0 100644 --- a/src/ai/agents/hooks.py +++ b/src/ai/agents/hooks.py @@ -125,7 +125,7 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: # Emit pending signal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, @@ -150,10 +150,10 @@ async def _hook_impl(call: middleware_.HookContext) -> pydantic.BaseModel: # Clean up live registry. _live_hooks.pop(label, None) - # Emit resolved signal message. + # Emit resolved internal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, @@ -229,10 +229,10 @@ async def cancel_hook(label: str, *, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - # Emit cancelled signal message. + # Emit cancelled internal message. await rt.put_message( messages_.Message( - role="signal", + role="internal", parts=[ messages_.HookPart( hook_id=label, diff --git a/src/ai/adapters/__init__.py b/src/ai/agents/ui/__init__.py similarity index 61% rename from src/ai/adapters/__init__.py rename to src/ai/agents/ui/__init__.py index db357d09..2c272a79 100644 --- a/src/ai/adapters/__init__.py +++ b/src/ai/agents/ui/__init__.py @@ -1,5 +1,5 @@ """Downstream adapters — protocol bridges for the agent loop output.""" -from . import ai_sdk_ui +from . import ai_sdk -__all__ = ["ai_sdk_ui"] +__all__ = ["ai_sdk"] diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py new file mode 100644 index 00000000..8f283199 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -0,0 +1,23 @@ +"""AI SDK UI adapter — ``ai.Messages`` in, ``ai.Messages`` out, SSE on the wire.""" + +from .inbound import ( + ApprovalResponse, + apply_approvals, + extract_approvals, + to_messages, +) +from .outbound import to_sse, to_stream, to_ui_messages +from .protocol import UI_MESSAGE_STREAM_HEADERS +from .ui_message import UIMessage + +__all__ = [ + "ApprovalResponse", + "UIMessage", + "UI_MESSAGE_STREAM_HEADERS", + "apply_approvals", + "extract_approvals", + "to_messages", + "to_sse", + "to_stream", + "to_ui_messages", +] diff --git a/src/ai/agents/ui/ai_sdk/_approvals.py b/src/ai/agents/ui/ai_sdk/_approvals.py new file mode 100644 index 00000000..f4e98d79 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/_approvals.py @@ -0,0 +1,31 @@ +"""Approval-prefix linkage between ToolApproval hooks and tool calls. + +TODO(datamodel-rework §4): once ``HookPart.target_tool_call_id`` is added, +delete this module and replace call sites with direct field access. +""" + +from __future__ import annotations + +from ....types import messages as messages_ +from ...hooks import TOOL_APPROVAL_HOOK_TYPE + +_PREFIX = "approve_" + + +def tool_call_id_for(hook_part: messages_.HookPart) -> str | None: + """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" + if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: + return None + if hook_part.hook_id.startswith(_PREFIX): + return hook_part.hook_id[len(_PREFIX) :] + return None + + +def is_tool_approval_message(msg: messages_.Message) -> bool: + """True if every part of ``msg`` is a ToolApproval HookPart.""" + if not msg.parts: + return False + return all( + isinstance(p, messages_.HookPart) and tool_call_id_for(p) is not None + for p in msg.parts + ) diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py new file mode 100644 index 00000000..9d4d6ecc --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -0,0 +1,141 @@ +"""Shared conversions between internal Part objects and UIMessagePart objects. + +Used by ``outbound.history`` to reconstruct UIMessages from persisted +``ai.Message`` lists. The live outbound stream does not use these — it +emits wire-protocol deltas directly from ``Message.stream.new_events``. +""" + +from __future__ import annotations + +import json +from typing import Any + +from ....types import messages as messages_ +from . import _approvals, ui_message + + +def _normalize_tool_input(raw: str) -> str | dict[str, Any]: + """Parse tool args JSON string into a dict; fall back to raw string. + + TODO(datamodel-rework §4): once ``ToolCallPart.tool_args`` has a + canonical shape, drop this helper. + """ + try: + parsed = json.loads(raw) + except (json.JSONDecodeError, TypeError): + return raw + return parsed if isinstance(parsed, dict) else raw + + +def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: + """Convert internal Part objects to UIMessagePart objects.""" + result: list[ui_message.UIMessagePart] = [] + for part in parts: + if isinstance(part, messages_.TextPart) and part.text: + result.append(ui_message.UITextPart(type="text", text=part.text)) + elif isinstance(part, messages_.ReasoningPart) and part.text: + result.append( + ui_message.UIReasoningPart(type="reasoning", reasoning=part.text) + ) + elif isinstance(part, messages_.ToolCallPart): + result.append( + ui_message.UIToolPart.model_validate( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": _normalize_tool_input(part.tool_args), + } + ) + ) + elif isinstance(part, messages_.FilePart): + result.append( + ui_message.UIFilePart.model_validate( + { + "type": "file", + "mediaType": part.media_type, + "url": part.data if isinstance(part.data, str) else "", + "filename": part.filename, + } + ) + ) + return result + + +def merge_tool_results( + ui_parts: list[ui_message.UIMessagePart], + tool_parts: list[messages_.Part], +) -> None: + """Merge ToolResultParts into existing UIToolParts in-place.""" + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in tool_parts: + if not isinstance(part, messages_.ToolResultPart): + continue + idx_opt = tool_index.get(part.tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + if existing.state == "output-denied": + continue + state = "output-error" if part.is_error else "output-available" + ui_parts[idx] = existing.model_copy( + update={"state": state, "output": part.result} + ) + + +def merge_approval_signals( + ui_parts: list[ui_message.UIMessagePart], + internal_parts: list[messages_.Part], +) -> None: + """Merge HookPart approval state into existing UIToolParts in-place.""" + tool_index: dict[str, int] = {} + for idx, ui_part in enumerate(ui_parts): + if isinstance(ui_part, ui_message.UIToolPart): + tool_index[ui_part.tool_call_id] = idx + + for part in internal_parts: + if not isinstance(part, messages_.HookPart): + continue + + tool_call_id = _approvals.tool_call_id_for(part) + if tool_call_id is None: + continue + + idx_opt = tool_index.get(tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + + existing = ui_parts[idx] + if not isinstance(existing, ui_message.UIToolPart): + continue + + updates: dict[str, Any] = {} + if part.status == "pending": + updates["state"] = "approval-requested" + updates["approval"] = ui_message.UIToolApproval(id=part.hook_id) + elif part.status == "resolved": + resolution = part.resolution or {} + updates["approval"] = ui_message.UIToolApproval( + id=part.hook_id, + approved=resolution.get("granted"), + reason=resolution.get("reason"), + ) + if resolution.get("granted", False): + updates["state"] = "approval-responded" + else: + updates["state"] = "output-denied" + updates["output"] = None + elif part.status == "cancelled": + updates["state"] = "output-error" + updates["error_text"] = "Hook cancelled" + + if updates: + ui_parts[idx] = existing.model_copy(update=updates) diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py new file mode 100644 index 00000000..0dcadea2 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -0,0 +1,409 @@ +"""Inbound adapter: AI SDK v6 UIMessages → internal ``ai.Message`` list. + +The primary entry point is :func:`to_messages`, which bundles normalization, +approval extraction, parsing, and pre-registration of approval resolutions. +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, NamedTuple + +from ....types import messages as messages_ +from ...hooks import resolve_hook +from . import ui_message + +logger = logging.getLogger(__name__) + + +_TOOL_RESULT_STATES: frozenset[str] = frozenset({"output-available"}) +_TOOL_ERROR_STATES: frozenset[str] = frozenset({"output-error", "output-denied"}) + + +def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: + return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES + + +def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: + return state in _TOOL_ERROR_STATES + + +# TODO(datamodel-rework §4): once tool args have a canonical shape, drop +# these normalizers. +def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: + """Normalize tool input (JSON string, dict, or None) to a JSON string.""" + match tool_input: + case str(): + return tool_input + case dict(): + return json.dumps(tool_input) + case _: + return "{}" + + +def _normalize_tool_result(output: Any) -> dict[str, Any] | None: + """Normalize tool output to dict format for internal ToolResultPart.""" + if output is None: + return None + return output if isinstance(output, dict) else {"value": output} + + +def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None: + normalized = _normalize_tool_result(output) + if error_text: + if normalized is None: + return {"error": error_text} + if isinstance(normalized, dict) and "error" not in normalized: + return {"error": error_text, **normalized} + return normalized + + +def _approval_hook_part(tp: ui_message.UIToolPart) -> messages_.HookPart | None: + """Reconstruct approval hook state from a UI tool part when possible.""" + approval = tp.approval + if approval is None: + return None + + if tp.state == "approval-requested": + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="pending", + ) + + if tp.state == "approval-responded" and approval.approved is not None: + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="resolved", + resolution={ + "granted": approval.approved, + "reason": approval.reason, + }, + ) + + if tp.state == "output-denied": + return messages_.HookPart( + hook_id=approval.id, + hook_type="ToolApproval", + status="resolved", + resolution={ + "granted": False, + "reason": approval.reason, + }, + ) + + return None + + +# ============================================================================ +# Approval extraction + bulk resolution +# ============================================================================ + + +class ApprovalResponse(NamedTuple): + """Approval response extracted from a UIToolPart in ``approval-responded`` state.""" + + hook_id: str + granted: bool + reason: str | None + + +def extract_approvals( + ui_messages: list[ui_message.UIMessage], +) -> list[ApprovalResponse]: + """Return every approval response found in *ui_messages*. + + Pure function — does not resolve hooks or trigger side effects. + """ + approvals: list[ApprovalResponse] = [] + for ui_msg in ui_messages: + for part in ui_msg.parts: + if not isinstance(part, ui_message.UIToolPart): + continue + if ( + part.state == "approval-responded" + and part.approval is not None + and part.approval.approved is not None + ): + approvals.append( + ApprovalResponse( + hook_id=part.approval.id, + granted=part.approval.approved, + reason=part.approval.reason, + ) + ) + return approvals + + +def apply_approvals(approvals: list[ApprovalResponse]) -> None: + """Pre-register each approval resolution with the hooks registry.""" + for approval in approvals: + resolve_hook( + approval.hook_id, + {"granted": approval.granted, "reason": approval.reason}, + ) + + +# ============================================================================ +# UI message normalization (heal stale tool states) +# ============================================================================ + + +def _normalize_ui_messages( + ui_messages: list[ui_message.UIMessage], +) -> list[ui_message.UIMessage]: + """Heal stale tool-part states from previously persisted assistant history.""" + normalized: list[ui_message.UIMessage] = [] + for message in ui_messages: + new_parts = [] + changed = False + for part in message.parts: + part_type = getattr(part, "type", None) + state = getattr(part, "state", None) + if isinstance(part_type, str) and part_type.startswith("tool-"): + output = getattr(part, "output", None) + approval = getattr(part, "approval", None) + approved = approval.approved if approval is not None else None + error_text = getattr(part, "error_text", None) + + next_state = state + if output is not None: + if state == "output-error" or error_text is not None: + next_state = "output-error" + elif state == "output-denied" or approved is False: + next_state = "output-denied" + else: + next_state = "output-available" + elif state == "call": + next_state = "input-available" + + if next_state != state: + part = part.model_copy(update={"state": next_state}) + changed = True + + new_parts.append(part) + + normalized.append( + message.model_copy(update={"parts": new_parts}) if changed else message + ) + return normalized + + +# ============================================================================ +# UI → internal message conversion +# ============================================================================ + + +def to_messages( + ui_messages: list[ui_message.UIMessage], + *, + apply_approvals_: bool = True, +) -> list[messages_.Message]: + """Parse a UI request into runtime messages. + + Pipeline: + + 1. normalize stale tool states (``call`` → ``input-available``, etc.) + 2. extract approval responses + 3. parse UIMessages into ``ai.Message`` list, splitting at tool boundaries + 4. if *apply_approvals_* is True, pre-register resolutions via ``resolve_hook`` + 5. strip trailing assistant message when approval responses are present + """ + normalized = _normalize_ui_messages(ui_messages) + approvals = extract_approvals(normalized) + messages = _parse(normalized) + + if apply_approvals_: + apply_approvals(approvals) + + if approvals and messages: + # The assistant message that originated the approval-responded tool + # call would re-send the duplicate tool-use to the LLM on replay. + # Walk past any trailing internal (hook) messages and drop the + # assistant message beneath them. + idx = len(messages) - 1 + while idx >= 0 and messages[idx].role == "internal": + idx -= 1 + if idx >= 0 and messages[idx].role == "assistant": + logger.info("Stripping assistant message originating responded approvals") + messages = messages[:idx] + messages[idx + 1 :] + + return messages + + +def _parse( + ui_messages: list[ui_message.UIMessage], +) -> list[messages_.Message]: + result: list[messages_.Message] = [] + + for ui_msg in ui_messages: + assistant_parts: list[messages_.Part] = [] + tool_result_parts: list[messages_.ToolResultPart] = [] + hook_parts: list[messages_.HookPart] = [] + + for part in ui_msg.parts: + match part: + case ui_message.UITextPart(text=text) if text: + assistant_parts.append(messages_.TextPart(text=text)) + + case ui_message.UIReasoningPart(reasoning=reasoning): + assistant_parts.append(messages_.ReasoningPart(text=reasoning)) + + case ui_message.UIToolInvocationPart() as inv: + tool_args = json.dumps(inv.args) if inv.args else "{}" + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + tool_args=tool_args, + ) + ) + if _is_tool_completed(inv.state): + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + result=inv.result, + is_error=_is_tool_error(inv.state), + ) + ) + + case ui_message.UIToolPart() as tp: + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=_normalize_tool_args(tp.input), + ) + ) + approval_hook = _approval_hook_part(tp) + if approval_hook is not None: + hook_parts.append(approval_hook) + + if tp.state in _TOOL_RESULT_STATES: + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + result=_normalize_tool_result(tp.output), + is_error=False, + ) + ) + elif tp.state == "output-error": + tool_result_parts.append( + messages_.ToolResultPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + result=_error_result(tp.error_text, tp.output), + is_error=True, + ) + ) + + case ui_message.UIFilePart() as fp: + assistant_parts.append( + messages_.FilePart( + data=fp.url, + media_type=fp.media_type, + filename=fp.filename, + ) + ) + + case ( + ui_message.UIStepStartPart() + | ui_message.UISourceUrlPart() + | ui_message.UISourceDocumentPart() + ): + pass + + if ui_msg.role in ("user", "system") and not assistant_parts: + raise ValueError( + f"Message '{ui_msg.id}' has role '{ui_msg.role}' but no content. " + "User and system messages require non-empty content." + ) + + # The UI sends one assistant message per conversation turn, but a + # single turn may span multiple loop iterations (e.g. [text, + # tool_call, tool_result, text, tool_call, tool_result, text]). + # LLM APIs expect one message per iteration, so split into + # assistant + tool message pairs at tool-result boundaries. + if ui_msg.role == "assistant": + result.extend( + _split_assistant_parts( + assistant_parts, tool_result_parts, msg_id=ui_msg.id + ) + ) + for hp in hook_parts: + result.append( + messages_.Message( + id=ui_msg.id, + role="internal", + parts=[hp], + ) + ) + else: + result.append( + messages_.Message( + id=ui_msg.id, + role=ui_msg.role, + parts=assistant_parts, + ) + ) + + return result + + +def _split_assistant_parts( + parts: list[messages_.Part], + tool_results: list[messages_.ToolResultPart], + msg_id: str, +) -> list[messages_.Message]: + """Split assistant parts into assistant + tool message pairs.""" + results_by_id = {tr.tool_call_id: tr for tr in tool_results} + + pending_results: list[messages_.ToolResultPart] = [] + for part in parts: + if ( + isinstance(part, messages_.ToolCallPart) + and part.tool_call_id in results_by_id + ): + pending_results.append(results_by_id[part.tool_call_id]) + + if not pending_results: + if parts: + return [messages_.Message(role="assistant", parts=parts, id=msg_id)] + return [] + + messages: list[messages_.Message] = [] + current: list[messages_.Part] = [] + current_results: list[messages_.ToolResultPart] = [] + seen_tool_call = False + + for part in parts: + if ( + seen_tool_call + and current_results + and not isinstance(part, messages_.ToolCallPart) + ): + messages.append( + messages_.Message(role="assistant", parts=current, id=msg_id) + ) + messages.append(messages_.Message(role="tool", parts=list(current_results))) + current = [] + current_results = [] + seen_tool_call = False + + current.append(part) + + if isinstance(part, messages_.ToolCallPart): + seen_tool_call = True + if part.tool_call_id in results_by_id: + current_results.append(results_by_id[part.tool_call_id]) + + if current: + messages.append(messages_.Message(role="assistant", parts=current, id=msg_id)) + if current_results: + messages.append(messages_.Message(role="tool", parts=list(current_results))) + + return messages diff --git a/src/ai/agents/ui/ai_sdk/outbound/__init__.py b/src/ai/agents/ui/ai_sdk/outbound/__init__.py new file mode 100644 index 00000000..9347170e --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/__init__.py @@ -0,0 +1,7 @@ +"""Outbound adapter: ``ai.Message`` stream → AI SDK UI protocol.""" + +from .history import to_ui_messages +from .sse import to_sse +from .stream import to_stream + +__all__ = ["to_stream", "to_sse", "to_ui_messages"] diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py new file mode 100644 index 00000000..14b26e45 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -0,0 +1,305 @@ +"""Stream state bookkeeping for the live outbound walk. + +Owns message/step boundary logic (via ``turn_id`` + ``agent``), tracks +which parts have open text/reasoning blocks, and guards against +re-emission when the runtime re-yields an already-finalized message. +""" + +from __future__ import annotations + +from typing import Any + +from .....types import messages as messages_ +from .. import _approvals, protocol + + +def _tool_error_text(part: messages_.ToolResultPart) -> str: + """Best-effort error text extraction from a failed tool result.""" + if isinstance(part.result, str) and part.result: + return part.result + if isinstance(part.result, dict): + for key in ("error", "message", "detail"): + value = part.result.get(key) + if isinstance(value, str) and value: + return value + return "Tool execution failed" + + +class _StreamState: + """Single-pass state across one ``to_stream()`` call.""" + + def __init__(self) -> None: + self.current_turn_id: str | None = None + self.current_agent: str | None = None + self.ui_message_id: str | None = None + self.emitted_start: bool = False + self.in_step: bool = False + + # Message-level dedup — an ``is_done`` message re-emitted as input to a + # later ``stream()`` call must not fire events twice. + self.seen_done: set[str] = set() + + # Tool-call dedup — keyed by tool_call_id. + self.started_tool_inputs: set[str] = set() + self.input_available_emitted: set[str] = set() + self.emitted_tool_results: set[str] = set() + self.emitted_approval_requests: set[str] = set() + + # Open streaming blocks — keyed by part id. + self.open_text_ids: set[str] = set() + self.open_reasoning_ids: set[str] = set() + + # -- boundary helpers ---------------------------------------------------- + + def _close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: + parts: list[protocol.UIMessageStreamPart] = [] + for rid in list(self.open_reasoning_ids): + parts.append(protocol.ReasoningEndPart(id=rid)) + self.open_reasoning_ids.clear() + for tid in list(self.open_text_ids): + parts.append(protocol.TextEndPart(id=tid)) + self.open_text_ids.clear() + return parts + + def _finish_step(self) -> list[protocol.UIMessageStreamPart]: + parts = self._close_open_blocks() + if self.in_step: + parts.append(protocol.FinishStepPart()) + self.in_step = False + return parts + + def _reset_step_tracking(self) -> None: + self.started_tool_inputs.clear() + self.input_available_emitted.clear() + self.emitted_tool_results.clear() + self.emitted_approval_requests.clear() + + # -- phase: message start ------------------------------------------------ + + def on_message(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: + """Emit UIMessage/step boundary parts for *msg*.""" + parts: list[protocol.UIMessageStreamPart] = [] + + agent_changed = ( + self.emitted_start + and msg.source_label is not None + and msg.source_label != self.current_agent + ) + + if not self.emitted_start or agent_changed: + parts.extend(self._finish_step()) + if self.emitted_start: + parts.append(protocol.FinishPart(finish_reason="stop")) + + self.ui_message_id = msg.id + parts.append(protocol.StartPart(message_id=msg.id)) + parts.append(protocol.StartStepPart()) + self.emitted_start = True + self.in_step = True + self.current_agent = msg.source_label + self.current_turn_id = msg.turn_id + self._reset_step_tracking() + return parts + + # Same UIMessage — check for step boundary via turn_id change. Only + # non-None → different-non-None transitions fire a step boundary; + # None carries the current step (tool results yielded by the loop are + # intentionally left unstamped until the next stream() stamps them). + if ( + msg.turn_id is not None + and self.current_turn_id is not None + and msg.turn_id != self.current_turn_id + ): + parts.extend(self._finish_step()) + parts.append(protocol.StartStepPart()) + self.in_step = True + self._reset_step_tracking() + self.current_turn_id = msg.turn_id + elif msg.turn_id is not None and self.current_turn_id is None: + self.current_turn_id = msg.turn_id + + return parts + + # -- phase: per-event (mid-stream) --------------------------------------- + + def on_event( + self, + msg: messages_.Message, + event: messages_.StreamEvent, + ) -> list[protocol.UIMessageStreamPart]: + match event: + case messages_.PartOpened(part=messages_.TextPart(id=pid)): + self.open_text_ids.add(pid) + return [protocol.TextStartPart(id=pid)] + + case messages_.PartDelta(part=messages_.TextPart(id=pid), chunk=chunk): + if pid not in self.open_text_ids: + self.open_text_ids.add(pid) + return [ + protocol.TextStartPart(id=pid), + protocol.TextDeltaPart(id=pid, delta=chunk), + ] + return [protocol.TextDeltaPart(id=pid, delta=chunk)] + + case messages_.PartClosed(part=messages_.TextPart(id=pid)): + if pid in self.open_text_ids: + self.open_text_ids.discard(pid) + return [protocol.TextEndPart(id=pid)] + return [] + + case messages_.PartOpened(part=messages_.ReasoningPart(id=pid)): + self.open_reasoning_ids.add(pid) + return [protocol.ReasoningStartPart(id=pid)] + + case messages_.PartDelta(part=messages_.ReasoningPart(id=pid), chunk=chunk): + if pid not in self.open_reasoning_ids: + self.open_reasoning_ids.add(pid) + return [ + protocol.ReasoningStartPart(id=pid), + protocol.ReasoningDeltaPart(id=pid, delta=chunk), + ] + return [protocol.ReasoningDeltaPart(id=pid, delta=chunk)] + + case messages_.PartClosed(part=messages_.ReasoningPart(id=pid)): + if pid in self.open_reasoning_ids: + self.open_reasoning_ids.discard(pid) + return [protocol.ReasoningEndPart(id=pid)] + return [] + + case messages_.PartOpened(part=messages_.ToolCallPart() as tc): + if tc.tool_call_id in self.started_tool_inputs: + return [] + self.started_tool_inputs.add(tc.tool_call_id) + return [ + protocol.ToolInputStartPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + ) + ] + + case messages_.PartDelta(part=messages_.ToolCallPart() as tc, chunk=chunk): + out: list[protocol.UIMessageStreamPart] = [] + if tc.tool_call_id not in self.started_tool_inputs: + self.started_tool_inputs.add(tc.tool_call_id) + out.append( + protocol.ToolInputStartPart( + tool_call_id=tc.tool_call_id, + tool_name=tc.tool_name, + ) + ) + out.append( + protocol.ToolInputDeltaPart( + tool_call_id=tc.tool_call_id, + input_text_delta=chunk, + ) + ) + return out + + case messages_.PartClosed(part=messages_.ToolCallPart()): + # ToolInputAvailablePart is emitted in ``on_terminal`` from + # the terminal ``tool_args`` snapshot. + return [] + + return [] + + # -- phase: terminal (tool results, approvals, final tool-input) --------- + + def on_terminal(self, msg: messages_.Message) -> list[protocol.UIMessageStreamPart]: + if not msg.is_done: + return [] + + out: list[protocol.UIMessageStreamPart] = [] + + # Close any blocks that were opened but didn't see an explicit + # PartClosed (e.g. provider terminates abruptly — safety net). + if msg.stream is not None: + opened_ids = { + e.part.id + for e in msg.stream.new_events + if isinstance(e, messages_.PartOpened) + } + for tid in list(self.open_text_ids): + if tid in opened_ids and not any( + isinstance(e, messages_.PartClosed) and e.part.id == tid + for e in msg.stream.new_events + ): + out.append(protocol.TextEndPart(id=tid)) + self.open_text_ids.discard(tid) + + for part in msg.parts: + if isinstance(part, messages_.ToolCallPart): + if part.tool_call_id in self.input_available_emitted: + continue + self.input_available_emitted.add(part.tool_call_id) + # Ensure ToolInputStart was emitted (no streaming events case). + if part.tool_call_id not in self.started_tool_inputs: + self.started_tool_inputs.add(part.tool_call_id) + out.append( + protocol.ToolInputStartPart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + ) + out.append( + protocol.ToolInputAvailablePart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + input=part.tool_args, + ) + ) + + elif isinstance(part, messages_.ToolResultPart): + if part.tool_call_id in self.emitted_tool_results: + continue + self.emitted_tool_results.add(part.tool_call_id) + if part.is_error: + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=part.tool_call_id, + error_text=_tool_error_text(part), + ) + ) + else: + out.append( + protocol.ToolOutputAvailablePart( + tool_call_id=part.tool_call_id, + output=part.result, + ) + ) + + elif isinstance(part, messages_.HookPart): + tc_id = _approvals.tool_call_id_for(part) + if tc_id is None: + continue + + if part.status == "pending": + if tc_id in self.emitted_approval_requests: + continue + self.emitted_approval_requests.add(tc_id) + out.append( + protocol.ToolApprovalRequestPart( + approval_id=part.hook_id, + tool_call_id=tc_id, + ) + ) + elif part.status == "resolved": + resolution: dict[str, Any] = part.resolution or {} + if not resolution.get("granted", False): + out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) + elif part.status == "cancelled": + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=tc_id, + error_text="Hook cancelled", + ) + ) + + return out + + # -- phase: stream finish ------------------------------------------------ + + def finish(self) -> list[protocol.UIMessageStreamPart]: + parts = self._finish_step() + if self.emitted_start: + parts.append(protocol.FinishPart(finish_reason="stop")) + return parts diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py new file mode 100644 index 00000000..c5d5809b --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -0,0 +1,66 @@ +"""Persisted-message → UIMessage list for history endpoints.""" + +from __future__ import annotations + +from .....types import messages as messages_ +from .. import _parts, ui_message + + +def to_ui_messages( + messages: list[messages_.Message], +) -> list[ui_message.UIMessage]: + """Group persisted messages into UIMessage bubbles. + + ``user``/``system`` messages become standalone UIMessages. Runs of + ``assistant``/``tool``/``internal`` messages merge into a single + assistant UIMessage, with tool results and approval state folded into + the corresponding tool-call parts. + """ + result: list[ui_message.UIMessage] = [] + + i = 0 + while i < len(messages): + msg = messages[i] + + if msg.role in ("user", "system"): + result.append( + ui_message.UIMessage( + id=msg.id, + role=msg.role, + parts=_parts.to_ui_parts(msg.parts), + ) + ) + i += 1 + continue + + if msg.role == "assistant": + ui_parts: list[ui_message.UIMessagePart] = [] + bubble_id = msg.id + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "internal", + ): + current = messages[i] + if current.role == "assistant": + ui_parts.extend(_parts.to_ui_parts(current.parts)) + elif current.role == "tool": + _parts.merge_tool_results(ui_parts, current.parts) + elif current.role == "internal": + _parts.merge_approval_signals(ui_parts, current.parts) + i += 1 + + result.append( + ui_message.UIMessage( + id=bubble_id, + role="assistant", + parts=ui_parts, + ) + ) + continue + + # Orphan tool / internal messages — skip; they have no assistant anchor. + i += 1 + + return result diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py new file mode 100644 index 00000000..019f2894 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -0,0 +1,39 @@ +"""Serialize the UI message stream as Server-Sent Events.""" + +from __future__ import annotations + +import dataclasses +import json +from collections.abc import AsyncGenerator, AsyncIterable + +from .....types import messages as messages_ +from .. import protocol +from .stream import to_stream + + +def _to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def serialize_part(part: protocol.UIMessageStreamPart) -> str: + """Serialize a stream part to JSON with camelCase keys.""" + d = dataclasses.asdict(part) + if isinstance(part, protocol.DataPart): + d["type"] = part.type + del d["data_type"] + camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} + return json.dumps(camel_dict) + + +def format_sse(part: protocol.UIMessageStreamPart) -> str: + """Format a stream part as an SSE data line.""" + return f"data: {serialize_part(part)}\n\n" + + +async def to_sse( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[str]: + """Convert an internal message stream into SSE strings.""" + async for part in to_stream(messages): + yield format_sse(part) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py new file mode 100644 index 00000000..3bccf95f --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -0,0 +1,42 @@ +"""Convert an internal ``ai.Message`` stream into AI SDK UI protocol parts.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, AsyncIterable + +from .....types import messages as messages_ +from .. import protocol +from ._state import _StreamState + + +async def to_stream( + messages: AsyncIterable[messages_.Message], +) -> AsyncGenerator[protocol.UIMessageStreamPart]: + """Walk ``messages`` once, emitting AI SDK UI stream parts. + + Drives off ``Message.stream.new_events`` for incremental deltas and + ``Message.parts`` for terminal tool input/output/approval parts. + Re-emitted messages (same id, already seen ``is_done``) are skipped. + """ + state = _StreamState() + + async for msg in messages: + if msg.id in state.seen_done: + continue + + for part in state.on_message(msg): + yield part + + if msg.stream is not None and msg.stream.new_events: + for event in msg.stream.new_events: + for out in state.on_event(msg, event): + yield out + + for part in state.on_terminal(msg): + yield part + + if msg.is_done: + state.seen_done.add(msg.id) + + for part in state.finish(): + yield part diff --git a/src/ai/adapters/ai_sdk_ui/protocol.py b/src/ai/agents/ui/ai_sdk/protocol.py similarity index 100% rename from src/ai/adapters/ai_sdk_ui/protocol.py rename to src/ai/agents/ui/ai_sdk/protocol.py diff --git a/src/ai/adapters/ai_sdk_ui/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_message.py similarity index 99% rename from src/ai/adapters/ai_sdk_ui/ui_message.py rename to src/ai/agents/ui/ai_sdk/ui_message.py index b886444b..406a699f 100644 --- a/src/ai/adapters/ai_sdk_ui/ui_message.py +++ b/src/ai/agents/ui/ai_sdk/ui_message.py @@ -13,7 +13,7 @@ import pydantic -from ...types import messages as messages_ +from ....types import messages as messages_ class UITextPart(pydantic.BaseModel): diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 55ba9648..a4d9c4ca 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -13,7 +13,9 @@ msgs = [ai.user_message("hello")] s = await ai.stream(model, msgs) async for msg in s: - print(msg.text_delta, end="") + for ev in msg.deltas: + if isinstance(ev.part, ai.TextPart): + print(ev.chunk, end="", flush=True) # explicit client for custom auth client = ai.Client(base_url="https://custom.example.com/v1", api_key="sk-...") diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index d79e5ca0..898eafcd 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -27,6 +27,7 @@ async def stream( *, tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, + turn_id: str | None = None, **kwargs: Any, ) -> stream_.StreamResultLike: """Stream an LLM response. @@ -35,11 +36,19 @@ async def stream( collects the final ``Message``. After iteration, access ``.text``, ``.tool_calls``, ``.usage``, etc. + One call is one turn: a single request and its response. The model + response carries ``turn_id``; re-emitted input messages keep any + existing ``turn_id`` from prior turns and only receive the current + one when unstamped. If *turn_id* is not provided, one is generated. + The client is resolved from the model: ``model.client`` if set, otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. """ messages = integrity_.prepare_messages(messages) + if turn_id is None: + turn_id = messages_.generate_id("turn") + call = middleware_.ModelContext( model=model, messages=messages, @@ -48,6 +57,9 @@ async def stream( kwargs=kwargs, ) + # Capture in closure for the inner function. + _turn_id = turn_id + async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: c = client_.auto_client(call.model) adapter_fn = adapters.get_stream_adapter(call.model.adapter) @@ -59,7 +71,9 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: tools=call.tools, output_type=call.output_type, **call.kwargs, - ) + ), + turn_id=_turn_id, + input_messages=call.messages, ) chain = middleware_._build_model_chain(_real) diff --git a/src/ai/models/core/helpers/streaming.py b/src/ai/models/core/helpers/streaming.py index 1fa91441..bd7db603 100644 --- a/src/ai/models/core/helpers/streaming.py +++ b/src/ai/models/core/helpers/streaming.py @@ -1,10 +1,6 @@ from __future__ import annotations import dataclasses -import json -from collections.abc import AsyncGenerator - -import pydantic from ....types import messages as messages_ @@ -95,21 +91,16 @@ class StreamHandler: Accumulates LLM adapter events and produces Messages with stateful parts. This is the normalization layer between LLM adapters and the rest of the system. + Parts are tracked in a single ``_current_parts`` dict keyed by block/tool id, + updated in place as events stream in. Each event carries the just-constructed + frozen part snapshot, so consumers never need to look parts up by id. """ message_id: str = dataclasses.field(default_factory=messages_.generate_id) - # Accumulators - _text_blocks: dict[str, str] = dataclasses.field(default_factory=dict) - _reasoning_blocks: dict[str, tuple[str, str | None]] = dataclasses.field( - default_factory=dict - ) # (text, signature) - _tool_calls: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # (name, args) - _files: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # block_id -> (media_type, data) + # Single source of truth for part state, keyed by id. Insertion order + # preserves provider emission order. + _current_parts: dict[str, messages_.Part] = dataclasses.field(default_factory=dict) # Active tracking _active_text_id: str | None = None @@ -122,53 +113,91 @@ class StreamHandler: def handle_event(self, event: StreamEvent) -> messages_.Message: """Process event and return current Message state.""" - # Current deltas (reset each call) - text_delta: str | None = None - reasoning_delta: str | None = None - tool_deltas: dict[str, str] = {} # tool_call_id -> delta + # Sidecar events for this yield (reset each call). + stream_events: list[messages_.StreamEvent] = [] match event: case TextStart(block_id=bid): - self._text_blocks[bid] = "" + part: messages_.Part = messages_.TextPart(id=bid, text="") + self._current_parts[bid] = part self._active_text_id = bid + stream_events.append(messages_.PartOpened(part=part)) case TextDelta(block_id=bid, delta=d): - self._text_blocks[bid] += d - text_delta = d + existing = self._current_parts[bid] + assert isinstance(existing, messages_.TextPart) + part = messages_.TextPart(id=bid, text=existing.text + d) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case TextEnd(block_id=bid): if self._active_text_id == bid: self._active_text_id = None + stream_events.append( + messages_.PartClosed(part=self._current_parts[bid]) + ) case ReasoningStart(block_id=bid): - self._reasoning_blocks[bid] = ("", None) + part = messages_.ReasoningPart(id=bid, text="") + self._current_parts[bid] = part self._active_reasoning_id = bid + stream_events.append(messages_.PartOpened(part=part)) case ReasoningDelta(block_id=bid, delta=d): - text, sig = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text + d, sig) - reasoning_delta = d + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, + text=existing.text + d, + signature=existing.signature, + ) + self._current_parts[bid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case ReasoningEnd(block_id=bid, signature=sig): - text, _ = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text, sig) + existing = self._current_parts[bid] + assert isinstance(existing, messages_.ReasoningPart) + part = messages_.ReasoningPart( + id=bid, text=existing.text, signature=sig + ) + self._current_parts[bid] = part if self._active_reasoning_id == bid: self._active_reasoning_id = None + stream_events.append(messages_.PartClosed(part=part)) case ToolStart(tool_call_id=tcid, tool_name=name): - self._tool_calls[tcid] = (name, "") + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=tcid, + tool_name=name, + tool_args="", + ) + self._current_parts[tcid] = part self._active_tool_ids.add(tcid) + stream_events.append(messages_.PartOpened(part=part)) case ToolArgsDelta(tool_call_id=tcid, delta=d): - name, args = self._tool_calls[tcid] - self._tool_calls[tcid] = (name, args + d) - tool_deltas[tcid] = d + existing = self._current_parts[tcid] + assert isinstance(existing, messages_.ToolCallPart) + part = messages_.ToolCallPart( + id=tcid, + tool_call_id=existing.tool_call_id, + tool_name=existing.tool_name, + tool_args=existing.tool_args + d, + ) + self._current_parts[tcid] = part + stream_events.append(messages_.PartDelta(part=part, chunk=d)) case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) + stream_events.append( + messages_.PartClosed(part=self._current_parts[tcid]) + ) case FileEvent(block_id=bid, media_type=mt, data=d): - self._files[bid] = (mt, d) + self._current_parts[bid] = messages_.FilePart( + id=bid, data=d, media_type=mt + ) case MessageDone(usage=usage): self._is_done = True @@ -177,90 +206,19 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: self._active_reasoning_id = None self._active_tool_ids.clear() - return self._build_message(text_delta, reasoning_delta, tool_deltas) + return self._build_message(stream_events) def _build_message( self, - text_delta: str | None, - reasoning_delta: str | None, - tool_deltas: dict[str, str], + stream_events: list[messages_.StreamEvent], ) -> messages_.Message: - parts: list[messages_.Part] = [] - - # Reasoning parts first (like thinking blocks) - for bid, (text, sig) in self._reasoning_blocks.items(): - is_active = bid == self._active_reasoning_id - parts.append( - messages_.ReasoningPart( - id=bid, - text=text, - signature=sig, - state="streaming" if is_active else "done", - delta=reasoning_delta if is_active else None, - ) - ) - - # Text parts - for bid, text in self._text_blocks.items(): - is_active = bid == self._active_text_id - parts.append( - messages_.TextPart( - id=bid, - text=text, - state="streaming" if is_active else "done", - delta=text_delta if is_active else None, - ) - ) - - # Tool call parts - for tcid, (name, args) in self._tool_calls.items(): - is_active = tcid in self._active_tool_ids - parts.append( - messages_.ToolCallPart( - id=tcid, - tool_call_id=tcid, - tool_name=name, - tool_args=args, - state="streaming" if is_active else "done", - args_delta=tool_deltas.get(tcid), - ) - ) - - # File parts (inline images/videos from LLMs like Gemini, GPT-5) - for bid, (media_type, data) in self._files.items(): - parts.append(messages_.FilePart(id=bid, data=data, media_type=media_type)) - return messages_.Message( id=self.message_id, role="assistant", - parts=parts, + parts=list(self._current_parts.values()), usage=self._usage if self._is_done else None, + stream=messages_.StreamState( + new_events=stream_events, + is_done=self._is_done, + ), ) - - -async def events_to_messages( - events: AsyncGenerator[StreamEvent], - output_type: type[pydantic.BaseModel] | None = None, -) -> AsyncGenerator[messages_.Message]: - """Convert a stream of events into Message snapshots. - - This is the standalone version of the logic that ``LanguageModel.stream()`` - uses. Wire functions call this to turn their ``StreamEvent`` generators - into ``Message`` generators suitable for ``Stream``. - """ - handler = StreamHandler() - msg: messages_.Message | None = None - async for event in events: - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy(update={"parts": [*msg.parts, part]}) - yield msg diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py index eff6e5af..a15d6dd1 100644 --- a/src/ai/models/core/types.py +++ b/src/ai/models/core/types.py @@ -69,11 +69,26 @@ class StreamResult: Properties like ``.text`` and ``.tool_calls`` delegate to the final ``Message`` snapshot and are available after iteration completes. + One ``StreamResult`` represents one turn: a single LLM request and its + response. When *turn_id* is provided, the model response is stamped + with it. When *input_messages* is provided, they are re-emitted ahead + of the response; inputs that already carry a ``turn_id`` (from earlier + turns) are preserved as-is, only inputs with ``turn_id=None`` receive + the current *turn_id*. + Satisfies :class:`~ai.types.StreamResultLike`. """ - def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: + def __init__( + self, + gen: AsyncGenerator[messages_.Message], + *, + turn_id: str | None = None, + input_messages: list[messages_.Message] | None = None, + ) -> None: self._gen = gen + self._turn_id = turn_id + self._input_messages = input_messages or [] self._final: messages_.Message | None = None @classmethod @@ -98,10 +113,25 @@ def __aiter__(self) -> AsyncGenerator[messages_.Message]: return self._iterate() async def _iterate(self) -> AsyncGenerator[messages_.Message]: + # Re-emit input messages; stamp only the ones without a turn_id. + # Prior turns keep their existing ids. + for msg in self._input_messages: + if msg.turn_id is None and self._turn_id is not None: + msg = msg.model_copy(update={"turn_id": self._turn_id}) + yield msg + + # Stream model response with turn_id stamped (when missing). async for msg in self._gen: + if msg.turn_id is None and self._turn_id is not None: + msg = msg.model_copy(update={"turn_id": self._turn_id}) self._final = msg yield msg + @property + def turn_id(self) -> str | None: + """The turn id stamped on this stream's response (if any).""" + return self._turn_id + @property def text(self) -> str: return self._final.text if self._final else "" diff --git a/src/ai/types/__init__.py b/src/ai/types/__init__.py index aa7a9435..6a41f9d6 100644 --- a/src/ai/types/__init__.py +++ b/src/ai/types/__init__.py @@ -9,12 +9,14 @@ HookPart, Message, Part, - PartState, + PartClosed, + PartDelta, + PartOpened, ReasoningPart, + StreamState, StructuredOutputPart, TextPart, ToolCallPart, - ToolDelta, ToolResultPart, Usage, generate_id, @@ -27,13 +29,15 @@ "HookPart", "Message", "Part", - "PartState", + "PartClosed", + "PartDelta", + "PartOpened", "ReasoningPart", "StreamResultLike", + "StreamState", "StructuredOutputPart", "TextPart", "ToolCallPart", - "ToolDelta", "ToolLike", "ToolResultPart", "ToolSchema", diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py index f2b8ca87..9f737101 100644 --- a/src/ai/types/integrity.py +++ b/src/ai/types/integrity.py @@ -16,7 +16,7 @@ "invalid-tool-args", "orphaned-tool-call", "orphaned-tool-result", - "signal-message", + "internal-message", ] @@ -47,9 +47,9 @@ def _clean_messages( result: list[messages_.Message] = [] for msg in messages: - # 1. drop signal messages emitted by hooks - if msg.role == "signal": - issues.append("signal-message") + # 1. drop internal messages emitted by hooks + if msg.role == "internal": + issues.append("internal-message") if mode == "strict": result.append(msg) continue diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index e0adc8ef..25ecd6ce 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -13,19 +13,12 @@ def generate_id(prefix: str | None = None) -> str: return f"{prefix}_{raw}" if prefix else raw -# Streaming state for parts -PartState = Literal["streaming", "done"] - - class TextPart(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) id: str = pydantic.Field(default_factory=generate_id) text: str type: Literal["text"] = "text" - # Streaming state - state: PartState | None = None # None = finalized/restored from storage - delta: str | None = None # Current delta, None when not actively streaming class ToolCallPart(pydantic.BaseModel): @@ -43,9 +36,6 @@ class ToolCallPart(pydantic.BaseModel): tool_name: str tool_args: str type: Literal["tool_call"] = "tool_call" - # Streaming state (for args streaming) - state: PartState | None = None - args_delta: str | None = None # Delta for tool_args class ToolResultPart(pydantic.BaseModel): @@ -74,9 +64,6 @@ class ReasoningPart(pydantic.BaseModel): # Anthropic's thinking blocks include a signature for cache/verification. # This must be preserved and sent back in multi-turn conversations. signature: str | None = None - # Streaming state - state: PartState | None = None - delta: str | None = None class HookPart(pydantic.BaseModel): @@ -271,22 +258,78 @@ def _add_optional(a: int | None, b: int | None) -> int | None: ) -class ToolDelta(pydantic.BaseModel): +# --------------------------------------------------------------------------- +# Streaming sidecar — transient state excluded from persistence. +# --------------------------------------------------------------------------- + + +class PartOpened(pydantic.BaseModel): + """A new streaming block was opened by the LLM. + + ``part`` holds the initial snapshot of the part (empty text/args). + """ + model_config = pydantic.ConfigDict(frozen=True) - tool_call_id: str - tool_name: str - args_delta: str + part: Part + type: Literal["part_opened"] = "part_opened" + + +class PartDelta(pydantic.BaseModel): + """An incremental update to a streaming part. + + ``part`` is the post-delta snapshot (state accumulated up to and including + ``chunk``). ``chunk`` is the new fragment appended this step (plain text + for :class:`TextPart` / :class:`ReasoningPart`, a JSON-args fragment for + :class:`ToolCallPart`). + """ + + model_config = pydantic.ConfigDict(frozen=True) + + part: Part + chunk: str + type: Literal["part_delta"] = "part_delta" + + +class PartClosed(pydantic.BaseModel): + """A streaming block was closed by the LLM. + + ``part`` holds the final snapshot of the part. + """ + + model_config = pydantic.ConfigDict(frozen=True) + + part: Part + type: Literal["part_closed"] = "part_closed" + + +StreamEvent = Annotated[ + PartOpened | PartDelta | PartClosed, + pydantic.Field(discriminator="type"), +] + + +class StreamState(pydantic.BaseModel): + """Transient streaming state attached to a Message during streaming. + + ``new_events`` contains the events since the previous yield — never cumulative. + ``is_done`` is True once the stream has finished. + """ + + new_events: list[StreamEvent] = pydantic.Field(default_factory=list) + is_done: bool = False class Message(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) - role: Literal["user", "assistant", "system", "tool", "signal"] + role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] id: str = pydantic.Field(default_factory=generate_id) - label: str | None = None + turn_id: str | None = None + source_label: str | None = None usage: Usage | None = None + stream: StreamState | None = pydantic.Field(default=None, exclude=True) @overload def replace(self, new: Part, /) -> Message: ... @@ -337,45 +380,29 @@ def output(self) -> Any: @property def is_done(self) -> bool: - """Message is done when all parts are done (or have no streaming state).""" - for part in self.parts: - if ( - isinstance(part, (TextPart, ReasoningPart, ToolCallPart)) - and part.state == "streaming" - ): - return False - return True + """No sidecar (persisted/restored) means done. Otherwise ``stream.is_done``.""" + if self.stream is None: + return True + return self.stream.is_done - @property - def text_delta(self) -> str: - """Get current text delta from parts.""" + def get_part(self, part_id: str) -> Part | None: + """Find a part by id, or return None if not found.""" for part in self.parts: - if isinstance(part, TextPart) and part.delta: - return part.delta - return "" + if part.id == part_id: + return part + return None @property - def reasoning_delta(self) -> str: - """Get current reasoning delta from parts.""" - for part in self.parts: - if isinstance(part, ReasoningPart) and part.delta: - return part.delta - return "" + def deltas(self) -> list[PartDelta]: + """PartDelta events from this yield step, in order. - @property - def tool_deltas(self) -> list[ToolDelta]: - """Get current tool deltas from parts.""" - deltas = [] - for part in self.parts: - if isinstance(part, ToolCallPart) and part.args_delta: - deltas.append( - ToolDelta( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - args_delta=part.args_delta, - ) - ) - return deltas + Empty list means nothing streamed in this step. Each event carries + its post-delta :class:`Part` snapshot via ``ev.part`` and the chunk + fragment via ``ev.chunk``. + """ + if self.stream is None: + return [] + return [ev for ev in self.stream.new_events if isinstance(ev, PartDelta)] @property def files(self) -> list[FilePart]: diff --git a/src/ai/types/stream.py b/src/ai/types/stream.py index 3fa43a99..f6389bc0 100644 --- a/src/ai/types/stream.py +++ b/src/ai/types/stream.py @@ -34,3 +34,6 @@ def usage(self) -> messages_.Usage | None: ... @property def output(self) -> Any: ... + + @property + def turn_id(self) -> str | None: ... diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py deleted file mode 100644 index dae8732f..00000000 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ /dev/null @@ -1,693 +0,0 @@ -""" -Based on: .reference/ai/packages/ai/src/ui/process-ui-message-stream.test.ts -""" - -from __future__ import annotations - -import asyncio -from collections.abc import AsyncGenerator - -import ai -from ai.adapters.ai_sdk_ui import adapter, ui_message -from ai.agents import hooks -from ai.types import messages - -from ...conftest import MOCK_MODEL, mock_llm, tool_call_msg - - -async def get_event_types(msgs: list[messages.Message]) -> list[str]: - """Stream messages through adapter and return event type sequence.""" - - async def stream() -> AsyncGenerator[messages.Message]: - for m in msgs: - yield m - - return [p.type async for p in adapter.to_ui_message_stream(stream())] - - -# ----------------------------------------------------------------------------- -# Event sequence tests -# ----------------------------------------------------------------------------- - - -async def test_text_streaming() -> None: - """Text: start -> start-step -> text-start/delta/end -> finish-step -> finish""" - msgs = [ - messages.Message( - id="msg-1", - role="assistant", - parts=[messages.TextPart(text="Hello", delta="Hello", state="streaming")], - ), - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart( - text="Hello, world!", delta=", world!", state="streaming" - ) - ], - ), - messages.Message( - id="msg-1", - role="assistant", - parts=[messages.TextPart(text="Hello, world!", state="done")], - ), - ] - - assert await get_event_types(msgs) == [ - "start", - "start-step", - "text-start", - "text-delta", - "text-delta", - "text-end", - "finish-step", - "finish", - ] - - -async def test_tool_roundtrip() -> None: - """Server-side tool: input-available -> output-available -> text response. - - Reference: process-ui-message-stream.test.ts "server-side tool roundtrip" - """ - msgs = [ - # Tool call (assistant message) - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args='{"city": "London"}', - state="done", - ), - ], - ), - # Tool result (tool message, pinned to same id for same step) - messages.Message( - id="msg-1", - role="tool", - parts=[ - messages.ToolResultPart( - tool_call_id="tc-1", - tool_name="get_weather", - result={"weather": "sunny"}, - ), - ], - ), - # Final text - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart(text="The weather is sunny.", state="done"), - ], - ), - ] - - assert await get_event_types(msgs) == [ - "start", - "start-step", - "tool-input-start", - "tool-input-available", - "tool-output-available", - "finish-step", - "start-step", - "text-start", - "text-end", - "finish-step", - "finish", - ] - - -async def test_text_then_tool_then_text() -> None: - """Full mothership scenario: text -> tool -> result -> final text. - - Input: "when will the robots take over?" - 1. Text: "I'll check with the mothership..." - 2. Tool: talk_to_mothership(question="...") - 3. Result: "Soon." - 4. Text: "According to the mothership: Soon." - """ - msgs = [ - # Streaming initial text - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart( - text="I'll check with the mothership.", - delta="I'll check with the mothership.", - state="streaming", - ) - ], - ), - # Text done + tool call - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.TextPart(text="I'll check with the mothership.", state="done"), - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="talk_to_mothership", - tool_args='{"question": "when?"}', - state="done", - ), - ], - ), - # Tool result (pinned to same id for same step) - messages.Message( - id="msg-1", - role="tool", - parts=[ - messages.ToolResultPart( - tool_call_id="tc-1", - tool_name="talk_to_mothership", - result={"answer": "Soon."}, - ), - ], - ), - # Final text (new message) - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart( - text="According to the mothership: Soon.", - delta="According to the mothership: Soon.", - state="streaming", - ) - ], - ), - messages.Message( - id="msg-2", - role="assistant", - parts=[ - messages.TextPart( - text="According to the mothership: Soon.", state="done" - ) - ], - ), - ] - - # Per AI SDK protocol, tool-input-available and tool-output-available - # are in the SAME step (one LLM turn). Reference: - # process-ui-message-stream.test.ts - # "server-side tool roundtrip with multiple assistant texts" - assert await get_event_types(msgs) == [ - "start", - "start-step", - "text-start", - "text-delta", - "text-end", - "tool-input-start", - "tool-input-available", - "tool-output-available", # Same step as tool-input (AI SDK protocol) - "finish-step", - # New step for second LLM call (new message ID) - "start-step", - "text-start", - "text-delta", - "text-end", - "finish-step", - "finish", - ] - - -# ----------------------------------------------------------------------------- -# Integration tests - runtime-based execution -# ----------------------------------------------------------------------------- - - -@ai.tool -async def get_weather(city: str) -> str: - """Get weather for a city.""" - return f"Sunny in {city}" - - -async def test_runtime_tool_roundtrip() -> None: - """ - Integration test: run an Agent through agent.run() and verify - that tool-input-available and tool-output-available events are emitted. - - This test demonstrates the bug: the runtime yields the message with - the tool call, but by the time it's yielded the tool has already been - executed and the ToolPart has been mutated to status="result". The UI - adapter never sees the intermediate status="pending" state. - - Root cause: the default loop appends the message, then executes tools - which mutate the message in-place. The message was already yielded with - status="pending", but pydantic models are mutable so when we collect - them at the end, we see the mutated state. - """ - weather_agent = ai.agent(tools=[get_weather]) - - # First LLM call: returns a tool call - tool_call_response = [ - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args='{"city": "London"}', - state="done", - ), - ], - ), - ] - - # Second LLM call: returns final text - final_text_response = [ - messages.Message( - id="msg-2", - role="assistant", - parts=[messages.TextPart(text="The weather is sunny.", state="done")], - ), - ] - - mock_llm([tool_call_response, final_text_response]) - - # Collect all messages from the runtime - runtime_messages: list[messages.Message] = [] - async for msg in weather_agent.run( - MOCK_MODEL, [ai.user_message("What's the weather in London?")] - ): - runtime_messages.append(msg) - - # Stream through UI adapter - event_types = [ - p.type - async for p in adapter.to_ui_message_stream(_async_iter(runtime_messages)) - ] - - # This is what SHOULD happen: - # 1. First step streams tool call args then completes - # -> tool-input-start, tool-input-delta, tool-input-available - # 2. After tool execution, we yield the same message with - # status="result" -> tool-output-available - # (same step because same message ID) - # 3. Second LLM step streams text then completes - # -> text-start, text-delta, text-end, (final done msg) text-start, text-end - expected = [ - "start", - "start-step", - "tool-input-start", - "tool-input-delta", - "tool-input-available", - "tool-output-available", # Same step as input (same message ID) - "finish-step", - # Second LLM call (new message ID = new step) - "start-step", - "text-start", - "text-delta", - "text-end", - "text-start", # Final done message re-emits completed text - "text-end", - "finish-step", - "finish", - ] - - assert event_types == expected - - -async def _async_iter( - items: list[messages.Message], -) -> AsyncGenerator[messages.Message]: - """Helper to convert a list to an async generator.""" - for item in items: - yield item - - -# ----------------------------------------------------------------------------- -# UI → Internal conversion tests -# ----------------------------------------------------------------------------- - - -def test_ui_to_internal_two_turn_with_tool() -> None: - """Test converting a realistic two-turn conversation with tool call. - - This test uses the exact payload structure from a real AI SDK frontend - that was causing 422 validation errors due to: - 1. step-start parts (boundary markers) - 2. tool-{toolName} dynamic type pattern (e.g., "tool-talk_to_mothership") - """ - # Exact structure from a failing request - raw_messages = [ - { - "id": "lmaOqWZJdKOVUbYT", - "role": "user", - "parts": [{"type": "text", "text": "when will the robots take over?"}], - }, - { - "id": "d04b88d9a82e", - "role": "assistant", - "parts": [ - {"type": "step-start"}, - { - "type": "text", - "text": "I'll check with the mothership " - "about this important question.", - "state": "done", - }, - { - "type": "tool-talk_to_mothership", - "toolCallId": "toolu_01FiXNXhq1kHx4TegRjSaJyv", - "state": "output-available", - "input": '{"question": "when will the robots take over?"}', - "output": "Soon.", - }, - {"type": "text", "text": "", "state": "done"}, # Empty text part - {"type": "step-start"}, - { - "type": "text", - "text": "The mothership has spoken: Soon.", - "state": "done", - }, - {"type": "text", "text": "", "state": "done"}, # Empty text part - ], - }, - { - "id": "ZLi3qVpgZLBjwMxZ", - "role": "user", - "parts": [ - { - "type": "text", - "text": "this is a test run. can you remember the first turn?", - } - ], - }, - ] - - # Parse UI messages - this should NOT raise validation errors - ui_messages = [ui_message.UIMessage.model_validate(m) for m in raw_messages] - - # Verify parsing worked - assert len(ui_messages) == 3 - assert ui_messages[0].role == "user" - assert ui_messages[1].role == "assistant" - assert ui_messages[2].role == "user" - - # Check that step-start and tool parts were parsed correctly - assistant_parts = ui_messages[1].parts - assert isinstance(assistant_parts[0], ui_message.UIStepStartPart) - assert isinstance(assistant_parts[1], ui_message.UITextPart) - assert isinstance(assistant_parts[2], ui_message.UIToolPart) - assert assistant_parts[2].tool_name == "talk_to_mothership" - assert assistant_parts[2].state == "output-available" - - # Convert to internal format - internal = adapter.to_messages(ui_messages) - - # The single UI assistant message contains [text, tool(done), text] from - # two stream_loop iterations. to_messages splits at the tool-result - # boundary so LLM adapters receive one message per iteration: - # user, assistant (text + tool_call), tool (result), assistant (text), user - assert len(internal) == 5 - assert internal[0].role == "user" - assert internal[0].text == "when will the robots take over?" - - # First iteration: text + tool call (assistant message) - assert internal[1].role == "assistant" - assert internal[1].text == ( - "I'll check with the mothership about this important question." - ) - assert len(internal[1].tool_calls) == 1 - assert internal[1].tool_calls[0].tool_name == "talk_to_mothership" - assert internal[1].tool_calls[0].tool_call_id == "toolu_01FiXNXhq1kHx4TegRjSaJyv" - - # Tool result (separate tool message) - assert internal[2].role == "tool" - assert len(internal[2].tool_results) == 1 - assert internal[2].tool_results[0].result == {"value": "Soon."} - - # Second iteration: follow-up text - assert internal[3].role == "assistant" - assert internal[3].text == "The mothership has spoken: Soon." - assert len(internal[3].tool_calls) == 0 - - assert internal[4].role == "user" - assert internal[4].text == "this is a test run. can you remember the first turn?" - - -def test_ui_tool_part_with_dict_input() -> None: - """Test that tool parts with dict input (not JSON string) are handled.""" - raw_message = { - "id": "msg-1", - "role": "assistant", - "parts": [ - { - "type": "tool-get_weather", - "toolCallId": "tc-1", - "state": "input-available", - "input": {"city": "London"}, # Dict, not JSON string - } - ], - } - - ui_msg = ui_message.UIMessage.model_validate(raw_message) - internal = adapter.to_messages([ui_msg]) - - assert len(internal) == 1 - tool_part = internal[0].tool_calls[0] - assert tool_part.tool_name == "get_weather" - assert tool_part.tool_args == '{"city": "London"}' - # input-available means call is present but no result yet (no tool message) - - -def test_ui_file_part_converted_to_core_file_part() -> None: - """UIFilePart from the frontend is converted to a core FilePart.""" - raw_message = { - "id": "msg-1", - "role": "user", - "parts": [ - {"type": "text", "text": "What's in this image?"}, - { - "type": "file", - "mediaType": "image/png", - "url": "https://example.com/photo.png", - "filename": "photo.png", - }, - ], - } - ui_msg = ui_message.UIMessage.model_validate(raw_message) - internal = adapter.to_messages([ui_msg]) - - assert len(internal) == 1 - msg = internal[0] - assert msg.role == "user" - assert len(msg.parts) == 2 - assert isinstance(msg.parts[0], messages.TextPart) - assert isinstance(msg.parts[1], messages.FilePart) - fp = msg.parts[1] - assert fp.data == "https://example.com/photo.png" - assert fp.media_type == "image/png" - assert fp.filename == "photo.png" - - -def test_ui_skips_unsupported_parts() -> None: - """Test that unsupported part types are skipped gracefully.""" - raw_message = { - "id": "msg-1", - "role": "assistant", - "parts": [ - {"type": "text", "text": "Hello"}, - {"type": "data-custom", "data": {"foo": "bar"}}, # Unsupported - {"type": "unknown-type", "content": "xyz"}, # Unsupported - {"type": "text", "text": "World"}, - ], - } - - ui_msg = ui_message.UIMessage.model_validate(raw_message) - # Only text parts should be parsed (data-* and unknown skipped) - assert len(ui_msg.parts) == 2 - assert all(isinstance(p, ui_message.UITextPart) for p in ui_msg.parts) - - internal = adapter.to_messages([ui_msg]) - assert len(internal[0].parts) == 2 - - -# ----------------------------------------------------------------------------- -# Tool approval (human-in-the-loop) tests -# ----------------------------------------------------------------------------- - - -async def test_tool_approval_hook_emits_approval_request() -> None: - """Pending ToolApproval HookPart emits tool-approval-request on the wire. - - The HookPart message uses a *different* id from the tool message, - matching what the Runtime actually does (it creates an ad-hoc Message - with its own auto-generated id at runtime.py:452). The adapter must - keep both in the same step so the frontend's sendAutomaticallyWhen - helper can find the tool part when the user responds to the approval. - """ - msgs = [ - # Tool call (args complete, awaiting approval) - messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolCallPart( - tool_call_id="tc-1", - tool_name="rm_rf", - tool_args='{"path": "/"}', - state="done", - ), - ], - ), - # Hook pending (approval requested) — different message id, - # just like the Runtime produces at runtime.py:452. - messages.Message( - id="hook-msg-1", - role="assistant", - parts=[ - messages.HookPart( - hook_id="approve_tc-1", - hook_type=hooks.TOOL_APPROVAL_HOOK_TYPE, - status="pending", - metadata={"tool_name": "rm_rf", "tool_args": '{"path": "/"}'}, - ), - ], - ), - ] - - event_types = await get_event_types(msgs) - # tool-approval-request must be in the SAME step as the tool input — - # no extra start-step/finish-step between them. - assert event_types == [ - "start", - "start-step", - "tool-input-start", - "tool-input-available", - "tool-approval-request", - "finish-step", - "finish", - ] - - -def test_approval_responded_resolves_hook() -> None: - """to_messages() resolves the ToolApproval hook for approval-responded parts.""" - label = "approve_tc-42" - raw_messages = [ - { - "id": "msg-1", - "role": "assistant", - "parts": [ - { - "type": "tool-dangerous_action", - "toolCallId": "tc-42", - "state": "approval-responded", - "input": '{"x": 1}', - "approval": { - "id": label, - "approved": True, - "reason": "looks safe", - }, - } - ], - }, - ] - - # Clean up any leftover state from other tests - hooks._pending_resolutions.pop(label, None) - - ui_msgs = [ui_message.UIMessage.model_validate(m) for m in raw_messages] - adapter.to_messages(ui_msgs) - - # The side-effect should have pre-registered the resolution - assert label in hooks._pending_resolutions - resolution = hooks._pending_resolutions.pop(label) - assert resolution == {"granted": True, "reason": "looks safe"} - - -async def test_runtime_tool_approval_same_step() -> None: - """E2E: tool-approval-request must land in the same SSE step as the tool call. - - Runs a graph with ToolApproval (interrupt_loop=True) through agent.run(), - collects runtime messages, streams through the adapter, and asserts - that no spurious step boundary appears between tool-input-available - and tool-approval-request. - """ - from collections.abc import AsyncGenerator as AG - - @ai.tool - async def dangerous_action(path: str) -> str: - """Do something dangerous.""" - return f"deleted {path}" - - approval_agent = ai.agent(tools=[dangerous_action]) - - @approval_agent.loop - async def custom(context: ai.Context) -> AG[ai.Message]: - stream = await ai.models.stream( - context.model, context.messages, tools=context.tools - ) - async for msg in stream: - yield msg - - tool_calls = context.resolve(stream.tool_calls) - if not tool_calls: - return - - async def approve_and_execute(tc: ai.ToolCall) -> ai.Message: - approval = await ai.hook( - f"approve_{tc.id}", - payload=ai.ToolApproval, - metadata={"tool_name": tc.name}, - interrupt_loop=True, - ) - if approval.granted: - return await tc() - return ai.Message( - role="tool", - parts=[ - ai.ToolResultPart( - tool_call_id=tc.id, - tool_name=tc.name, - result="denied", - is_error=True, - ) - ], - ) - - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(approve_and_execute(tc)) for tc in tool_calls] - yield ai.tool_message(*(t.result() for t in tasks)) - - mock_llm( - [ - [ - tool_call_msg( - tc_id="tc-1", - name="dangerous_action", - args='{"path": "/tmp"}', - ) - ], - ] - ) - - runtime_messages: list[messages.Message] = [] - async for msg in approval_agent.run(MOCK_MODEL, [ai.user_message("delete /tmp")]): - runtime_messages.append(msg) - - # Stream through UI adapter - event_types = [ - p.type - async for p in adapter.to_ui_message_stream(_async_iter(runtime_messages)) - ] - - # tool-approval-request must be in the SAME step as tool-input. - assert event_types == [ - "start", - "start-step", - "tool-input-start", - "tool-input-delta", - "tool-input-available", - "tool-approval-request", - "finish-step", - "finish", - ] diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 37259b18..c1e6a4f7 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -24,12 +24,12 @@ async def progress_tool(query: str) -> AsyncGenerator[ai.Message]: """Tool that streams progress, then returns a final answer.""" yield ai.Message( role="assistant", - parts=[messages_.TextPart(text="Working...", state="done")], - label="progress", + parts=[messages_.TextPart(text="Working...")], + source_label="progress", ) yield ai.Message( role="assistant", - parts=[messages_.TextPart(text=f"Answer for {query}", state="done")], + parts=[messages_.TextPart(text=f"Answer for {query}")], ) @@ -51,7 +51,7 @@ async def test_generator_tool_streams_and_returns_result() -> None: assert llm.call_count == 2 # Intermediate progress message was forwarded to consumer. - progress = [m for m in collected if m.label == "progress"] + progress = [m for m in collected if m.source_label == "progress"] assert len(progress) == 1 assert progress[0].text == "Working..." @@ -175,7 +175,7 @@ async def test_yield_from_nested_agent() -> None: assert adapter.call_count == 3 # Inner messages were forwarded to the consumer with label="inner". - inner_msgs = [m for m in collected if m.label == "inner"] + inner_msgs = [m for m in collected if m.source_label == "inner"] assert len(inner_msgs) > 0 # The outer LLM's second call (index 2) must NOT contain any inner @@ -187,7 +187,7 @@ async def test_yield_from_nested_agent() -> None: # Specifically: no inner assistant text or inner tool results leaked. for m in outer_turn2_msgs: - assert m.label is None or m.label != "inner" + assert m.source_label is None or m.source_label != "inner" if m.role == "assistant": # This must be the outer tool-call message, not inner text. assert len(m.tool_calls) > 0 diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index aa08d1c3..6cda3a8d 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -72,13 +72,11 @@ async def test_agent_parallel_tools() -> None: tool_call_id="tc-1", tool_name="double", tool_args='{"x": 3}', - state="done", ), messages.ToolCallPart( tool_call_id="tc-2", tool_name="double", tool_args='{"x": 7}', - state="done", ), ], ) @@ -111,3 +109,74 @@ async def test_agent_multi_turn() -> None: async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Concat then double")]): msgs.append(m) assert llm.call_count == 3 + + +# -- turn_id semantics: one turn per LLM round-trip ------------------------- + + +async def test_two_user_messages_produce_four_turns() -> None: + """Two agent.run invocations, each with a tool call + final reply, + produce four distinct turn ids; history from the first run keeps its + original turn ids when fed into the second run.""" + my_agent = ai.agent(tools=[double]) + + # Run 1: tool call, then text. + r1_turn1 = [tool_call_msg(tc_id="tc-1", name="double", args='{"x": 5}', id="m-1a")] + r1_turn2 = [text_msg("Ten.", id="m-1b")] + # Run 2: tool call, then text. + r2_turn1 = [tool_call_msg(tc_id="tc-2", name="double", args='{"x": 7}', id="m-2a")] + r2_turn2 = [text_msg("Fourteen.", id="m-2b")] + mock_llm([r1_turn1, r1_turn2, r2_turn1, r2_turn2]) + + def dedup(stream: list[ai.Message]) -> list[ai.Message]: + seen: dict[str, ai.Message] = {} + for m in stream: + if m.is_done: + seen[m.id] = m + return list(seen.values()) + + run1_stream: list[ai.Message] = [] + async for m in my_agent.run(MOCK_MODEL, [ai.user_message("Double 5")]): + run1_stream.append(m) + history = dedup(run1_stream) + + run2_stream: list[ai.Message] = [] + async for m in my_agent.run(MOCK_MODEL, [*history, ai.user_message("Double 7")]): + run2_stream.append(m) + final = dedup(run2_stream) + + # Chronological list of terminal non-internal messages. Insertion order + # of ``dedup`` reflects the order they first appeared in the stream. + chronological = [m for m in final if m.role != "internal"] + assert len(chronological) == 8 + + # Expected shape: four turns, each a (input, assistant) pair. + # turn 1: user → assistant (tool call) + # turn 2: tool → assistant (text) + # turn 3: user → assistant (tool call) + # turn 4: tool → assistant (text) + expected_roles = [ + ("user", "assistant"), + ("tool", "assistant"), + ("user", "assistant"), + ("tool", "assistant"), + ] + pairs = [(chronological[2 * i], chronological[2 * i + 1]) for i in range(4)] + for (left, right), (expected_left, expected_right) in zip( + pairs, expected_roles, strict=True + ): + assert (left.role, right.role) == (expected_left, expected_right) + # Both messages in the pair share the same turn_id. + assert left.turn_id is not None + assert left.turn_id == right.turn_id + + # The four turn_ids are all distinct. + turn_ids = [left.turn_id for left, _ in pairs] + assert len(set(turn_ids)) == 4 + + # Run 1's history survives untouched into run 2. + history_ids = {h.id for h in history} + for h in history: + same = next(m for m in final if m.id == h.id) + assert same.turn_id == h.turn_id + assert any(m.id in history_ids for m in final) diff --git a/tests/adapters/__init__.py b/tests/agents/ui/__init__.py similarity index 100% rename from tests/adapters/__init__.py rename to tests/agents/ui/__init__.py diff --git a/tests/adapters/ai_sdk_ui/__init__.py b/tests/agents/ui/ai_sdk/__init__.py similarity index 100% rename from tests/adapters/ai_sdk_ui/__init__.py rename to tests/agents/ui/ai_sdk/__init__.py diff --git a/tests/agents/ui/ai_sdk/outbound/__init__.py b/tests/agents/ui/ai_sdk/outbound/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py new file mode 100644 index 00000000..7ebdf3e8 --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -0,0 +1,126 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import to_ui_messages +from ai.agents.ui.ai_sdk.ui_message import ( + UITextPart, + UIToolPart, +) +from ai.types import messages as messages_ + + +def test_to_ui_messages_user_and_assistant() -> None: + msgs = [ + messages_.Message(id="u1", role="user", parts=[messages_.TextPart(text="hi")]), + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="hello back")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + assert result[1].id == "a1" + + +def test_to_ui_messages_merges_assistant_tool_internal() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.TextPart(text="calling"), + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + ], + ), + messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 2}, + ) + ], + ), + messages_.Message( + role="assistant", + parts=[messages_.TextPart(text="done")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + ui_msg = result[0] + assert ui_msg.role == "assistant" + assert ui_msg.id == "a1" + assert isinstance(ui_msg.parts[0], UITextPart) + assert ui_msg.parts[0].text == "calling" + assert isinstance(ui_msg.parts[1], UIToolPart) + assert ui_msg.parts[1].state == "output-available" + assert ui_msg.parts[1].output == {"hits": 2} + assert isinstance(ui_msg.parts[2], UITextPart) + assert ui_msg.parts[2].text == "done" + + +def test_to_ui_messages_internal_role_merges_approval() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + ), + messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + ] + result = to_ui_messages(msgs) + ui_msg = result[0] + tool_part = ui_msg.parts[0] + assert isinstance(tool_part, UIToolPart) + assert tool_part.state == "approval-requested" + assert tool_part.approval is not None + assert tool_part.approval.id == "approve_tc1" + + +def test_to_ui_messages_user_message_uses_own_id() -> None: + msgs = [ + messages_.Message(id="u1", role="user", parts=[messages_.TextPart(text="a")]) + ] + result = to_ui_messages(msgs) + assert result[0].id == "u1" + + +def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="first")], + ), + messages_.Message( + id="a2", + role="assistant", + parts=[messages_.TextPart(text="second")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + assert result[0].id == "a1" diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py new file mode 100644 index 00000000..af7f998f --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator + +from ai.agents.ui.ai_sdk import protocol, to_sse +from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part +from ai.types import messages as messages_ + + +def test_serialize_part_camelcases_keys() -> None: + part = protocol.StartPart(message_id="m1") + payload = json.loads(serialize_part(part)) + assert payload == {"type": "start", "messageId": "m1"} + + +def test_format_sse_wraps_data_line() -> None: + part = protocol.TextDeltaPart(id="t1", delta="hi") + line = format_sse(part) + assert line.startswith("data: ") + assert line.endswith("\n\n") + + +def test_serialize_data_part_uses_type_with_prefix() -> None: + part = protocol.DataPart(data_type="custom", data={"k": 1}) + payload = json.loads(serialize_part(part)) + assert payload["type"] == "data-custom" + assert "dataType" not in payload + + +async def _gen( + msgs: list[messages_.Message], +) -> AsyncGenerator[messages_.Message]: + for m in msgs: + yield m + + +async def test_to_sse_emits_data_prefixed_lines() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hi")], + stream=messages_.StreamState(new_events=[], is_done=True), + ) + ] + lines = [line async for line in to_sse(_gen(msgs))] + assert all(line.startswith("data: ") for line in lines) + # first line is the start part + first = json.loads(lines[0].removeprefix("data: ").rstrip()) + assert first["type"] == "start" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py new file mode 100644 index 00000000..7a14fd1e --- /dev/null +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator + +from ai.agents.ui.ai_sdk import protocol, to_stream +from ai.types import messages as messages_ + + +async def _gen( + msgs: list[messages_.Message], +) -> AsyncGenerator[messages_.Message]: + for m in msgs: + yield m + + +async def _collect( + msgs: list[messages_.Message], +) -> list[protocol.UIMessageStreamPart]: + return [p async for p in to_stream(_gen(msgs))] + + +def _text_stream_message( + msg_id: str, + turn_id: str | None, + text_id: str, + chunk: str, + *, + is_done: bool, + full_text: str | None = None, +) -> messages_.Message: + text = full_text or chunk + part = messages_.TextPart(id=text_id, text=text) + events: list[messages_.StreamEvent] + if is_done: + events = [messages_.PartClosed(part=part)] + else: + events = [messages_.PartDelta(part=part, chunk=chunk)] + return messages_.Message( + id=msg_id, + role="assistant", + turn_id=turn_id, + parts=[part], + stream=messages_.StreamState(new_events=events, is_done=is_done), + ) + + +async def test_event_driven_text_streaming() -> None: + text_id = "txt1" + empty_text = messages_.TextPart(id=text_id, text="") + hi_text = messages_.TextPart(id=text_id, text="hi") + msgs = [ + # Initial: PartOpened + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[empty_text], + stream=messages_.StreamState( + new_events=[messages_.PartOpened(part=empty_text)], + is_done=False, + ), + ), + # Delta: "hi" + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[hi_text], + stream=messages_.StreamState( + new_events=[messages_.PartDelta(part=hi_text, chunk="hi")], + is_done=False, + ), + ), + # Closed + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[hi_text], + stream=messages_.StreamState( + new_events=[messages_.PartClosed(part=hi_text)], + is_done=True, + ), + ), + ] + out = await _collect(msgs) + # expect: Start, StartStep, TextStart, TextDelta, TextEnd, FinishStep, Finish + assert isinstance(out[0], protocol.StartPart) + assert out[0].message_id == "m1" + assert isinstance(out[1], protocol.StartStepPart) + assert isinstance(out[2], protocol.TextStartPart) and out[2].id == text_id + assert isinstance(out[3], protocol.TextDeltaPart) and out[3].delta == "hi" + assert isinstance(out[4], protocol.TextEndPart) and out[4].id == text_id + assert isinstance(out[5], protocol.FinishStepPart) + assert isinstance(out[6], protocol.FinishPart) + + +async def test_turn_id_change_emits_step_boundary() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[messages_.TextPart(text="hello")], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + messages_.Message( + id="m2", + role="assistant", + turn_id="t2", # different turn → step boundary + parts=[messages_.TextPart(text="world")], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + ] + out = await _collect(msgs) + # Look for FinishStep followed by StartStep between messages. + has_mid_step_boundary = any( + isinstance(out[i], protocol.FinishStepPart) + and i + 1 < len(out) + and isinstance(out[i + 1], protocol.StartStepPart) + for i in range(1, len(out) - 1) + ) + assert has_mid_step_boundary + + +async def test_agent_change_emits_message_boundary() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + source_label="a1", + parts=[messages_.TextPart(text="from a")], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + messages_.Message( + id="m2", + role="assistant", + source_label="a2", # different source → FinishPart + StartPart + parts=[messages_.TextPart(text="from b")], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + ] + out = await _collect(msgs) + # There should be a FinishPart+StartPart pair mid-stream. + has_mid_msg_boundary = any( + isinstance(out[i], protocol.FinishPart) + and i + 1 < len(out) + and isinstance(out[i + 1], protocol.StartPart) + for i in range(1, len(out) - 1) + ) + assert has_mid_msg_boundary + + +async def test_tool_call_and_result_emit_terminal_parts() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ) + ], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ), + ] + out = await _collect(msgs) + types = [type(p).__name__ for p in out] + assert "ToolInputStartPart" in types + assert "ToolInputAvailablePart" in types + assert "ToolOutputAvailablePart" in types + + +async def test_approval_request_hook_emits_approval_part() -> None: + msgs = [ + messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + stream=messages_.StreamState(new_events=[], is_done=True), + ), + messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + ] + out = await _collect(msgs) + approval_parts = [p for p in out if isinstance(p, protocol.ToolApprovalRequestPart)] + assert len(approval_parts) == 1 + assert approval_parts[0].tool_call_id == "tc1" + assert approval_parts[0].approval_id == "approve_tc1" + + +async def test_dedup_on_reemitted_message_id() -> None: + empty = messages_.TextPart(id="txt1", text="") + hi = messages_.TextPart(id="txt1", text="hi") + msg = messages_.Message( + id="m1", + role="assistant", + turn_id="t1", + parts=[hi], + stream=messages_.StreamState( + new_events=[ + messages_.PartOpened(part=empty), + messages_.PartDelta(part=hi, chunk="hi"), + messages_.PartClosed(part=hi), + ], + is_done=True, + ), + ) + out = await _collect([msg, msg]) # re-emit the same done message + text_deltas = [p for p in out if isinstance(p, protocol.TextDeltaPart)] + # only the first emission should fire a TextDelta + assert len(text_deltas) == 1 diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py new file mode 100644 index 00000000..e0733757 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -0,0 +1,64 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import _approvals +from ai.types import messages as messages_ + + +def test_tool_call_id_for_strips_prefix() -> None: + hook = messages_.HookPart( + hook_id="approve_tc_42", + hook_type="ToolApproval", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) == "tc_42" + + +def test_tool_call_id_for_rejects_non_approval_type() -> None: + hook = messages_.HookPart( + hook_id="approve_tc_42", + hook_type="SomethingElse", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) is None + + +def test_tool_call_id_for_rejects_bad_prefix() -> None: + hook = messages_.HookPart( + hook_id="tc_42", + hook_type="ToolApproval", + status="pending", + ) + assert _approvals.tool_call_id_for(hook) is None + + +def test_is_tool_approval_message_detects_all_approval_hooks() -> None: + msg = messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc_1", + hook_type="ToolApproval", + status="pending", + ), + ], + ) + assert _approvals.is_tool_approval_message(msg) + + +def test_is_tool_approval_message_false_for_non_approval() -> None: + msg = messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="other", + hook_type="Something", + status="pending", + ), + ], + ) + assert not _approvals.is_tool_approval_message(msg) + + +def test_is_tool_approval_message_false_for_empty() -> None: + msg = messages_.Message(role="internal", parts=[]) + assert not _approvals.is_tool_approval_message(msg) diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py new file mode 100644 index 00000000..9b808d82 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from typing import Any + +import pytest + +from ai.agents.ui.ai_sdk import to_messages +from ai.agents.ui.ai_sdk.inbound import ( + _normalize_ui_messages, + extract_approvals, +) +from ai.agents.ui.ai_sdk.ui_message import UIMessage, UIToolPart + + +def _ui(role: str, *parts: dict[str, Any], id: str = "m1") -> UIMessage: + return UIMessage.model_validate({"id": id, "role": role, "parts": list(parts)}) + + +def _text(text: str) -> dict[str, Any]: + return {"type": "text", "text": text} + + +def _tool( + tool_name: str, + tool_call_id: str, + state: str, + **extra: Any, +) -> dict[str, Any]: + return { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "state": state, + **extra, + } + + +def test_to_messages_user_text() -> None: + result = to_messages([_ui("user", _text("hello"))]) + assert len(result) == 1 + assert result[0].role == "user" + assert result[0].text == "hello" + + +def test_to_messages_splits_at_tool_boundary() -> None: + result = to_messages( + [ + _ui( + "assistant", + _text("before"), + _tool( + "search", + "tc1", + "output-available", + input={"q": "x"}, + output={"hits": 3}, + ), + _text("after"), + ) + ] + ) + assert [m.role for m in result] == ["assistant", "tool", "assistant"] + assert result[1].tool_results[0].tool_call_id == "tc1" + + +def test_to_messages_approval_hook_emitted_as_internal() -> None: + result = to_messages( + [ + _ui( + "assistant", + _tool( + "delete", + "tc1", + "approval-requested", + approval={"id": "approve_tc1"}, + ), + ) + ], + apply_approvals_=False, + ) + assert [m.role for m in result] == ["assistant", "internal"] + hook = result[1].parts[0] + assert hook.type == "hook" + assert hook.hook_id == "approve_tc1" + + +def test_to_messages_strips_trailing_assistant_when_approved() -> None: + result = to_messages( + [ + _ui("user", _text("delete it"), id="u1"), + _ui( + "assistant", + _tool( + "delete", + "tc1", + "approval-responded", + approval={"id": "approve_tc1", "approved": True, "reason": None}, + ), + id="a1", + ), + ], + apply_approvals_=False, + ) + assert [m.role for m in result] == ["user", "internal"] + + +def test_extract_approvals_returns_approved_responses() -> None: + approvals = extract_approvals( + [ + _ui( + "assistant", + _tool( + "x", + "tc1", + "approval-responded", + approval={ + "id": "approve_tc1", + "approved": False, + "reason": "nope", + }, + ), + ) + ] + ) + assert len(approvals) == 1 + assert approvals[0].hook_id == "approve_tc1" + assert approvals[0].granted is False + assert approvals[0].reason == "nope" + + +def test_normalize_ui_messages_heals_stale_tool_state() -> None: + ui = [ + _ui( + "assistant", + _tool("x", "tc1", "input-available", output={"ok": True}), + ) + ] + normalized = _normalize_ui_messages(ui) + tool_part = normalized[0].parts[0] + assert isinstance(tool_part, UIToolPart) + assert tool_part.state == "output-available" + + +def test_to_messages_rejects_empty_user() -> None: + ui = [UIMessage.model_validate({"id": "u1", "role": "user", "parts": []})] + with pytest.raises(ValueError): + to_messages(ui) diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py new file mode 100644 index 00000000..69ca1566 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_parts.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from ai.agents.ui.ai_sdk import _parts +from ai.agents.ui.ai_sdk.ui_message import ( + UIReasoningPart, + UITextPart, + UIToolApproval, + UIToolPart, +) +from ai.types import messages as messages_ + + +def test_to_ui_parts_text_and_reasoning() -> None: + parts: list[messages_.Part] = [ + messages_.ReasoningPart(text="thinking"), + messages_.TextPart(text="hi"), + ] + ui_parts = _parts.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIReasoningPart) + assert ui_parts[0].reasoning == "thinking" + assert isinstance(ui_parts[1], UITextPart) + assert ui_parts[1].text == "hi" + + +def test_to_ui_parts_tool_call_parses_json_args() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q": "x"}', + ) + ] + ui_parts = _parts.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIToolPart) + assert ui_parts[0].type == "tool-search" + assert ui_parts[0].input == {"q": "x"} + assert ui_parts[0].state == "input-available" + + +def test_merge_tool_results_updates_state_and_output() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + ] + ui_parts = _parts.to_ui_parts(parts) + _parts.merge_tool_results( + ui_parts, + [ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 3}, + ) + ], + ) + merged = ui_parts[0] + assert isinstance(merged, UIToolPart) + assert merged.state == "output-available" + assert merged.output == {"hits": 3} + + +def test_merge_approval_signals_pending_then_resolved() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ] + ui_parts = _parts.to_ui_parts(parts) + + _parts.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ) + requested = ui_parts[0] + assert isinstance(requested, UIToolPart) + assert requested.state == "approval-requested" + assert isinstance(requested.approval, UIToolApproval) + + _parts.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + resolution={"granted": True, "reason": None}, + ) + ], + ) + responded = ui_parts[0] + assert isinstance(responded, UIToolPart) + assert responded.state == "approval-responded" + assert responded.approval is not None + assert responded.approval.approved is True diff --git a/tests/conftest.py b/tests/conftest.py index 1e8529f1..dd876b3a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -215,10 +215,8 @@ def text_msg( text: str, *, id: str = "msg-1", - state: messages_.PartState | None = "done", - delta: str | None = None, ) -> messages_.Message: - part: messages_.Part = messages_.TextPart(text=text, state=state, delta=delta) + part: messages_.Part = messages_.TextPart(text=text) return messages_.Message(id=id, role="assistant", parts=[part]) @@ -234,7 +232,6 @@ def tool_call_msg( tool_call_id=tc_id, tool_name=name, tool_args=args, - state="done", ) return messages_.Message(id=id, role="assistant", parts=[part]) diff --git a/tests/models/core/test_streaming.py b/tests/models/core/test_streaming.py index 7cc7c3a2..b123f0e8 100644 --- a/tests/models/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -4,6 +4,7 @@ from ai.models.core.helpers import streaming from ai.types import messages +from ai.types.messages import PartClosed, PartDelta, PartOpened # -- Text streaming -------------------------------------------------------- @@ -14,27 +15,41 @@ def test_text_lifecycle() -> None: assert len(m.parts) == 1 part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "streaming" assert part.text == "" + assert m.stream is not None + assert any( + isinstance(e, PartOpened) and e.part.id == "b1" for e in m.stream.new_events + ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta="Hello")) part = m.parts[0] assert isinstance(part, messages.TextPart) assert part.text == "Hello" - assert part.delta == "Hello" - assert part.state == "streaming" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == "Hello" + for e in m.stream.new_events + ) m = h.handle_event(streaming.TextDelta(block_id="b1", delta=" world")) part = m.parts[0] assert isinstance(part, messages.TextPart) assert part.text == "Hello world" - assert part.delta == " world" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part.id == "b1" and e.chunk == " world" + for e in m.stream.new_events + ) m = h.handle_event(streaming.TextEnd(block_id="b1")) part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "done" - assert part.delta is None + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part.id == "b1" for e in m.stream.new_events + ) + # No delta events in this yield + assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) # -- Reasoning streaming --------------------------------------------------- @@ -47,13 +62,20 @@ def test_reasoning_lifecycle() -> None: part = m.parts[0] assert isinstance(part, messages.ReasoningPart) assert part.text == "thinking" - assert part.state == "streaming" + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part.id == "r1" and e.chunk == "thinking" + for e in m.stream.new_events + ) m = h.handle_event(streaming.ReasoningEnd(block_id="r1", signature="sig123")) part = m.parts[0] assert isinstance(part, messages.ReasoningPart) - assert part.state == "done" assert part.signature == "sig123" + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part.id == "r1" for e in m.stream.new_events + ) # -- Tool streaming -------------------------------------------------------- @@ -67,8 +89,11 @@ def test_tool_lifecycle() -> None: assert isinstance(part, messages.ToolCallPart) assert part.tool_name == "get_weather" assert part.tool_args == '{"ci' - assert part.state == "streaming" - assert part.args_delta == '{"ci' + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part.id == "tc1" and e.chunk == '{"ci' + for e in m.stream.new_events + ) m = h.handle_event( streaming.ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}') @@ -80,8 +105,12 @@ def test_tool_lifecycle() -> None: m = h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) part = m.parts[0] assert isinstance(part, messages.ToolCallPart) - assert part.state == "done" - assert part.args_delta is None + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events + ) + # No delta events in this yield + assert not any(isinstance(e, PartDelta) for e in m.stream.new_events) # -- Multi-part messages --------------------------------------------------- @@ -106,12 +135,10 @@ def test_reasoning_then_text_then_tool() -> None: assert isinstance(m.parts[0], messages.ReasoningPart) assert isinstance(m.parts[1], messages.TextPart) assert isinstance(m.parts[2], messages.ToolCallPart) - assert all( - p.state == "done" - for p in m.parts - if isinstance( - p, (messages.TextPart, messages.ToolCallPart, messages.ReasoningPart) - ) + # The last event was ToolEnd(tc1), so only that PartClosed is in events + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part.id == "tc1" for e in m.stream.new_events ) @@ -134,12 +161,10 @@ def test_multiple_tool_calls() -> None: h.handle_event(streaming.ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) h.handle_event(streaming.ToolEnd(tool_call_id="tc1")) m = h.handle_event(streaming.ToolEnd(tool_call_id="tc2")) - assert all( - p.state == "done" - for p in m.parts - if isinstance( - p, (messages.TextPart, messages.ToolCallPart, messages.ReasoningPart) - ) + # Last event was ToolEnd(tc2), so its PartClosed is in events + assert m.stream is not None + assert any( + isinstance(e, PartClosed) and e.part.id == "tc2" for e in m.stream.new_events ) @@ -154,8 +179,9 @@ def test_message_done_finalizes_all() -> None: m = h.handle_event(streaming.MessageDone(finish_reason="end_turn")) part = m.parts[0] assert isinstance(part, messages.TextPart) - assert part.state == "done" assert m.is_done + assert m.stream is not None + assert m.stream.is_done def test_message_done_propagates_usage() -> None: @@ -180,7 +206,7 @@ def test_message_done_propagates_usage() -> None: def test_deltas_only_on_active_blocks() -> None: - """Delta should be None on inactive blocks, present only on active.""" + """Delta events should only reference the active block.""" h = streaming.StreamHandler(message_id="m1") h.handle_event(streaming.TextStart(block_id="t1")) h.handle_event(streaming.TextDelta(block_id="t1", delta="first")) @@ -190,8 +216,17 @@ def test_deltas_only_on_active_blocks() -> None: m = h.handle_event(streaming.TextDelta(block_id="t2", delta="second")) text_parts = [p for p in m.parts if isinstance(p, messages.TextPart)] - assert text_parts[0].delta is None # t1 is done - assert text_parts[1].delta == "second" # t2 is active + assert text_parts[0].text == "first" # t1 snapshot + assert text_parts[1].text == "second" # t2 snapshot + # Only t2 has a delta event in this yield + assert m.stream is not None + assert any( + isinstance(e, PartDelta) and e.part.id == "t2" and e.chunk == "second" + for e in m.stream.new_events + ) + assert not any( + isinstance(e, PartDelta) and e.part.id == "t1" for e in m.stream.new_events + ) # -- File event (inline images from LLMs like Gemini/GPT-5) --------------- diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 7ca4b1b1..8f1ddb1b 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -32,14 +32,54 @@ async def test_stream_basic() -> None: s = await models.stream(MOCK_MODEL, [ai.user_message("Hi")]) deltas: list[str] = [] async for msg in s: - if msg.text_delta: - deltas.append(msg.text_delta) + for ev in msg.deltas: + if isinstance(ev.part, messages_.TextPart): + deltas.append(ev.chunk) assert mock.call_count == 1 assert s.text == "Hello world" assert "".join(deltas) == "Hello world" +async def test_stream_preserves_existing_turn_ids() -> None: + """ai.stream() stamps only inputs without a turn_id; older turns survive.""" + mock = mock_llm([[text_msg("reply")]]) + + old = ai.user_message("earlier") + old = old.model_copy(update={"turn_id": "prev"}) + fresh = ai.user_message("latest") + + s = await models.stream(MOCK_MODEL, [old, fresh]) + yielded: list[messages_.Message] = [] + async for msg in s: + yielded.append(msg) + + assert mock.call_count == 1 + # First yielded is the old input — unchanged. + assert yielded[0].turn_id == "prev" + # Fresh input was stamped with the current turn's id. + assert yielded[1].turn_id is not None + assert yielded[1].turn_id != "prev" + # Response shares the current turn id. + response_ids = [m.turn_id for m in yielded if m.role == "assistant"] + assert response_ids and all(tid == yielded[1].turn_id for tid in response_ids) + + +async def test_stream_accepts_explicit_turn_id() -> None: + """Explicit turn_id kwarg is used verbatim.""" + mock_llm([[text_msg("ok")]]) + fresh = ai.user_message("hi") + + s = await models.stream(MOCK_MODEL, [fresh], turn_id="custom-turn") + yielded: list[messages_.Message] = [] + async for msg in s: + yielded.append(msg) + + assert s.turn_id == "custom-turn" + assert yielded[0].turn_id == "custom-turn" + assert yielded[-1].turn_id == "custom-turn" + + async def test_stream_with_explicit_client() -> None: """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] @@ -57,7 +97,7 @@ async def _spy_stream( yield messages_.Message( id="m1", role="assistant", - parts=[messages_.TextPart(text="ok", state="done")], + parts=[messages_.TextPart(text="ok")], ) models.register_stream("mock", _spy_stream) @@ -89,7 +129,7 @@ async def _structured_stream( output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, ) -> AsyncGenerator[messages_.Message]: - text_part = messages_.TextPart(text=json_text, state="done") + text_part = messages_.TextPart(text=json_text) parts: list[messages_.Part] = [text_part] if output_type is not None: import json diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index ac308752..16a94f18 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -93,14 +93,14 @@ def test_idempotent() -> None: # --------------------------------------------------------------------------- -# Signal messages +# Internal messages # --------------------------------------------------------------------------- -def test_drops_signal_messages() -> None: +def test_drops_internal_messages() -> None: msgs = [ builders.user_message("hi"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), builders.assistant_message("hello"), ] result = prepare_messages(msgs) @@ -109,13 +109,13 @@ def test_drops_signal_messages() -> None: assert result[1].role == "assistant" -def test_signal_strict_raises() -> None: +def test_internal_strict_raises() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), ] with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") - assert "signal-message" in exc_info.value.issues + assert "internal-message" in exc_info.value.issues # --------------------------------------------------------------------------- @@ -346,7 +346,7 @@ def test_complete_tool_flow_unchanged() -> None: def test_strict_collects_all_issues() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), messages.Message( role="assistant", parts=[ @@ -358,13 +358,13 @@ def test_strict_collects_all_issues() -> None: with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") issues = exc_info.value.issues - assert "signal-message" in issues + assert "internal-message" in issues assert "internal-part" in issues def test_strict_keeps_recoverable_issues_when_history_is_corrupt() -> None: msgs = [ - messages.Message(role="signal", parts=[messages.TextPart(text="x")]), + messages.Message(role="internal", parts=[messages.TextPart(text="x")]), builders.user_message("go"), messages.Message( role="assistant", @@ -380,7 +380,7 @@ def test_strict_keeps_recoverable_issues_when_history_is_corrupt() -> None: ] with pytest.raises(IntegrityError) as exc_info: prepare_messages(msgs, mode="strict") - assert "signal-message" in exc_info.value.issues + assert "internal-message" in exc_info.value.issues assert "duplicate-tool-call" in exc_info.value.issues @@ -487,8 +487,8 @@ async def test_stream_calls_prepare_messages() -> None: spy.assert_called_once_with(msgs) -async def test_stream_sanitizes_signal_messages() -> None: - """Signal messages are stripped before reaching the adapter.""" +async def test_stream_sanitizes_internal_messages() -> None: + """Internal messages are stripped before reaching the adapter.""" received: list[list[messages.Message]] = [] mock = mock_llm([[text_msg("ok")]]) @@ -514,17 +514,17 @@ async def _spy_stream( msgs = [ ai.user_message("hi"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), ai.assistant_message("hello"), ] s = await models.stream(MOCK_MODEL, msgs) async for _ in s: pass - # The adapter should have received only 2 messages (signal stripped) + # The adapter should have received only 2 messages (internal stripped) assert len(received) == 1 assert len(received[0]) == 2 - assert all(m.role != "signal" for m in received[0]) + assert all(m.role != "internal" for m in received[0]) async def test_generate_calls_prepare_messages() -> None: @@ -543,8 +543,8 @@ async def test_generate_calls_prepare_messages() -> None: spy.assert_called_once_with(msgs) -async def test_generate_sanitizes_signal_messages() -> None: - """Signal messages are stripped before reaching generate adapter.""" +async def test_generate_sanitizes_internal_messages() -> None: + """Internal messages are stripped before reaching generate adapter.""" received: list[list[messages.Message]] = [] sentinel = messages.Message( role="assistant", @@ -564,7 +564,7 @@ async def _spy_gen( msgs = [ ai.user_message("A cat"), - messages.Message(role="signal", parts=[messages.TextPart(text="internal")]), + messages.Message(role="internal", parts=[messages.TextPart(text="internal")]), ] await models.generate(MOCK_MODEL, msgs, models.ImageParams(n=1))