From d256105d44132d7f8caacd744734b302266ed044 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 19 May 2026 10:49:14 -0700 Subject: [PATCH 01/13] Add tests for broken message id logic in the ui adapter --- .../agents/ui/ai_sdk/outbound/test_history.py | 159 ++++++++++++++++++ .../agents/ui/ai_sdk/outbound/test_stream.py | 21 +++ tests/types/test_integrity.py | 67 ++++++++ 3 files changed, 247 insertions(+) diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py index 7665bc13..8f701373 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -1,13 +1,118 @@ from __future__ import annotations +from collections import Counter + +from ai.agents.ui import ai_sdk from ai.agents.ui.ai_sdk import to_ui_messages from ai.agents.ui.ai_sdk.ui_message import ( UITextPart, UIToolPart, ) +from ai.types import integrity from ai.types import messages as messages_ +def _parallel_tool_turn( + *, + turn_id: str, + assistant_prefix: str | None = None, + tool_call_ids: tuple[str, str] = ("tc-bash", "tc-web"), +) -> list[messages_.Message]: + prefix = assistant_prefix or turn_id + tc_bash, tc_web = tool_call_ids + + return [ + messages_.Message( + id=f"{prefix}:assistant:0", + turn_id=turn_id, + role="assistant", + parts=[ + messages_.TextPart( + id=f"{prefix}:text:0", + text="I will run two tools.", + ), + messages_.ToolCallPart( + id=f"{prefix}:call:bash", + tool_call_id=tc_bash, + tool_name="bash", + tool_args='{"command":"date"}', + ), + messages_.ToolCallPart( + id=f"{prefix}:call:web", + tool_call_id=tc_web, + tool_name="web_fetch", + tool_args='{"url":"https://httpbin.org/get"}', + ), + ], + ), + messages_.Message( + id=f"{prefix}:tool:0", + turn_id=turn_id, + role="tool", + parts=[ + messages_.ToolResultPart( + id=f"{prefix}:result:bash", + tool_call_id=tc_bash, + tool_name="bash", + result="Tue May 19 2026", + ), + messages_.ToolResultPart( + id=f"{prefix}:result:web", + tool_call_id=tc_web, + tool_name="web_fetch", + result={"status": 200}, + ), + ], + ), + messages_.Message( + id=f"{prefix}:assistant:1", + turn_id=turn_id, + role="assistant", + parts=[ + messages_.TextPart( + id=f"{prefix}:text:1", + text="Both tools finished.", + ), + ], + ), + ] + + +def _tool_counts( + messages: list[messages_.Message], +) -> Counter[tuple[str, str]]: + counts: Counter[tuple[str, str]] = Counter() + for message in messages: + for part in message.parts: + if isinstance(part, messages_.ToolCallPart): + counts["tool_call", part.tool_call_id] += 1 + elif isinstance(part, messages_.ToolResultPart): + counts["tool_result", part.tool_call_id] += 1 + return counts + + +class IdUpsertStore: + """Small app-like store: persist full history by message id.""" + + def __init__(self) -> None: + self._rows: list[messages_.Message] = [] + + def save_full_history(self, messages: list[messages_.Message]) -> None: + for message in messages: + if message.role == "system": + continue + + for index, existing in enumerate(self._rows): + if existing.id == message.id: + self._rows[index] = message + break + else: + self._rows.append(message) + + def load(self) -> list[messages_.Message]: + return list(self._rows) + + def test_to_ui_messages_user_and_assistant() -> None: msgs = [ messages_.Message( @@ -128,3 +233,57 @@ def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: result = to_ui_messages(msgs) assert len(result) == 1 assert result[0].id == "a1" + + +def test_common_id_upsert_persistence_is_idempotent_after_reload() -> None: + store = IdUpsertStore() + + first_run = [ + messages_.Message( + id="user-1", + role="user", + parts=[messages_.TextPart(id="user-1:text", text="run two tools")], + ), + *_parallel_tool_turn(turn_id="turn-1"), + ] + store.save_full_history(first_run) + + reloaded_ui = ai_sdk.to_ui_messages(store.load()) + request_history, _ = ai_sdk.to_messages(reloaded_ui) + + second_run_result = [ + *request_history, + messages_.Message( + id="user-2", + role="user", + parts=[messages_.TextPart(id="user-2:text", text="do nothing")], + ), + messages_.Message( + id="turn-2:assistant:0", + turn_id="turn-2", + role="assistant", + parts=[messages_.TextPart(id="turn-2:text:0", text="standing by")], + ), + ] + store.save_full_history(second_run_result) + + loaded = store.load() + integrity.prepare_messages(loaded) + + counts = _tool_counts(loaded) + assert counts["tool_call", "tc-bash"] == 1 + assert counts["tool_result", "tc-bash"] == 1 + assert counts["tool_call", "tc-web"] == 1 + assert counts["tool_result", "tc-web"] == 1 + + +def test_duplicate_tool_copies_do_not_reach_model_integrity() -> None: + history = [ + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="server"), + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="client"), + ] + + reloaded_ui = ai_sdk.to_ui_messages(history) + next_request_history, _ = ai_sdk.to_messages(reloaded_ui) + + integrity.prepare_messages(next_request_history) diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 1022de46..ebaf85da 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -22,6 +22,27 @@ async def _collect( return [part async for part in to_stream(_gen(stream_events))] +async def test_stream_start_uses_runtime_message_id() -> None: + assistant = messages_.Message( + id="assistant-runtime-id", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="hello")], + ) + + out = await _collect( + [ + events_.TextStart(block_id="text-1", message=assistant), + events_.TextDelta( + block_id="text-1", chunk="hello", message=assistant + ), + events_.TextEnd(block_id="text-1", message=assistant), + ] + ) + + start = next(part for part in out if isinstance(part, protocol.StartPart)) + assert start.message_id == "assistant-runtime-id" + + async def test_event_driven_text_streaming() -> None: """Streaming text events lazily open a UI message.""" text_id = "txt1" diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index 8e41101d..e046005f 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -57,6 +57,56 @@ def _tool_result( ) +def _parallel_tool_turn( + *, + turn_id: str, + assistant_prefix: str, + tool_call_ids: tuple[str, str] = ("tc-bash", "tc-web"), +) -> list[messages.Message]: + tc_bash, tc_web = tool_call_ids + + return [ + messages.Message( + id=f"{assistant_prefix}:assistant:0", + turn_id=turn_id, + role="assistant", + parts=[ + messages.ToolCallPart( + id=f"{assistant_prefix}:call:bash", + tool_call_id=tc_bash, + tool_name="bash", + tool_args='{"command":"date"}', + ), + messages.ToolCallPart( + id=f"{assistant_prefix}:call:web", + tool_call_id=tc_web, + tool_name="web_fetch", + tool_args='{"url":"https://httpbin.org/get"}', + ), + ], + ), + messages.Message( + id=f"{assistant_prefix}:tool:0", + turn_id=turn_id, + role="tool", + parts=[ + messages.ToolResultPart( + id=f"{assistant_prefix}:result:bash", + tool_call_id=tc_bash, + tool_name="bash", + result="Tue May 19 2026", + ), + messages.ToolResultPart( + id=f"{assistant_prefix}:result:web", + tool_call_id=tc_web, + tool_name="web_fetch", + result={"status": 200}, + ), + ], + ), + ] + + def _assert_raises_issue( msgs: list[messages.Message], issue: str, @@ -478,6 +528,23 @@ def test_duplicate_tool_results_within_same_message_raises() -> None: assert "duplicate-tool-result" in exc_info.value.issues +def test_integrity_error_includes_tool_ids_and_message_locations() -> None: + bad_history = [ + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="server"), + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="client"), + ] + + with pytest.raises(IntegrityError) as exc_info: + prepare_messages(bad_history) + + error = exc_info.value + assert "duplicate-tool-call" in str(error) + assert "duplicate-tool-result" in str(error) + assert "tc-bash" in str(error) + assert "server:assistant:0" in str(error) + assert "client:assistant:0" in str(error) + + # --------------------------------------------------------------------------- # Does not mutate input # --------------------------------------------------------------------------- From eb2d2288c24a0d4ecae084d7c30155f49d278ca3 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 19 May 2026 12:51:01 -0700 Subject: [PATCH 02/13] Fix adapter id bugs and namespace message/part ids --- src/ai/agents/ui/ai_sdk/_parts.py | 61 +++++++++ src/ai/agents/ui/ai_sdk/inbound.py | 129 +++++++++++++++++--- src/ai/agents/ui/ai_sdk/outbound/_state.py | 23 +++- src/ai/agents/ui/ai_sdk/outbound/history.py | 29 ++++- src/ai/types/messages.py | 18 +-- tests/types/test_integrity.py | 17 --- 6 files changed, 226 insertions(+), 51 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py index 3637174b..1f3f00aa 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -13,6 +13,16 @@ from ....types import messages as messages_ from . import _approvals, ui_message +_TOOL_STATE_RANK: dict[ui_message.UIToolInvocationState, int] = { + "input-streaming": 0, + "input-available": 1, + "approval-requested": 2, + "approval-responded": 3, + "output-denied": 4, + "output-error": 5, + "output-available": 6, +} + def _normalize_tool_input(raw: str) -> str | dict[str, Any]: """Parse tool args JSON string into a dict; fall back to raw string. @@ -62,6 +72,57 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: return result +def _merge_tool_part( + existing: ui_message.UIToolPart, + candidate: ui_message.UIToolPart, +) -> ui_message.UIToolPart: + """Merge duplicate UI tool parts, keeping the first display position.""" + existing_rank = _TOOL_STATE_RANK.get(existing.state, 0) + candidate_rank = _TOOL_STATE_RANK.get(candidate.state, 0) + updates: dict[str, Any] = {} + + if candidate_rank >= existing_rank: + updates["state"] = candidate.state + if candidate.state == "output-denied": + updates["output"] = None + + if existing.input is None and candidate.input is not None: + updates["input"] = candidate.input + if candidate.output is not None: + updates["output"] = candidate.output + if candidate.error_text is not None: + updates["error_text"] = candidate.error_text + if candidate.approval is not None: + updates["approval"] = candidate.approval + + return existing.model_copy(update=updates) if updates else existing + + +def dedupe_tool_parts( + ui_parts: list[ui_message.UIMessagePart], +) -> list[ui_message.UIMessagePart]: + """Collapse duplicate UIToolParts by tool_call_id.""" + result: list[ui_message.UIMessagePart] = [] + tool_index: dict[str, int] = {} + + for part in ui_parts: + if not isinstance(part, ui_message.UIToolPart): + result.append(part) + continue + + idx = tool_index.get(part.tool_call_id) + if idx is None: + tool_index[part.tool_call_id] = len(result) + result.append(part) + continue + + existing = result[idx] + if isinstance(existing, ui_message.UIToolPart): + result[idx] = _merge_tool_part(existing, part) + + return result + + def merge_tool_results( ui_parts: list[ui_message.UIMessagePart], tool_parts: list[messages_.Part], diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index e4918a22..bac251c4 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -32,6 +32,54 @@ def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: return state in _TOOL_ERROR_STATES +class _AdapterIdPolicy: + """Stable ids for reconstructing canonical messages from UI bubbles.""" + + @staticmethod + def message_id(turn_id: str, role: str, index: int) -> str: + return f"{turn_id}:{role}:{index}" + + @classmethod + def with_stable_part_ids( + cls, + message_id: str, + parts: list[messages_.Part], + ) -> list[messages_.Part]: + counts: dict[str, int] = {} + result: list[messages_.Part] = [] + for index, part in enumerate(parts): + base = cls.part_id(message_id, part, index) + seen = counts.get(base, 0) + counts[base] = seen + 1 + part_id = base if seen == 0 else f"{base}:{seen}" + result.append(part.model_copy(update={"id": part_id})) + return result + + @staticmethod + def part_id( + message_id: str, + part: messages_.Part, + index: int, + ) -> str: + match part: + case messages_.TextPart(): + return f"{message_id}:text:{index}" + case messages_.ReasoningPart(): + return f"{message_id}:reasoning:{index}" + case messages_.ToolCallPart(tool_call_id=tool_call_id): + return f"{message_id}:call:{tool_call_id}" + case messages_.ToolResultPart(tool_call_id=tool_call_id): + return f"{message_id}:result:{tool_call_id}" + case messages_.BuiltinToolCallPart(tool_call_id=tool_call_id): + return f"{message_id}:builtin-call:{tool_call_id}" + case messages_.BuiltinToolReturnPart(tool_call_id=tool_call_id): + return f"{message_id}:builtin-result:{tool_call_id}" + case messages_.HookPart(): + return f"{message_id}:hook:{index}" + case messages_.FilePart(): + return f"{message_id}:file:{index}" + + # TODO(datamodel-rework §4): once tool args have a canonical shape, drop # these normalizers. def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: @@ -295,6 +343,7 @@ def _patch_pending_hook_aborts( continue new_parts.append( messages_.ToolResultPart( + id=f"{tool_msg.id}:result:{tc.tool_call_id}:hook-pending", tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, result=f"Pending on hook '{hook.hook_id}'", @@ -443,15 +492,23 @@ def _build_result_part( if ui_msg.role == "assistant": result.extend( _split_assistant_parts( - assistant_parts, tool_result_parts, msg_id=ui_msg.id + assistant_parts, + tool_result_parts, + turn_id=ui_msg.id, ) ) - for hp in hook_parts: + for index, hp in enumerate(hook_parts): + msg_id = _AdapterIdPolicy.message_id( + ui_msg.id, "internal", index + ) result.append( messages_.Message( - id=ui_msg.id, + id=msg_id, + turn_id=ui_msg.id, role="internal", - parts=[hp], + parts=_AdapterIdPolicy.with_stable_part_ids( + msg_id, [hp] + ), ) ) else: @@ -459,7 +516,9 @@ def _build_result_part( messages_.Message( id=ui_msg.id, role=ui_msg.role, - parts=assistant_parts, + parts=_AdapterIdPolicy.with_stable_part_ids( + ui_msg.id, assistant_parts + ), ) ) @@ -469,7 +528,7 @@ def _build_result_part( def _split_assistant_parts( parts: list[messages_.Part], tool_results: list[messages_.ToolResultPart], - msg_id: str, + turn_id: str, ) -> list[messages_.Message]: """Split assistant parts into assistant + tool message pairs.""" results_by_id = {tr.tool_call_id: tr for tr in tool_results} @@ -484,13 +543,53 @@ def _split_assistant_parts( if not pending_results: if parts: - return [messages_.Message(role="assistant", parts=parts, id=msg_id)] + msg_id = _AdapterIdPolicy.message_id(turn_id, "assistant", 0) + return [ + messages_.Message( + role="assistant", + parts=_AdapterIdPolicy.with_stable_part_ids(msg_id, parts), + id=msg_id, + turn_id=turn_id, + ) + ] return [] messages: list[messages_.Message] = [] current: list[messages_.Part] = [] current_results: list[messages_.ToolResultPart] = [] seen_tool_call = False + assistant_index = 0 + tool_index = 0 + + def _append_assistant(parts_: list[messages_.Part]) -> None: + nonlocal assistant_index + msg_id = _AdapterIdPolicy.message_id( + turn_id, "assistant", assistant_index + ) + messages.append( + messages_.Message( + role="assistant", + parts=_AdapterIdPolicy.with_stable_part_ids(msg_id, parts_), + id=msg_id, + turn_id=turn_id, + ) + ) + assistant_index += 1 + + def _append_tool(parts_: list[messages_.ToolResultPart]) -> None: + nonlocal tool_index + msg_id = _AdapterIdPolicy.message_id(turn_id, "tool", tool_index) + messages.append( + messages_.Message( + role="tool", + parts=_AdapterIdPolicy.with_stable_part_ids( + msg_id, list(parts_) + ), + id=msg_id, + turn_id=turn_id, + ) + ) + tool_index += 1 for part in parts: if ( @@ -498,12 +597,8 @@ def _split_assistant_parts( and current_results and not isinstance(part, messages_.ToolCallPart) ): - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) - messages.append( - messages_.Message(role="tool", parts=list(current_results)) - ) + _append_assistant(current) + _append_tool(current_results) current = [] current_results = [] seen_tool_call = False @@ -516,12 +611,8 @@ def _split_assistant_parts( current_results.append(results_by_id[part.tool_call_id]) if current: - messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) - ) + _append_assistant(current) if current_results: - messages.append( - messages_.Message(role="tool", parts=list(current_results)) - ) + _append_tool(current_results) return messages diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index baf8ab1c..e080dc61 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -40,6 +40,15 @@ def _to_wire_output(snapshot: Any) -> Any: return snapshot +def _stream_message_id(event: events_.Event) -> str | None: + message = event.message + if message.role != "assistant": + return None + if message.turn_id is not None: + return message.turn_id + return None if message.id == "" else message.id + + class _StreamState: """Single-pass state across one ``to_stream()`` call.""" @@ -96,12 +105,16 @@ def _reset_step_tracking(self) -> None: self.emitted_tool_results.clear() self.emitted_approval_requests.clear() - def _ensure_started(self) -> list[protocol.UIMessageStreamPart]: + def _ensure_started( + self, + message_id: str | None = None, + ) -> list[protocol.UIMessageStreamPart]: """Lazily emit StartPart / StartStepPart on the first event.""" parts: list[protocol.UIMessageStreamPart] = [] if not self.emitted_start: - parts.append(protocol.StartPart(message_id=None)) + self.ui_message_id = message_id + parts.append(protocol.StartPart(message_id=self.ui_message_id)) parts.append(protocol.StartStepPart()) self.emitted_start = True self.in_step = True @@ -118,7 +131,7 @@ def on_event( # Lazily open the UI message on the first streaming event. if not self.emitted_start: - out.extend(self._ensure_started()) + out.extend(self._ensure_started(_stream_message_id(event))) match event: case events_.TextStart(block_id=pid): @@ -197,7 +210,7 @@ def on_tool_result( msg = event.message out: list[protocol.UIMessageStreamPart] = [] - out.extend(self._ensure_started()) + out.extend(self._ensure_started(msg.turn_id)) # Emit ToolInputAvailable for each tool call that triggered # these results (from the assistant message's ToolCallParts). @@ -307,7 +320,7 @@ def on_hook( out: list[protocol.UIMessageStreamPart] = [] # Ensure the UI message is started. - out.extend(self._ensure_started()) + out.extend(self._ensure_started(event.message.turn_id)) tc_id = _approvals.tool_call_id_for(hook_part) if tc_id is None: diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py index eb5a20a4..215fc233 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -10,6 +10,29 @@ from .....types import messages as messages_ +def _turn_id_from_message_id(message_id: str) -> str | None: + for marker in (":assistant:", ":tool:", ":internal:"): + if marker in message_id: + return message_id.split(marker, 1)[0] + return None + + +def _message_turn_key(message: messages_.Message) -> str | None: + return message.turn_id or _turn_id_from_message_id(message.id) + + +def _assistant_bubble_id(message: messages_.Message) -> str: + return _message_turn_key(message) or message.id + + +def _belongs_to_bubble( + message: messages_.Message, + bubble_id: str, +) -> bool: + key = _message_turn_key(message) + return key is None or key == bubble_id + + def to_ui_messages( messages: list[messages_.Message], ) -> list[ui_message.UIMessage]: @@ -39,7 +62,7 @@ def to_ui_messages( if msg.role == "assistant": ui_parts: list[ui_message.UIMessagePart] = [] - bubble_id = msg.id + bubble_id = _assistant_bubble_id(msg) while i < len(messages) and messages[i].role in ( "assistant", @@ -47,13 +70,17 @@ def to_ui_messages( "internal", ): current = messages[i] + if not _belongs_to_bubble(current, bubble_id): + break if current.role == "assistant": ui_parts.extend(_parts.to_ui_parts(current.parts)) + ui_parts = _parts.dedupe_tool_parts(ui_parts) elif current.role == "tool": _parts.merge_tool_results(ui_parts, current.parts) elif current.role == "internal": _parts.merge_approval_signals(ui_parts, current.parts) i += 1 + ui_parts = _parts.dedupe_tool_parts(ui_parts) result.append( ui_message.UIMessage( diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index bae4c3c0..217b4287 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -14,7 +14,7 @@ def generate_id(prefix: str | None = None) -> str: class TextPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) text: str provider_metadata: dict[str, Any] | None = None @@ -25,7 +25,7 @@ class TextPart(pydantic.BaseModel): class ToolResultPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str is_error: bool = False @@ -67,7 +67,7 @@ def has_model_input(self) -> bool: class ToolCallPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str tool_args: str @@ -98,7 +98,7 @@ class BuiltinToolCallPart(pydantic.BaseModel): host. Adapters emit them when a model uses a built-in tool. """ - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str tool_args: str = "" @@ -110,7 +110,7 @@ class BuiltinToolCallPart(pydantic.BaseModel): class BuiltinToolReturnPart(pydantic.BaseModel): """The provider's result for a :class:`BuiltinToolCallPart`.""" - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str result: Any = None @@ -122,7 +122,7 @@ class BuiltinToolReturnPart(pydantic.BaseModel): class ReasoningPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) text: str provider_metadata: dict[str, Any] | None = None @@ -130,7 +130,7 @@ class ReasoningPart(pydantic.BaseModel): class HookPart[T](pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) hook_id: str hook_type: str status: Literal["pending", "resolved", "cancelled"] @@ -158,7 +158,7 @@ class FilePart(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) data: str | bytes media_type: str # IANA media type, e.g. "image/png", "audio/wav" filename: str | None = None @@ -220,7 +220,7 @@ def from_bytes( class Message(pydantic.BaseModel): role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("msg")) turn_id: str | None = None usage: usage_.Usage | None = None provider_metadata: dict[str, Any] | None = None diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index e046005f..d65eecc9 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -528,23 +528,6 @@ def test_duplicate_tool_results_within_same_message_raises() -> None: assert "duplicate-tool-result" in exc_info.value.issues -def test_integrity_error_includes_tool_ids_and_message_locations() -> None: - bad_history = [ - *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="server"), - *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="client"), - ] - - with pytest.raises(IntegrityError) as exc_info: - prepare_messages(bad_history) - - error = exc_info.value - assert "duplicate-tool-call" in str(error) - assert "duplicate-tool-result" in str(error) - assert "tc-bash" in str(error) - assert "server:assistant:0" in str(error) - assert "client:assistant:0" in str(error) - - # --------------------------------------------------------------------------- # Does not mutate input # --------------------------------------------------------------------------- From 45e8c920794546c6fe462ba45bd60818f83585ae Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 11:04:04 -0700 Subject: [PATCH 03/13] Extend UI protocol support --- src/ai/agents/ui/ai_sdk/_parts.py | 188 ++++++++++++++--- src/ai/agents/ui/ai_sdk/inbound.py | 194 ++++++++++++++---- src/ai/agents/ui/ai_sdk/outbound/_state.py | 192 ++++++++++++++++- src/ai/agents/ui/ai_sdk/outbound/history.py | 28 +++ src/ai/agents/ui/ai_sdk/outbound/sse.py | 6 + src/ai/agents/ui/ai_sdk/protocol.py | 48 +++++ src/ai/agents/ui/ai_sdk/ui_message.py | 141 ++++++++++++- .../agents/ui/ai_sdk/outbound/test_history.py | 136 ++++++++++++ tests/agents/ui/ai_sdk/outbound/test_sse.py | 32 ++- .../agents/ui/ai_sdk/outbound/test_stream.py | 106 ++++++++++ tests/agents/ui/ai_sdk/test_inbound.py | 105 ++++++++++ 11 files changed, 1086 insertions(+), 90 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py index 1f3f00aa..bc21e7b0 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -8,11 +8,14 @@ from __future__ import annotations import json -from typing import Any, cast +from typing import Any, TypeGuard, cast +from ....types import media from ....types import messages as messages_ from . import _approvals, ui_message +UIToolLike = ui_message.UIToolPart | ui_message.UIDynamicToolPart + _TOOL_STATE_RANK: dict[ui_message.UIToolInvocationState, int] = { "input-streaming": 0, "input-available": 1, @@ -24,17 +27,31 @@ } -def _normalize_tool_input(raw: str) -> str | dict[str, Any]: - """Parse tool args JSON string into a dict; fall back to raw string. +def _normalize_tool_input(raw: str) -> Any: + """Parse tool args JSON string into a JSON value; fall back to raw string. TODO(datamodel-rework §4): once ``ToolCallPart.tool_args`` has a canonical shape, drop this helper. """ try: - parsed = json.loads(raw) + return json.loads(raw) except (json.JSONDecodeError, TypeError): return raw - return parsed if isinstance(parsed, dict) else raw + + +def _metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: + value = metadata.get(key) + return value if isinstance(value, bool) else None + + +def _is_tool_part(part: ui_message.UIMessagePart) -> TypeGuard[UIToolLike]: + return isinstance( + part, ui_message.UIToolPart | ui_message.UIDynamicToolPart + ) + + +def _tool_call_id(part: UIToolLike) -> str: + return part.tool_call_id def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: @@ -42,10 +59,24 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: result: list[ui_message.UIMessagePart] = [] for part in parts: if isinstance(part, messages_.TextPart) and part.text: - result.append(ui_message.UITextPart(type="text", text=part.text)) + result.append( + ui_message.UITextPart.model_validate( + { + "type": "text", + "text": part.text, + "providerMetadata": part.provider_metadata, + } + ) + ) elif isinstance(part, messages_.ReasoningPart) and part.text: result.append( - ui_message.UIReasoningPart(type="reasoning", text=part.text) + ui_message.UIReasoningPart.model_validate( + { + "type": "reasoning", + "text": part.text, + "providerMetadata": part.provider_metadata, + } + ) ) elif isinstance(part, messages_.ToolCallPart): result.append( @@ -55,6 +86,43 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: "toolCallId": part.tool_call_id, "state": "input-available", "input": _normalize_tool_input(part.tool_args), + "callProviderMetadata": part.provider_metadata, + } + ) + ) + elif isinstance(part, messages_.BuiltinToolCallPart): + result.append( + ui_message.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": _normalize_tool_input(part.tool_args), + "providerExecuted": True, + "callProviderMetadata": part.provider_metadata, + } + ) + ) + elif isinstance(part, messages_.BuiltinToolReturnPart): + result.append( + ui_message.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": ( + "output-error" + if part.is_error + else "output-available" + ), + "input": None, + "output": None if part.is_error else part.result, + "errorText": ( + str(part.result) if part.is_error else None + ), + "providerExecuted": True, + "resultProviderMetadata": part.provider_metadata, } ) ) @@ -64,8 +132,11 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: { "type": "file", "mediaType": part.media_type, - "url": part.data if isinstance(part.data, str) else "", + "url": media.data_to_data_url( + part.data, part.media_type + ), "filename": part.filename, + "providerMetadata": part.provider_metadata, } ) ) @@ -73,9 +144,9 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: def _merge_tool_part( - existing: ui_message.UIToolPart, - candidate: ui_message.UIToolPart, -) -> ui_message.UIToolPart: + existing: UIToolLike, + candidate: UIToolLike, +) -> UIToolLike: """Merge duplicate UI tool parts, keeping the first display position.""" existing_rank = _TOOL_STATE_RANK.get(existing.state, 0) candidate_rank = _TOOL_STATE_RANK.get(candidate.state, 0) @@ -90,10 +161,24 @@ def _merge_tool_part( updates["input"] = candidate.input if candidate.output is not None: updates["output"] = candidate.output + if candidate.raw_input is not None: + updates["raw_input"] = candidate.raw_input if candidate.error_text is not None: updates["error_text"] = candidate.error_text if candidate.approval is not None: updates["approval"] = candidate.approval + if candidate.provider_executed is not None: + updates["provider_executed"] = candidate.provider_executed + if candidate.call_provider_metadata is not None: + updates["call_provider_metadata"] = candidate.call_provider_metadata + if candidate.result_provider_metadata is not None: + updates["result_provider_metadata"] = candidate.result_provider_metadata + if candidate.tool_metadata is not None: + updates["tool_metadata"] = candidate.tool_metadata + if candidate.preliminary is not None: + updates["preliminary"] = candidate.preliminary + if candidate.title is not None: + updates["title"] = candidate.title return existing.model_copy(update=updates) if updates else existing @@ -106,18 +191,18 @@ def dedupe_tool_parts( tool_index: dict[str, int] = {} for part in ui_parts: - if not isinstance(part, ui_message.UIToolPart): + if not _is_tool_part(part): result.append(part) continue - idx = tool_index.get(part.tool_call_id) + idx = tool_index.get(_tool_call_id(part)) if idx is None: - tool_index[part.tool_call_id] = len(result) + tool_index[_tool_call_id(part)] = len(result) result.append(part) continue existing = result[idx] - if isinstance(existing, ui_message.UIToolPart): + if _is_tool_part(existing): result[idx] = _merge_tool_part(existing, part) return result @@ -130,30 +215,51 @@ def merge_tool_results( """Merge ToolResultParts into existing UIToolParts in-place.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx + if _is_tool_part(ui_part): + tool_index[_tool_call_id(ui_part)] = idx for part in tool_parts: - if not isinstance(part, messages_.ToolResultPart): + if isinstance(part, messages_.ToolResultPart): + tool_call_id = part.tool_call_id + state = "output-error" if part.is_error else "output-available" + updates: dict[str, Any] = { + "state": state, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + elif isinstance(part, messages_.BuiltinToolReturnPart): + tool_call_id = part.tool_call_id + updates = { + "state": ( + "output-error" if part.is_error else "output-available" + ), + "provider_executed": True, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + else: continue # Hook-abort placeholders are internal: the corresponding # HookPart(pending) carries the user-visible state via # merge_approval_signals. - if part.is_hook_pending: + if isinstance(part, messages_.ToolResultPart) and part.is_hook_pending: continue - idx_opt = tool_index.get(part.tool_call_id) + idx_opt = tool_index.get(tool_call_id) if idx_opt is None: continue idx = idx_opt existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): + if not _is_tool_part(existing): continue if existing.state == "output-denied": continue - state = "output-error" if part.is_error else "output-available" - ui_parts[idx] = existing.model_copy( - update={"state": state, "output": part.result} - ) + ui_parts[idx] = existing.model_copy(update=updates) def merge_approval_signals( @@ -163,8 +269,8 @@ def merge_approval_signals( """Merge HookPart approval state into existing UIToolParts in-place.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx + if _is_tool_part(ui_part): + tool_index[_tool_call_id(ui_part)] = idx for part in internal_parts: if not isinstance(part, messages_.HookPart): @@ -180,22 +286,38 @@ def merge_approval_signals( idx = idx_opt existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): + if not _is_tool_part(existing): continue updates: dict[str, Any] = {} + if (provider_executed := _metadata_bool( + part.metadata, "providerExecuted" + )) is not None: + updates["provider_executed"] = provider_executed if part.status == "pending": updates["state"] = "approval-requested" - updates["approval"] = ui_message.UIToolApproval(id=part.hook_id) + updates["approval"] = ui_message.UIToolApproval.model_validate( + { + "id": part.hook_id, + "isAutomatic": _metadata_bool( + part.metadata, "isAutomatic" + ), + } + ) elif part.status == "resolved": resolution = cast( "dict[str, Any]", part.resolution if isinstance(part.resolution, dict) else {}, ) - updates["approval"] = ui_message.UIToolApproval( - id=part.hook_id, - approved=resolution.get("granted"), - reason=resolution.get("reason"), + updates["approval"] = ui_message.UIToolApproval.model_validate( + { + "id": part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + "isAutomatic": _metadata_bool( + part.metadata, "isAutomatic" + ), + } ) if resolution.get("granted", False): updates["state"] = "approval-responded" diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index bac251c4..d4ad85cd 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -82,15 +82,37 @@ def part_id( # TODO(datamodel-rework §4): once tool args have a canonical shape, drop # these normalizers. -def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: - """Normalize tool input (JSON string, dict, or None) to a JSON string.""" +def _normalize_tool_args(tool_input: Any) -> str: + """Normalize tool input (JSON string, JSON value, or None) to a string.""" match tool_input: case str(): return tool_input - case dict(): - return json.dumps(tool_input) - case _: + case None: return "{}" + case _: + return json.dumps(tool_input) + + +def _tool_input_for_args( + part: ui_message.UIToolPart | ui_message.UIDynamicToolPart, +) -> Any: + if part.state == "output-error" and part.input is None: + return part.raw_input + return part.input + + +def _tool_result_output( + part: ui_message.UIToolPart | ui_message.UIDynamicToolPart, +) -> Any: + if part.state == "output-error": + return _error_result(part.error_text, part.output) + if part.state == "output-denied": + reason = part.approval.reason if part.approval is not None else None + return { + "type": "error-text", + "value": reason or "Tool call execution denied.", + } + return part.output def _normalize_tool_result(output: Any) -> dict[str, Any] | None: @@ -132,18 +154,20 @@ def _decode_wire_output(output: Any) -> Any: def _approval_hook_part( - tp: ui_message.UIToolPart, + tp: ui_message.UIToolPart | ui_message.UIDynamicToolPart, ) -> messages_.HookPart[Any] | None: """Reconstruct approval hook state from a UI tool part when possible.""" approval = tp.approval if approval is None: return None + metadata = _approval_metadata(tp) if tp.state == "approval-requested": return messages_.HookPart( hook_id=approval.id, hook_type="ToolApproval", status="pending", + metadata=metadata, ) if tp.state == "approval-responded" and approval.approved is not None: @@ -151,6 +175,7 @@ def _approval_hook_part( hook_id=approval.id, hook_type="ToolApproval", status="resolved", + metadata=metadata, resolution={ "granted": approval.approved, "reason": approval.reason, @@ -162,6 +187,7 @@ def _approval_hook_part( hook_id=approval.id, hook_type="ToolApproval", status="resolved", + metadata=metadata, resolution={ "granted": False, "reason": approval.reason, @@ -171,6 +197,19 @@ def _approval_hook_part( return None +def _approval_metadata( + tp: ui_message.UIToolPart | ui_message.UIDynamicToolPart, +) -> dict[str, Any]: + metadata: dict[str, Any] = {} + if tp.approval is not None and tp.approval.is_automatic is not None: + metadata["isAutomatic"] = tp.approval.is_automatic + if tp.provider_executed is not None: + metadata["providerExecuted"] = tp.provider_executed + if tp.call_provider_metadata is not None: + metadata["callProviderMetadata"] = tp.call_provider_metadata + return metadata + + # ============================================================================ # Approval extraction + bulk resolution # ============================================================================ @@ -195,7 +234,9 @@ def extract_approvals( approvals: list[ApprovalResponse] = [] for ui_msg in ui_messages: for part in ui_msg.parts: - if not isinstance(part, ui_message.UIToolPart): + if not isinstance( + part, ui_message.UIToolPart | ui_message.UIDynamicToolPart + ): continue if ( part.state == "approval-responded" @@ -238,7 +279,9 @@ def _normalize_ui_messages( for part in message.parts: part_type = getattr(part, "type", None) state = getattr(part, "state", None) - if isinstance(part_type, str) and part_type.startswith("tool-"): + if isinstance(part_type, str) and ( + part_type.startswith("tool-") or part_type == "dynamic-tool" + ): output = getattr(part, "output", None) approval = getattr(part, "approval", None) approved = approval.approved if approval is not None else None @@ -394,6 +437,22 @@ def _build_result_part( is_error=is_error, ) + def _build_builtin_return_part( + *, + tool_call_id: str, + tool_name: str, + output: Any, + is_error: bool, + provider_metadata: dict[str, Any] | None, + ) -> messages_.BuiltinToolReturnPart: + return messages_.BuiltinToolReturnPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=output, + is_error=is_error, + provider_metadata=provider_metadata, + ) + result: list[messages_.Message] = [] for ui_msg in ui_messages: @@ -404,62 +463,119 @@ def _build_result_part( for part in ui_msg.parts: match part: case ui_message.UITextPart(text=text) if text: - assistant_parts.append(messages_.TextPart(text=text)) + assistant_parts.append( + messages_.TextPart( + text=text, + provider_metadata=part.provider_metadata, + ) + ) case ui_message.UIReasoningPart(text=reasoning) if reasoning: assistant_parts.append( - messages_.ReasoningPart(text=reasoning) + messages_.ReasoningPart( + text=reasoning, + provider_metadata=part.provider_metadata, + ) ) case ui_message.UIToolInvocationPart() as inv: tool_args = json.dumps(inv.args) if inv.args else "{}" - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - tool_args=tool_args, + if inv.provider_executed: + assistant_parts.append( + messages_.BuiltinToolCallPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + tool_args=tool_args, + ) ) - ) - if _is_tool_completed(inv.state): - tool_result_parts.append( - _build_result_part( + if _is_tool_completed(inv.state): + assistant_parts.append( + _build_builtin_return_part( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + output=inv.result, + is_error=_is_tool_error(inv.state), + provider_metadata=None, + ) + ) + else: + assistant_parts.append( + messages_.ToolCallPart( tool_call_id=inv.tool_invocation_id, tool_name=inv.tool_name, - output=inv.result, - is_error=_is_tool_error(inv.state), + tool_args=tool_args, ) ) + if _is_tool_completed(inv.state): + tool_result_parts.append( + _build_result_part( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + output=inv.result, + is_error=_is_tool_error(inv.state), + ) + ) - case ui_message.UIToolPart() as tp: - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - tool_args=_normalize_tool_args(tp.input), + case ( + ui_message.UIToolPart() | ui_message.UIDynamicToolPart() + ) as tp: + tool_input = _tool_input_for_args(tp) + tool_args = _normalize_tool_args(tool_input) + + if tp.provider_executed: + assistant_parts.append( + messages_.BuiltinToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=tool_args, + provider_metadata=tp.call_provider_metadata, + ) + ) + else: + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=tool_args, + provider_metadata=tp.call_provider_metadata, + ) ) - ) approval_hook = _approval_hook_part(tp) if approval_hook is not None: hook_parts.append(approval_hook) - if tp.state in _TOOL_RESULT_STATES: - tool_result_parts.append( - _build_result_part( + if tp.provider_executed and _is_tool_completed(tp.state): + assistant_parts.append( + _build_builtin_return_part( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, - output=tp.output, - is_error=False, + output=_tool_result_output(tp), + is_error=_is_tool_error(tp.state), + provider_metadata=( + tp.result_provider_metadata + or tp.call_provider_metadata + ), ) ) - elif tp.state == "output-error": + elif tp.state in _TOOL_RESULT_STATES: tool_result_parts.append( - messages_.ToolResultPart( + _build_result_part( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, - result=_error_result(tp.error_text, tp.output), - is_error=True, + output=_tool_result_output(tp), + is_error=_is_tool_error(tp.state), ) ) + if tp.result_provider_metadata is not None: + tool_result_parts[-1] = tool_result_parts[ + -1 + ].model_copy( + update={ + "provider_metadata": ( + tp.result_provider_metadata + ) + } + ) case ui_message.UIFilePart() as fp: assistant_parts.append( @@ -467,6 +583,7 @@ def _build_result_part( data=fp.url, media_type=fp.media_type, filename=fp.filename, + provider_metadata=fp.provider_metadata, ) ) @@ -474,6 +591,9 @@ def _build_result_part( ui_message.UIStepStartPart() | ui_message.UISourceUrlPart() | ui_message.UISourceDocumentPart() + | ui_message.UIReasoningFilePart() + | ui_message.UICustomPart() + | ui_message.UIDataPart() ): pass diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index e080dc61..7d480718 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -2,9 +2,11 @@ from __future__ import annotations +import json from typing import Any from .....types import events as events_ +from .....types import media from .....types import messages as messages_ from ....agent import MessageBundle from .. import _approvals, protocol @@ -23,6 +25,26 @@ def _tool_error_text(part: messages_.ToolResultPart) -> str: return "Tool execution failed" +def _normalize_tool_input(raw: str) -> Any: + try: + return json.loads(raw) + except Exception: + return raw + + +def _metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: + value = metadata.get(key) + return value if isinstance(value, bool) else None + + +def _metadata_dict( + metadata: dict[str, Any], + key: str, +) -> dict[str, Any] | None: + value = metadata.get(key) + return value if isinstance(value, dict) else None + + def _to_wire_output(snapshot: Any) -> Any: """Convert an aggregator snapshot to its UI wire representation. @@ -136,37 +158,79 @@ def on_event( match event: case events_.TextStart(block_id=pid): self.open_text_ids.add(pid) - out.append(protocol.TextStartPart(id=pid)) + out.append( + protocol.TextStartPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) case events_.TextDelta(block_id=pid, chunk=chunk): if pid not in self.open_text_ids: self.open_text_ids.add(pid) - out.append(protocol.TextStartPart(id=pid)) + out.append( + protocol.TextStartPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) self.text_delta_ids.add(pid) - out.append(protocol.TextDeltaPart(id=pid, delta=chunk)) + out.append( + protocol.TextDeltaPart( + id=pid, + delta=chunk, + provider_metadata=event.provider_metadata, + ) + ) case events_.TextEnd(block_id=pid): if pid in self.open_text_ids: self.open_text_ids.discard(pid) self.completed_text_ids.add(pid) - out.append(protocol.TextEndPart(id=pid)) + out.append( + protocol.TextEndPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) case events_.ReasoningStart(block_id=pid): self.open_reasoning_ids.add(pid) - out.append(protocol.ReasoningStartPart(id=pid)) + out.append( + protocol.ReasoningStartPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) case events_.ReasoningDelta(block_id=pid, chunk=chunk): if pid not in self.open_reasoning_ids: self.open_reasoning_ids.add(pid) - out.append(protocol.ReasoningStartPart(id=pid)) + out.append( + protocol.ReasoningStartPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) self.reasoning_delta_ids.add(pid) - out.append(protocol.ReasoningDeltaPart(id=pid, delta=chunk)) + out.append( + protocol.ReasoningDeltaPart( + id=pid, + delta=chunk, + provider_metadata=event.provider_metadata, + ) + ) case events_.ReasoningEnd(block_id=pid): if pid in self.open_reasoning_ids: self.open_reasoning_ids.discard(pid) self.completed_reasoning_ids.add(pid) - out.append(protocol.ReasoningEndPart(id=pid)) + out.append( + protocol.ReasoningEndPart( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) case events_.ToolStart(tool_call_id=tcid, tool_name=name): self.tool_names[tcid] = name @@ -177,6 +241,7 @@ def on_event( protocol.ToolInputStartPart( tool_call_id=tcid, tool_name=name, + provider_metadata=event.provider_metadata, ) ) @@ -187,6 +252,7 @@ def on_event( protocol.ToolInputStartPart( tool_call_id=tcid, tool_name=self.tool_names.get(tcid, ""), + provider_metadata=event.provider_metadata, ) ) out.append( @@ -199,6 +265,94 @@ def on_event( case events_.ToolEnd(): pass + case events_.BuiltinToolStart(tool_call_id=tcid, tool_name=name): + self.tool_names[tcid] = name + if tcid in self.started_tool_inputs: + return out + self.started_tool_inputs.add(tcid) + out.append( + protocol.ToolInputStartPart( + tool_call_id=tcid, + tool_name=name, + provider_executed=True, + provider_metadata=event.provider_metadata, + dynamic=True, + ) + ) + + case events_.BuiltinToolDelta(tool_call_id=tcid, chunk=chunk): + if tcid not in self.started_tool_inputs: + self.started_tool_inputs.add(tcid) + out.append( + protocol.ToolInputStartPart( + tool_call_id=tcid, + tool_name=self.tool_names.get(tcid, ""), + provider_executed=True, + provider_metadata=event.provider_metadata, + dynamic=True, + ) + ) + out.append( + protocol.ToolInputDeltaPart( + tool_call_id=tcid, + input_text_delta=chunk, + ) + ) + + case events_.BuiltinToolEnd(tool_call_id=tcid, tool_call=tc): + if tcid not in self.input_available_emitted: + self.input_available_emitted.add(tcid) + out.append( + protocol.ToolInputAvailablePart( + tool_call_id=tcid, + tool_name=tc.tool_name, + input=_normalize_tool_input(tc.tool_args), + provider_executed=True, + provider_metadata=tc.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + + case events_.BuiltinToolResult(tool_call_id=tcid, result=result): + if tcid in self.emitted_tool_results: + return out + self.emitted_tool_results.add(tcid) + if result.is_error: + out.append( + protocol.ToolOutputErrorPart( + tool_call_id=tcid, + error_text=str(result.result), + provider_executed=True, + provider_metadata=result.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + else: + out.append( + protocol.ToolOutputAvailablePart( + tool_call_id=tcid, + output=result.result, + provider_executed=True, + provider_metadata=result.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + + case events_.FileEvent( + media_type=media_type, + data=data, + ): + out.append( + protocol.FilePart( + url=media.data_to_data_url(data, media_type), + media_type=media_type, + provider_metadata=event.provider_metadata, + ) + ) + return out # -- phase: tool results ------------------------------------------------ @@ -225,13 +379,15 @@ def on_tool_result( protocol.ToolInputStartPart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, + provider_metadata=part.provider_metadata, ) ) out.append( protocol.ToolInputAvailablePart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, - input=part.tool_args, + input=_normalize_tool_input(part.tool_args), + provider_metadata=part.provider_metadata, ) ) @@ -249,6 +405,7 @@ def on_tool_result( protocol.ToolOutputErrorPart( tool_call_id=part.tool_call_id, error_text=_tool_error_text(part), + provider_metadata=part.provider_metadata, ) ) else: @@ -263,6 +420,7 @@ def on_tool_result( protocol.ToolOutputAvailablePart( tool_call_id=part.tool_call_id, output=wire_output, + provider_metadata=part.provider_metadata, ) ) @@ -334,10 +492,26 @@ def on_hook( protocol.ToolApprovalRequestPart( approval_id=hook_part.hook_id, tool_call_id=tc_id, + is_automatic=_metadata_bool( + hook_part.metadata, "isAutomatic" + ), ) ) elif hook_part.status == "resolved": resolution: dict[str, Any] = hook_part.resolution or {} + out.append( + protocol.ToolApprovalResponsePart( + approval_id=hook_part.hook_id, + approved=bool(resolution.get("granted")), + reason=resolution.get("reason"), + provider_executed=_metadata_bool( + hook_part.metadata, "providerExecuted" + ), + provider_metadata=_metadata_dict( + hook_part.metadata, "callProviderMetadata" + ), + ) + ) if not resolution.get("granted"): out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) elif hook_part.status == "cancelled": diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py index 215fc233..5dc289c0 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -9,6 +9,9 @@ if TYPE_CHECKING: from .....types import messages as messages_ +_ADAPTER_METADATA_KEY = "aiPython" +_SOURCE_MESSAGES_KEY = "sourceMessages" + def _turn_id_from_message_id(message_id: str) -> str | None: for marker in (":assistant:", ":tool:", ":internal:"): @@ -33,6 +36,27 @@ def _belongs_to_bubble( return key is None or key == bubble_id +def _source_message_entry(message: messages_.Message) -> dict[str, object]: + return { + "id": message.id, + "role": message.role, + "turnId": message.turn_id, + "partIds": [part.id for part in message.parts], + } + + +def _adapter_metadata( + source_messages: list[messages_.Message], +) -> dict[str, object]: + return { + _ADAPTER_METADATA_KEY: { + _SOURCE_MESSAGES_KEY: [ + _source_message_entry(message) for message in source_messages + ] + } + } + + def to_ui_messages( messages: list[messages_.Message], ) -> list[ui_message.UIMessage]: @@ -54,6 +78,7 @@ def to_ui_messages( ui_message.UIMessage( id=msg.id, role=msg.role, + metadata=_adapter_metadata([msg]), parts=_parts.to_ui_parts(msg.parts), ) ) @@ -62,6 +87,7 @@ def to_ui_messages( if msg.role == "assistant": ui_parts: list[ui_message.UIMessagePart] = [] + source_messages: list[messages_.Message] = [] bubble_id = _assistant_bubble_id(msg) while i < len(messages) and messages[i].role in ( @@ -72,6 +98,7 @@ def to_ui_messages( current = messages[i] if not _belongs_to_bubble(current, bubble_id): break + source_messages.append(current) if current.role == "assistant": ui_parts.extend(_parts.to_ui_parts(current.parts)) ui_parts = _parts.dedupe_tool_parts(ui_parts) @@ -86,6 +113,7 @@ def to_ui_messages( ui_message.UIMessage( id=bubble_id, role="assistant", + metadata=_adapter_metadata(source_messages), parts=ui_parts, ) ) diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index 88207c54..b859bc2d 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -51,9 +51,15 @@ def format_sse(part: protocol.UIMessageStreamPart) -> str: return f"data: {serialize_part(part)}\n\n" +def format_done_sse() -> str: + """Format the AI SDK UI stream termination marker.""" + return "data: [DONE]\n\n" + + async def to_sse( events: AsyncIterable[events_.AgentEvent], ) -> AsyncGenerator[str]: """Convert an internal event stream into SSE strings.""" async for part in to_stream(events): yield format_sse(part) + yield format_done_sse() diff --git a/src/ai/agents/ui/ai_sdk/protocol.py b/src/ai/agents/ui/ai_sdk/protocol.py index e0eb8902..c8db7b3e 100644 --- a/src/ai/agents/ui/ai_sdk/protocol.py +++ b/src/ai/agents/ui/ai_sdk/protocol.py @@ -97,6 +97,15 @@ class ReasoningEndPart: provider_metadata: dict[str, Any] | None = None +@dataclasses.dataclass +class CustomPart: + """Provider-specific content that does not fit standard UI parts.""" + + kind: str + type: Literal["custom"] = dataclasses.field(default="custom", init=False) + provider_metadata: dict[str, Any] | None = None + + @dataclasses.dataclass class SourceUrlPart: """References to external URLs.""" @@ -134,6 +143,18 @@ class FilePart: provider_metadata: dict[str, Any] | None = None +@dataclasses.dataclass +class ReasoningFilePart: + """A file generated as part of model reasoning.""" + + url: str + media_type: str + type: Literal["reasoning-file"] = dataclasses.field( + default="reasoning-file", init=False + ) + provider_metadata: dict[str, Any] | None = None + + @dataclasses.dataclass class DataPart: """Custom data part for arbitrary structured data. @@ -166,6 +187,8 @@ class ToolInputStartPart: default="tool-input-start", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @@ -193,6 +216,7 @@ class ToolInputAvailablePart: ) provider_executed: bool | None = None provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @@ -210,6 +234,7 @@ class ToolInputErrorPart: ) provider_executed: bool | None = None provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @@ -224,6 +249,8 @@ class ToolOutputAvailablePart: default="tool-output-available", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None preliminary: bool | None = None @@ -238,6 +265,8 @@ class ToolOutputErrorPart: default="tool-output-error", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None @@ -260,6 +289,21 @@ class ToolApprovalRequestPart: type: Literal["tool-approval-request"] = dataclasses.field( default="tool-approval-request", init=False ) + is_automatic: bool | None = None + + +@dataclasses.dataclass +class ToolApprovalResponsePart: + """Records an approval decision for a tool call.""" + + approval_id: str + approved: bool + type: Literal["tool-approval-response"] = dataclasses.field( + default="tool-approval-response", init=False + ) + reason: str | None = None + provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None @dataclasses.dataclass @@ -294,6 +338,7 @@ class AbortPart: """Indicates the message was aborted.""" type: Literal["abort"] = dataclasses.field(default="abort", init=False) + reason: str | None = None @dataclasses.dataclass @@ -322,9 +367,11 @@ class ErrorPart: | ReasoningStartPart | ReasoningDeltaPart | ReasoningEndPart + | CustomPart | SourceUrlPart | SourceDocumentPart | FilePart + | ReasoningFilePart | DataPart | ToolInputStartPart | ToolInputDeltaPart @@ -334,6 +381,7 @@ class ErrorPart: | ToolOutputErrorPart | ToolOutputDeniedPart | ToolApprovalRequestPart + | ToolApprovalResponsePart | StartStepPart | FinishStepPart | FinishPart diff --git a/src/ai/agents/ui/ai_sdk/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_message.py index ea3ccd0f..f48cef34 100644 --- a/src/ai/agents/ui/ai_sdk/ui_message.py +++ b/src/ai/agents/ui/ai_sdk/ui_message.py @@ -14,12 +14,20 @@ from ....types import messages as messages_ +_UI_MODEL_CONFIG = pydantic.ConfigDict(populate_by_name=True, extra="allow") + class UITextPart(pydantic.BaseModel): """Text content part in AI SDK v6 format.""" + model_config = _UI_MODEL_CONFIG + type: Literal["text"] text: str + state: Literal["streaming", "done"] | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UIReasoningPart(pydantic.BaseModel): @@ -31,9 +39,26 @@ class UIReasoningPart(pydantic.BaseModel): we accept it but don't currently route on it. """ + model_config = _UI_MODEL_CONFIG + type: Literal["reasoning"] text: str state: Literal["streaming", "done"] | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UICustomPart(pydantic.BaseModel): + """Provider-specific content that does not fit standard UI parts.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["custom"] + kind: str + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) # Tool invocation states in AI SDK v6: @@ -65,7 +90,7 @@ class UIToolInvocationPart(pydantic.BaseModel): Reference: https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["tool-invocation"] tool_invocation_id: str = pydantic.Field(alias="toolInvocationId") @@ -73,11 +98,16 @@ class UIToolInvocationPart(pydantic.BaseModel): args: dict[str, Any] = pydantic.Field(default_factory=dict) state: UIToolInvocationState = "input-available" result: Any | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) class UIStepStartPart(pydantic.BaseModel): """Step boundary marker. Skipped during conversion to internal format.""" + model_config = _UI_MODEL_CONFIG + type: Literal["step-start"] @@ -89,11 +119,14 @@ class UIToolApproval(pydantic.BaseModel): ``approved`` is None while awaiting a response, True/False after. """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG id: str approved: bool | None = None reason: str | None = None + is_automatic: bool | None = pydantic.Field( + default=None, alias="isAutomatic" + ) class UIToolPart(pydantic.BaseModel): @@ -103,7 +136,7 @@ class UIToolPart(pydantic.BaseModel): where the tool name is embedded in the type string. """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG # The actual type string (e.g., "tool-talk_to_mothership") # We store this to extract the tool name @@ -112,8 +145,23 @@ class UIToolPart(pydantic.BaseModel): state: UIToolInvocationState input: str | dict[str, Any] | None = None # JSON string or parsed dict output: Any | None = None + raw_input: Any | None = pydantic.Field(default=None, alias="rawInput") error_text: str | None = pydantic.Field(default=None, alias="errorText") approval: UIToolApproval | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) + call_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="callProviderMetadata" + ) + result_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="resultProviderMetadata" + ) + tool_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="toolMetadata" + ) + preliminary: bool | None = None + title: str | None = None @property def tool_name(self) -> str: @@ -126,61 +174,134 @@ def tool_name(self) -> str: return self.type +class UIDynamicToolPart(pydantic.BaseModel): + """Dynamic tool part where the tool name is a field, not the type suffix.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["dynamic-tool"] + tool_name: str = pydantic.Field(alias="toolName") + tool_call_id: str = pydantic.Field(alias="toolCallId") + state: UIToolInvocationState + input: Any | None = None + output: Any | None = None + raw_input: Any | None = pydantic.Field(default=None, alias="rawInput") + error_text: str | None = pydantic.Field(default=None, alias="errorText") + approval: UIToolApproval | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) + call_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="callProviderMetadata" + ) + result_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="resultProviderMetadata" + ) + tool_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="toolMetadata" + ) + preliminary: bool | None = None + title: str | None = None + + class UIFilePart(pydantic.BaseModel): """File part. TODO: FilePart not yet supported in core messages.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["file"] media_type: str = pydantic.Field(alias="mediaType") url: str filename: str | None = None + provider_reference: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerReference" + ) + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UIReasoningFilePart(pydantic.BaseModel): + """Reasoning file part generated as part of model reasoning.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["reasoning-file"] + media_type: str = pydantic.Field(alias="mediaType") + url: str + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UISourceUrlPart(pydantic.BaseModel): """Source URL part. TODO: SourceUrlPart not yet supported.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["source-url"] source_id: str = pydantic.Field(alias="sourceId") url: str title: str | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UISourceDocumentPart(pydantic.BaseModel): """Source document part. TODO: SourceDocumentPart not yet supported.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["source-document"] source_id: str = pydantic.Field(alias="sourceId") media_type: str = pydantic.Field(alias="mediaType") title: str filename: str | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UIDataPart(pydantic.BaseModel): + """Custom data part with a dynamic ``data-*`` type.""" + + model_config = _UI_MODEL_CONFIG + + type: str + id: str | None = None + data: Any + transient: bool | None = None # Union of all supported part types (used for type hints) UIMessagePart = ( UITextPart | UIReasoningPart + | UICustomPart | UIToolInvocationPart | UIStepStartPart | UIToolPart + | UIDynamicToolPart | UIFilePart + | UIReasoningFilePart | UISourceUrlPart | UISourceDocumentPart + | UIDataPart ) _STATIC_UI_PART_TYPES: dict[str, type[pydantic.BaseModel]] = { "text": UITextPart, "reasoning": UIReasoningPart, + "custom": UICustomPart, "tool-invocation": UIToolInvocationPart, "step-start": UIStepStartPart, "file": UIFilePart, + "reasoning-file": UIReasoningFilePart, "source-url": UISourceUrlPart, "source-document": UISourceDocumentPart, + "dynamic-tool": UIDynamicToolPart, } @@ -198,9 +319,8 @@ def _parse_ui_part(part_data: dict[str, Any]) -> UIMessagePart | None: case str() as t if t.startswith("tool-"): # Dynamic tool type: tool-{toolName} (e.g., "tool-get_weather") return UIToolPart.model_validate(part_data) - case str() as t if t.startswith("data-") or t == "dynamic-tool": - # TODO: data-{name} and dynamic-tool not yet supported - return None + case str() as t if t.startswith("data-"): + return UIDataPart.model_validate(part_data) case _: # Unknown part type - skip gracefully return None @@ -212,12 +332,13 @@ class UIMessage(pydantic.BaseModel): Reference: https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG id: str = pydantic.Field( default_factory=lambda: messages_.generate_id("msg") ) role: Literal["user", "assistant", "system"] + metadata: Any | None = None parts: list[UIMessagePart] = pydantic.Field(default_factory=list) @pydantic.field_validator("parts", mode="before") diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py index 8f701373..b91bbe50 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -5,6 +5,8 @@ from ai.agents.ui import ai_sdk from ai.agents.ui.ai_sdk import to_ui_messages from ai.agents.ui.ai_sdk.ui_message import ( + UIDynamicToolPart, + UIFilePart, UITextPart, UIToolPart, ) @@ -174,6 +176,72 @@ def test_to_ui_messages_merges_assistant_tool_internal() -> None: assert ui_msg.parts[2].text == "done" +def test_to_ui_messages_records_source_messages_in_metadata() -> None: + msgs = [ + messages_.Message( + id="turn-1:assistant:0", + turn_id="turn-1", + role="assistant", + parts=[ + messages_.TextPart(id="text-0", text="calling"), + messages_.ToolCallPart( + id="call-0", + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ), + ], + ), + messages_.Message( + id="turn-1:tool:0", + turn_id="turn-1", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-0", + tool_call_id="tc1", + tool_name="search", + result={"hits": 2}, + ) + ], + ), + messages_.Message( + id="turn-1:assistant:1", + turn_id="turn-1", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="done")], + ), + ] + + [ui_msg] = to_ui_messages(msgs) + + assert ui_msg.id == "turn-1" + assert ui_msg.metadata == { + "aiPython": { + "sourceMessages": [ + { + "id": "turn-1:assistant:0", + "role": "assistant", + "turnId": "turn-1", + "partIds": ["text-0", "call-0"], + }, + { + "id": "turn-1:tool:0", + "role": "tool", + "turnId": "turn-1", + "partIds": ["result-0"], + }, + { + "id": "turn-1:assistant:1", + "role": "assistant", + "turnId": "turn-1", + "partIds": ["text-1"], + }, + ] + } + } + + def test_to_ui_messages_internal_role_merges_approval() -> None: msgs = [ messages_.Message( @@ -235,6 +303,74 @@ def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: assert result[0].id == "a1" +def test_to_ui_messages_preserves_provider_metadata_and_files() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.TextPart( + text="hello", + provider_metadata={"provider": {"text": True}}, + ), + messages_.FilePart( + data=b"abc", + media_type="image/png", + filename="image.png", + provider_metadata={"provider": {"file": True}}, + ), + ], + ) + ] + + result = to_ui_messages(msgs) + + text_part = result[0].parts[0] + assert isinstance(text_part, UITextPart) + assert text_part.provider_metadata == {"provider": {"text": True}} + + file_part = result[0].parts[1] + assert isinstance(file_part, UIFilePart) + assert file_part.url == "data:image/png;base64,YWJj" + assert file_part.filename == "image.png" + assert file_part.provider_metadata == {"provider": {"file": True}} + + +def test_to_ui_messages_maps_builtin_tools_to_dynamic_parts() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.BuiltinToolCallPart( + tool_call_id="tc1", + tool_name="web_search", + tool_args='{"q":"ai"}', + provider_metadata={"provider": {"call": True}}, + ), + messages_.BuiltinToolReturnPart( + tool_call_id="tc1", + tool_name="web_search", + result={"hits": 1}, + provider_metadata={"provider": {"result": True}}, + ), + ], + ) + ] + + result = to_ui_messages(msgs) + + assert len(result[0].parts) == 1 + tool_part = result[0].parts[0] + assert isinstance(tool_part, UIDynamicToolPart) + assert tool_part.provider_executed is True + assert tool_part.state == "output-available" + assert tool_part.input == {"q": "ai"} + assert tool_part.output == {"hits": 1} + assert tool_part.call_provider_metadata == {"provider": {"call": True}} + assert tool_part.result_provider_metadata == {"provider": {"result": True}} + + def test_common_id_upsert_persistence_is_idempotent_after_reload() -> None: store = IdUpsertStore() diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 95d520ef..9f3a0cf7 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -4,7 +4,11 @@ from collections.abc import AsyncGenerator from ai.agents.ui.ai_sdk import protocol, to_sse -from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part +from ai.agents.ui.ai_sdk.outbound.sse import ( + format_done_sse, + format_sse, + serialize_part, +) from ai.types import events as agent_events_ from ai.types import events as events_ @@ -29,6 +33,31 @@ def test_serialize_data_part_uses_type_with_prefix() -> None: assert "dataType" not in payload +def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: + part = protocol.ToolApprovalResponsePart( + approval_id="approval-1", + approved=False, + reason="no", + provider_executed=True, + provider_metadata={"provider": {"k": "v"}}, + ) + + payload = json.loads(serialize_part(part)) + + assert payload == { + "type": "tool-approval-response", + "approvalId": "approval-1", + "approved": False, + "reason": "no", + "providerExecuted": True, + "providerMetadata": {"provider": {"k": "v"}}, + } + + +def test_format_done_sse_returns_done_sentinel() -> None: + assert format_done_sse() == "data: [DONE]\n\n" + + async def _gen( stream_events: list[agent_events_.AgentEvent], ) -> AsyncGenerator[agent_events_.AgentEvent]: @@ -53,3 +82,4 @@ async def test_to_sse_emits_data_prefixed_lines() -> None: # first line is the start part (lazy open) first = json.loads(lines[0].removeprefix("data: ").rstrip()) assert first["type"] == "start" + assert lines[-1] == "data: [DONE]\n\n" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index ebaf85da..4e18fdd2 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -259,6 +259,112 @@ async def test_partial_tool_result_without_factory_is_skipped() -> None: assert not any(isinstance(p, protocol.ToolOutputAvailablePart) for p in out) +async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: + out = await _collect( + [ + events_.BuiltinToolStart( + tool_call_id="tc1", + tool_name="web_search", + provider_metadata={"provider": {"start": True}}, + ), + events_.BuiltinToolDelta(tool_call_id="tc1", chunk='{"q":"ai"}'), + events_.BuiltinToolEnd( + tool_call_id="tc1", + tool_call=messages_.BuiltinToolCallPart( + tool_call_id="tc1", + tool_name="web_search", + tool_args='{"q":"ai"}', + provider_metadata={"provider": {"call": True}}, + ), + ), + events_.BuiltinToolResult( + tool_call_id="tc1", + result=messages_.BuiltinToolReturnPart( + tool_call_id="tc1", + tool_name="web_search", + result={"hits": 1}, + provider_metadata={"provider": {"result": True}}, + ), + ), + ] + ) + + start = next(p for p in out if isinstance(p, protocol.ToolInputStartPart)) + assert start.provider_executed is True + assert start.dynamic is True + assert start.provider_metadata == {"provider": {"start": True}} + + available = next( + p for p in out if isinstance(p, protocol.ToolInputAvailablePart) + ) + assert available.provider_executed is True + assert available.dynamic is True + assert available.input == {"q": "ai"} + assert available.provider_metadata == {"provider": {"call": True}} + + result = next( + p for p in out if isinstance(p, protocol.ToolOutputAvailablePart) + ) + assert result.provider_executed is True + assert result.dynamic is True + assert result.output == {"hits": 1} + assert result.provider_metadata == {"provider": {"result": True}} + + +async def test_file_event_emits_file_part_with_data_url_and_metadata() -> None: + out = await _collect( + [ + events_.FileEvent( + media_type="image/png", + data=b"abc", + provider_metadata={"provider": {"file": True}}, + ) + ] + ) + + file_part = next(p for p in out if isinstance(p, protocol.FilePart)) + assert file_part.url == "data:image/png;base64,YWJj" + assert file_part.media_type == "image/png" + assert file_part.provider_metadata == {"provider": {"file": True}} + + +async def test_resolved_approval_hook_emits_response_part() -> None: + hook = messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + metadata={ + "providerExecuted": True, + "callProviderMetadata": {"provider": {"approval": True}}, + }, + resolution={"granted": False, "reason": "not allowed"}, + ) + + out = await _collect( + [ + agent_events_.HookEvent( + message=messages_.Message( + id="turn-1:internal:0", + turn_id="turn-1", + role="internal", + parts=[hook], + ), + hook=hook, + ) + ] + ) + + response = next( + p for p in out if isinstance(p, protocol.ToolApprovalResponsePart) + ) + assert response.approval_id == "approve_tc1" + assert response.approved is False + assert response.reason == "not allowed" + assert response.provider_executed is True + assert response.provider_metadata == {"provider": {"approval": True}} + assert any(isinstance(p, protocol.ToolOutputDeniedPart) for p in out) + + # NOTE: agent-change boundary detection used to be driven by # Message.source_label. That field has been removed; agent-change # routing in the AI SDK adapter now needs to come from diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py index 64f8a556..feb8b7d2 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -160,6 +160,36 @@ def test_extract_approvals_returns_approved_responses() -> None: assert approvals[0].reason == "nope" +def test_extract_approvals_handles_dynamic_tool_responses() -> None: + approvals = extract_approvals( + [ + _ui( + "assistant", + { + "type": "dynamic-tool", + "toolName": "web_search", + "toolCallId": "tc1", + "state": "approval-responded", + "input": {"query": "ai"}, + "approval": { + "id": "approve_tc1", + "approved": True, + "reason": "ok", + "isAutomatic": True, + }, + "providerExecuted": True, + }, + ) + ] + ) + + assert len(approvals) == 1 + assert approvals[0].hook_id == "approve_tc1" + assert approvals[0].granted is True + assert approvals[0].reason == "ok" + assert approvals[0].tool_call_id == "tc1" + + def test_normalize_ui_messages_heals_stale_tool_state() -> None: ui = [ _ui( @@ -237,3 +267,78 @@ def test_to_messages_passthrough_keeps_wire_shape() -> None: part = tool_msgs[0].tool_results[0] assert part.result == {"pong": True} assert part.get_model_input() == {"pong": True} + + +def test_to_messages_accepts_metadata_and_ui_only_parts() -> None: + ui = [ + UIMessage.model_validate( + { + "id": "a1", + "role": "assistant", + "metadata": {"trace": "t1"}, + "parts": [ + {"type": "custom", "kind": "openai.compaction"}, + { + "type": "data-weather", + "id": "weather-1", + "data": {"status": "loading"}, + }, + { + "type": "source-url", + "sourceId": "src-1", + "url": "https://example.com", + }, + { + "type": "reasoning-file", + "mediaType": "image/png", + "url": "data:image/png;base64,AAAA", + }, + { + "type": "text", + "text": "visible", + "providerMetadata": {"provider": {"k": "v"}}, + }, + ], + } + ) + ] + + messages, approvals = to_messages(ui) + + assert approvals == [] + assert len(messages) == 1 + assert messages[0].text == "visible" + text = messages[0].parts[0] + assert isinstance(text, messages_.TextPart) + assert text.provider_metadata == {"provider": {"k": "v"}} + + +def test_to_messages_dynamic_provider_executed_tool_becomes_builtin() -> None: + messages, _ = to_messages( + [ + _ui( + "assistant", + { + "type": "dynamic-tool", + "toolName": "web_search", + "toolCallId": "tc1", + "state": "output-available", + "input": {"query": "ai"}, + "output": [{"title": "result"}], + "providerExecuted": True, + "callProviderMetadata": {"provider": {"call": 1}}, + "resultProviderMetadata": {"provider": {"result": 1}}, + }, + ) + ] + ) + + assert len(messages) == 1 + assert messages[0].role == "assistant" + [call] = messages[0].builtin_tool_calls + [result] = messages[0].builtin_tool_returns + assert call.tool_name == "web_search" + assert call.tool_args == '{"query": "ai"}' + assert call.provider_metadata == {"provider": {"call": 1}} + assert result.result == [{"title": "result"}] + assert result.provider_metadata == {"provider": {"result": 1}} From ed2208132d8a27c9b8e4ed68ea9eab0cdcb6ae96 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 11:16:23 -0700 Subject: [PATCH 04/13] Wire message id information through ui message metadata --- src/ai/agents/ui/ai_sdk/_roundtrip.py | 150 ++++++++++++++++++ src/ai/agents/ui/ai_sdk/inbound.py | 123 ++++---------- src/ai/agents/ui/ai_sdk/outbound/history.py | 39 +---- .../agents/ui/ai_sdk/outbound/test_history.py | 80 ++++++++++ 4 files changed, 264 insertions(+), 128 deletions(-) create mode 100644 src/ai/agents/ui/ai_sdk/_roundtrip.py diff --git a/src/ai/agents/ui/ai_sdk/_roundtrip.py b/src/ai/agents/ui/ai_sdk/_roundtrip.py new file mode 100644 index 00000000..56d23581 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/_roundtrip.py @@ -0,0 +1,150 @@ +"""Roundtrip metadata for preserving internal message identity.""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Literal, cast + +if TYPE_CHECKING: + from ....types import messages as messages_ + +ADAPTER_METADATA_KEY = "aiPython" +SOURCE_MESSAGES_KEY = "sourceMessages" + +MessageRole = Literal["user", "assistant", "system", "tool", "internal"] +_VALID_ROLES = {"user", "assistant", "system", "tool", "internal"} + + +@dataclasses.dataclass(frozen=True) +class SourceMessage: + id: str + role: MessageRole + turn_id: str | None + part_ids: tuple[str, ...] + + +def source_message_entry(message: messages_.Message) -> dict[str, object]: + return { + "id": message.id, + "role": message.role, + "turnId": message.turn_id, + "partIds": [part.id for part in message.parts], + } + + +def metadata_for( + source_messages: list[messages_.Message], +) -> dict[str, object]: + return { + ADAPTER_METADATA_KEY: { + SOURCE_MESSAGES_KEY: [ + source_message_entry(message) for message in source_messages + ] + } + } + + +def source_messages_from(metadata: object) -> list[SourceMessage]: + if not isinstance(metadata, dict): + return [] + + adapter_metadata = metadata.get(ADAPTER_METADATA_KEY) + if not isinstance(adapter_metadata, dict): + return [] + + raw_source_messages = adapter_metadata.get(SOURCE_MESSAGES_KEY) + if not isinstance(raw_source_messages, list): + return [] + + result: list[SourceMessage] = [] + for raw in raw_source_messages: + source = _parse_source_message(raw) + if source is not None: + result.append(source) + return result + + +def restore_source_ids( + messages: list[messages_.Message], + source_messages: list[SourceMessage], +) -> list[messages_.Message]: + if not source_messages: + return messages + + restored: list[messages_.Message] = [] + source_index = 0 + + for message in messages: + match_index = _find_next_source( + source_messages, + role=message.role, + start=source_index, + ) + if match_index is None: + restored.append(message) + continue + + source = source_messages[match_index] + source_index = match_index + 1 + restored.append(_restore_message_ids(message, source)) + + return restored + + +def _parse_source_message(raw: object) -> SourceMessage | None: + if not isinstance(raw, dict): + return None + + message_id = raw.get("id") + role = raw.get("role") + if not isinstance(message_id, str) or role not in _VALID_ROLES: + return None + + raw_turn_id = raw.get("turnId") + turn_id = raw_turn_id if isinstance(raw_turn_id, str) else None + + raw_part_ids = raw.get("partIds") + part_ids = ( + tuple(part_id for part_id in raw_part_ids if isinstance(part_id, str)) + if isinstance(raw_part_ids, list) + else () + ) + + return SourceMessage( + id=message_id, + role=cast("MessageRole", role), + turn_id=turn_id, + part_ids=part_ids, + ) + + +def _find_next_source( + source_messages: list[SourceMessage], + *, + role: MessageRole, + start: int, +) -> int | None: + for index in range(start, len(source_messages)): + if source_messages[index].role == role: + return index + return None + + +def _restore_message_ids( + message: messages_.Message, + source: SourceMessage, +) -> messages_.Message: + updates: dict[str, object] = { + "id": source.id, + "turn_id": source.turn_id, + } + + if len(source.part_ids) == len(message.parts): + updates["parts"] = [ + part.model_copy(update={"id": part_id}) + for part, part_id in zip( + message.parts, source.part_ids, strict=True + ) + ] + + return message.model_copy(update=updates) diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index d4ad85cd..3273d899 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -13,7 +13,7 @@ from ....types import messages as messages_ from ...agent import MessageBundle from ...hooks import resolve_hook -from . import ui_message +from . import _roundtrip, ui_message logger = logging.getLogger(__name__) @@ -32,54 +32,6 @@ def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: return state in _TOOL_ERROR_STATES -class _AdapterIdPolicy: - """Stable ids for reconstructing canonical messages from UI bubbles.""" - - @staticmethod - def message_id(turn_id: str, role: str, index: int) -> str: - return f"{turn_id}:{role}:{index}" - - @classmethod - def with_stable_part_ids( - cls, - message_id: str, - parts: list[messages_.Part], - ) -> list[messages_.Part]: - counts: dict[str, int] = {} - result: list[messages_.Part] = [] - for index, part in enumerate(parts): - base = cls.part_id(message_id, part, index) - seen = counts.get(base, 0) - counts[base] = seen + 1 - part_id = base if seen == 0 else f"{base}:{seen}" - result.append(part.model_copy(update={"id": part_id})) - return result - - @staticmethod - def part_id( - message_id: str, - part: messages_.Part, - index: int, - ) -> str: - match part: - case messages_.TextPart(): - return f"{message_id}:text:{index}" - case messages_.ReasoningPart(): - return f"{message_id}:reasoning:{index}" - case messages_.ToolCallPart(tool_call_id=tool_call_id): - return f"{message_id}:call:{tool_call_id}" - case messages_.ToolResultPart(tool_call_id=tool_call_id): - return f"{message_id}:result:{tool_call_id}" - case messages_.BuiltinToolCallPart(tool_call_id=tool_call_id): - return f"{message_id}:builtin-call:{tool_call_id}" - case messages_.BuiltinToolReturnPart(tool_call_id=tool_call_id): - return f"{message_id}:builtin-result:{tool_call_id}" - case messages_.HookPart(): - return f"{message_id}:hook:{index}" - case messages_.FilePart(): - return f"{message_id}:file:{index}" - - # TODO(datamodel-rework §4): once tool args have a canonical shape, drop # these normalizers. def _normalize_tool_args(tool_input: Any) -> str: @@ -386,7 +338,6 @@ def _patch_pending_hook_aborts( continue new_parts.append( messages_.ToolResultPart( - id=f"{tool_msg.id}:result:{tc.tool_call_id}:hook-pending", tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, result=f"Pending on hook '{hook.hook_id}'", @@ -456,6 +407,7 @@ def _build_builtin_return_part( result: list[messages_.Message] = [] for ui_msg in ui_messages: + source_messages = _roundtrip.source_messages_from(ui_msg.metadata) assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] hook_parts: list[messages_.HookPart[Any]] = [] @@ -517,8 +469,11 @@ def _build_builtin_return_part( ) case ( - ui_message.UIToolPart() | ui_message.UIDynamicToolPart() - ) as tp: + ( + ui_message.UIToolPart() + | ui_message.UIDynamicToolPart() + ) as tp + ): tool_input = _tool_input_for_args(tp) tool_args = _normalize_tool_args(tool_input) @@ -610,35 +565,33 @@ def _build_builtin_return_part( # LLM APIs expect one message per iteration, so split into # assistant + tool message pairs at tool-result boundaries. if ui_msg.role == "assistant": - result.extend( - _split_assistant_parts( - assistant_parts, - tool_result_parts, - turn_id=ui_msg.id, - ) + parsed = _split_assistant_parts( + assistant_parts, + tool_result_parts, + turn_id=ui_msg.id, ) - for index, hp in enumerate(hook_parts): - msg_id = _AdapterIdPolicy.message_id( - ui_msg.id, "internal", index - ) - result.append( + for hp in hook_parts: + parsed.append( messages_.Message( - id=msg_id, turn_id=ui_msg.id, role="internal", - parts=_AdapterIdPolicy.with_stable_part_ids( - msg_id, [hp] - ), + parts=[hp], ) ) + result.extend( + _roundtrip.restore_source_ids(parsed, source_messages) + ) else: - result.append( - messages_.Message( - id=ui_msg.id, - role=ui_msg.role, - parts=_AdapterIdPolicy.with_stable_part_ids( - ui_msg.id, assistant_parts - ), + result.extend( + _roundtrip.restore_source_ids( + [ + messages_.Message( + id=ui_msg.id, + role=ui_msg.role, + parts=assistant_parts, + ) + ], + source_messages, ) ) @@ -663,12 +616,10 @@ def _split_assistant_parts( if not pending_results: if parts: - msg_id = _AdapterIdPolicy.message_id(turn_id, "assistant", 0) return [ messages_.Message( role="assistant", - parts=_AdapterIdPolicy.with_stable_part_ids(msg_id, parts), - id=msg_id, + parts=parts, turn_id=turn_id, ) ] @@ -678,38 +629,24 @@ def _split_assistant_parts( current: list[messages_.Part] = [] current_results: list[messages_.ToolResultPart] = [] seen_tool_call = False - assistant_index = 0 - tool_index = 0 def _append_assistant(parts_: list[messages_.Part]) -> None: - nonlocal assistant_index - msg_id = _AdapterIdPolicy.message_id( - turn_id, "assistant", assistant_index - ) messages.append( messages_.Message( role="assistant", - parts=_AdapterIdPolicy.with_stable_part_ids(msg_id, parts_), - id=msg_id, + parts=parts_, turn_id=turn_id, ) ) - assistant_index += 1 def _append_tool(parts_: list[messages_.ToolResultPart]) -> None: - nonlocal tool_index - msg_id = _AdapterIdPolicy.message_id(turn_id, "tool", tool_index) messages.append( messages_.Message( role="tool", - parts=_AdapterIdPolicy.with_stable_part_ids( - msg_id, list(parts_) - ), - id=msg_id, + parts=list(parts_), turn_id=turn_id, ) ) - tool_index += 1 for part in parts: if ( diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py index 5dc289c0..51fe953f 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -4,24 +4,14 @@ from typing import TYPE_CHECKING -from .. import _parts, ui_message +from .. import _parts, _roundtrip, ui_message if TYPE_CHECKING: from .....types import messages as messages_ -_ADAPTER_METADATA_KEY = "aiPython" -_SOURCE_MESSAGES_KEY = "sourceMessages" - - -def _turn_id_from_message_id(message_id: str) -> str | None: - for marker in (":assistant:", ":tool:", ":internal:"): - if marker in message_id: - return message_id.split(marker, 1)[0] - return None - def _message_turn_key(message: messages_.Message) -> str | None: - return message.turn_id or _turn_id_from_message_id(message.id) + return message.turn_id def _assistant_bubble_id(message: messages_.Message) -> str: @@ -36,27 +26,6 @@ def _belongs_to_bubble( return key is None or key == bubble_id -def _source_message_entry(message: messages_.Message) -> dict[str, object]: - return { - "id": message.id, - "role": message.role, - "turnId": message.turn_id, - "partIds": [part.id for part in message.parts], - } - - -def _adapter_metadata( - source_messages: list[messages_.Message], -) -> dict[str, object]: - return { - _ADAPTER_METADATA_KEY: { - _SOURCE_MESSAGES_KEY: [ - _source_message_entry(message) for message in source_messages - ] - } - } - - def to_ui_messages( messages: list[messages_.Message], ) -> list[ui_message.UIMessage]: @@ -78,7 +47,7 @@ def to_ui_messages( ui_message.UIMessage( id=msg.id, role=msg.role, - metadata=_adapter_metadata([msg]), + metadata=_roundtrip.metadata_for([msg]), parts=_parts.to_ui_parts(msg.parts), ) ) @@ -113,7 +82,7 @@ def to_ui_messages( ui_message.UIMessage( id=bubble_id, role="assistant", - metadata=_adapter_metadata(source_messages), + metadata=_roundtrip.metadata_for(source_messages), parts=ui_parts, ) ) diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py index b91bbe50..b896cfce 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -371,6 +371,86 @@ def test_to_ui_messages_maps_builtin_tools_to_dynamic_parts() -> None: assert tool_part.result_provider_metadata == {"provider": {"result": True}} +def test_collapsed_assistant_turn_roundtrips_internal_ids() -> None: + original = [ + messages_.Message( + id="assistant-alpha", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-alpha", text="calling first"), + messages_.ToolCallPart( + id="call-alpha", + tool_call_id="tc-first", + tool_name="search", + tool_args='{"q":"first"}', + ), + ], + ), + messages_.Message( + id="tool-beta", + turn_id="turn-arbitrary", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-beta", + tool_call_id="tc-first", + tool_name="search", + result={"hits": 1}, + ) + ], + ), + messages_.Message( + id="assistant-gamma", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-gamma", text="calling second"), + messages_.ToolCallPart( + id="call-gamma", + tool_call_id="tc-second", + tool_name="lookup", + tool_args='{"id":2}', + ), + ], + ), + messages_.Message( + id="tool-delta", + turn_id="turn-arbitrary", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-delta", + tool_call_id="tc-second", + tool_name="lookup", + result={"value": 2}, + ) + ], + ), + messages_.Message( + id="assistant-epsilon", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-epsilon", text="all done"), + ], + ), + ] + + [ui_msg] = ai_sdk.to_ui_messages(original) + roundtripped, approvals = ai_sdk.to_messages([ui_msg]) + + assert approvals == [] + assert ui_msg.role == "assistant" + assert ui_msg.id == "turn-arbitrary" + assert [m.role for m in roundtripped] == [m.role for m in original] + assert [m.id for m in roundtripped] == [m.id for m in original] + assert [m.turn_id for m in roundtripped] == [m.turn_id for m in original] + assert [[p.id for p in m.parts] for m in roundtripped] == [ + [p.id for p in m.parts] for m in original + ] + + def test_common_id_upsert_persistence_is_idempotent_after_reload() -> None: store = IdUpsertStore() From 2a18633559c3d0b247a1f0832616f6859e540770 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 16:01:09 -0700 Subject: [PATCH 05/13] Rename ui_message.py to ui_messages.py and protocol.py to ui_events.py --- src/ai/agents/ui/ai_sdk/__init__.py | 4 +- src/ai/agents/ui/ai_sdk/_parts.py | 40 ++++---- src/ai/agents/ui/ai_sdk/inbound.py | 55 +++++------ src/ai/agents/ui/ai_sdk/outbound/_state.py | 96 +++++++++---------- src/ai/agents/ui/ai_sdk/outbound/history.py | 12 +-- src/ai/agents/ui/ai_sdk/outbound/sse.py | 8 +- src/ai/agents/ui/ai_sdk/outbound/stream.py | 4 +- .../ui/ai_sdk/{protocol.py => ui_events.py} | 0 .../ai_sdk/{ui_message.py => ui_messages.py} | 0 .../agents/ui/ai_sdk/outbound/test_history.py | 2 +- tests/agents/ui/ai_sdk/outbound/test_sse.py | 10 +- .../agents/ui/ai_sdk/outbound/test_stream.py | 47 ++++----- tests/agents/ui/ai_sdk/test_inbound.py | 2 +- tests/agents/ui/ai_sdk/test_parts.py | 2 +- 14 files changed, 143 insertions(+), 139 deletions(-) rename src/ai/agents/ui/ai_sdk/{protocol.py => ui_events.py} (100%) rename src/ai/agents/ui/ai_sdk/{ui_message.py => ui_messages.py} (100%) diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py index 711e7892..d71c41d1 100644 --- a/src/ai/agents/ui/ai_sdk/__init__.py +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -7,8 +7,8 @@ to_messages, ) from .outbound import to_sse, to_stream, to_ui_messages -from .protocol import UI_MESSAGE_STREAM_HEADERS -from .ui_message import UIMessage +from .ui_events import UI_MESSAGE_STREAM_HEADERS +from .ui_messages import UIMessage __all__ = [ "UI_MESSAGE_STREAM_HEADERS", diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py index bc21e7b0..312c90fa 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/_parts.py @@ -12,11 +12,11 @@ from ....types import media from ....types import messages as messages_ -from . import _approvals, ui_message +from . import _approvals, ui_messages -UIToolLike = ui_message.UIToolPart | ui_message.UIDynamicToolPart +UIToolLike = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart -_TOOL_STATE_RANK: dict[ui_message.UIToolInvocationState, int] = { +_TOOL_STATE_RANK: dict[ui_messages.UIToolInvocationState, int] = { "input-streaming": 0, "input-available": 1, "approval-requested": 2, @@ -44,9 +44,9 @@ def _metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: return value if isinstance(value, bool) else None -def _is_tool_part(part: ui_message.UIMessagePart) -> TypeGuard[UIToolLike]: +def _is_tool_part(part: ui_messages.UIMessagePart) -> TypeGuard[UIToolLike]: return isinstance( - part, ui_message.UIToolPart | ui_message.UIDynamicToolPart + part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart ) @@ -54,13 +54,13 @@ def _tool_call_id(part: UIToolLike) -> str: return part.tool_call_id -def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: +def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: """Convert internal Part objects to UIMessagePart objects.""" - result: list[ui_message.UIMessagePart] = [] + result: list[ui_messages.UIMessagePart] = [] for part in parts: if isinstance(part, messages_.TextPart) and part.text: result.append( - ui_message.UITextPart.model_validate( + ui_messages.UITextPart.model_validate( { "type": "text", "text": part.text, @@ -70,7 +70,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: ) elif isinstance(part, messages_.ReasoningPart) and part.text: result.append( - ui_message.UIReasoningPart.model_validate( + ui_messages.UIReasoningPart.model_validate( { "type": "reasoning", "text": part.text, @@ -80,7 +80,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: ) elif isinstance(part, messages_.ToolCallPart): result.append( - ui_message.UIToolPart.model_validate( + ui_messages.UIToolPart.model_validate( { "type": f"tool-{part.tool_name}", "toolCallId": part.tool_call_id, @@ -92,7 +92,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: ) elif isinstance(part, messages_.BuiltinToolCallPart): result.append( - ui_message.UIDynamicToolPart.model_validate( + ui_messages.UIDynamicToolPart.model_validate( { "type": "dynamic-tool", "toolName": part.tool_name, @@ -106,7 +106,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: ) elif isinstance(part, messages_.BuiltinToolReturnPart): result.append( - ui_message.UIDynamicToolPart.model_validate( + ui_messages.UIDynamicToolPart.model_validate( { "type": "dynamic-tool", "toolName": part.tool_name, @@ -128,7 +128,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: ) elif isinstance(part, messages_.FilePart): result.append( - ui_message.UIFilePart.model_validate( + ui_messages.UIFilePart.model_validate( { "type": "file", "mediaType": part.media_type, @@ -184,10 +184,10 @@ def _merge_tool_part( def dedupe_tool_parts( - ui_parts: list[ui_message.UIMessagePart], -) -> list[ui_message.UIMessagePart]: + ui_parts: list[ui_messages.UIMessagePart], +) -> list[ui_messages.UIMessagePart]: """Collapse duplicate UIToolParts by tool_call_id.""" - result: list[ui_message.UIMessagePart] = [] + result: list[ui_messages.UIMessagePart] = [] tool_index: dict[str, int] = {} for part in ui_parts: @@ -209,7 +209,7 @@ def dedupe_tool_parts( def merge_tool_results( - ui_parts: list[ui_message.UIMessagePart], + ui_parts: list[ui_messages.UIMessagePart], tool_parts: list[messages_.Part], ) -> None: """Merge ToolResultParts into existing UIToolParts in-place.""" @@ -263,7 +263,7 @@ def merge_tool_results( def merge_approval_signals( - ui_parts: list[ui_message.UIMessagePart], + ui_parts: list[ui_messages.UIMessagePart], internal_parts: list[messages_.Part], ) -> None: """Merge HookPart approval state into existing UIToolParts in-place.""" @@ -296,7 +296,7 @@ def merge_approval_signals( updates["provider_executed"] = provider_executed if part.status == "pending": updates["state"] = "approval-requested" - updates["approval"] = ui_message.UIToolApproval.model_validate( + updates["approval"] = ui_messages.UIToolApproval.model_validate( { "id": part.hook_id, "isAutomatic": _metadata_bool( @@ -309,7 +309,7 @@ def merge_approval_signals( "dict[str, Any]", part.resolution if isinstance(part.resolution, dict) else {}, ) - updates["approval"] = ui_message.UIToolApproval.model_validate( + updates["approval"] = ui_messages.UIToolApproval.model_validate( { "id": part.hook_id, "approved": resolution.get("granted"), diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound.py index 3273d899..41870e8f 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound.py @@ -13,7 +13,8 @@ from ....types import messages as messages_ from ...agent import MessageBundle from ...hooks import resolve_hook -from . import _roundtrip, ui_message +from . import _roundtrip +from . import ui_messages as ui_messages_ logger = logging.getLogger(__name__) @@ -24,11 +25,11 @@ ) -def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: +def _is_tool_completed(state: ui_messages_.UIToolInvocationState) -> bool: return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES -def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: +def _is_tool_error(state: ui_messages_.UIToolInvocationState) -> bool: return state in _TOOL_ERROR_STATES @@ -46,7 +47,7 @@ def _normalize_tool_args(tool_input: Any) -> str: def _tool_input_for_args( - part: ui_message.UIToolPart | ui_message.UIDynamicToolPart, + part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> Any: if part.state == "output-error" and part.input is None: return part.raw_input @@ -54,7 +55,7 @@ def _tool_input_for_args( def _tool_result_output( - part: ui_message.UIToolPart | ui_message.UIDynamicToolPart, + part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> Any: if part.state == "output-error": return _error_result(part.error_text, part.output) @@ -98,7 +99,7 @@ def _decode_wire_output(output: Any) -> Any: if output.get("role") != "assistant" or "parts" not in output: return output try: - ui_msg = ui_message.UIMessage.model_validate(output) + ui_msg = ui_messages_.UIMessage.model_validate(output) except Exception: return output inner = list(_parse([ui_msg])) @@ -106,7 +107,7 @@ def _decode_wire_output(output: Any) -> Any: def _approval_hook_part( - tp: ui_message.UIToolPart | ui_message.UIDynamicToolPart, + tp: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> messages_.HookPart[Any] | None: """Reconstruct approval hook state from a UI tool part when possible.""" approval = tp.approval @@ -150,7 +151,7 @@ def _approval_hook_part( def _approval_metadata( - tp: ui_message.UIToolPart | ui_message.UIDynamicToolPart, + tp: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> dict[str, Any]: metadata: dict[str, Any] = {} if tp.approval is not None and tp.approval.is_automatic is not None: @@ -177,7 +178,7 @@ class ApprovalResponse(NamedTuple): def extract_approvals( - ui_messages: list[ui_message.UIMessage], + ui_messages: list[ui_messages_.UIMessage], ) -> list[ApprovalResponse]: """Return every approval response found in *ui_messages*. @@ -187,7 +188,7 @@ def extract_approvals( for ui_msg in ui_messages: for part in ui_msg.parts: if not isinstance( - part, ui_message.UIToolPart | ui_message.UIDynamicToolPart + part, ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart ): continue if ( @@ -221,10 +222,10 @@ def apply_approvals(approvals: list[ApprovalResponse]) -> None: def _normalize_ui_messages( - ui_messages: list[ui_message.UIMessage], -) -> list[ui_message.UIMessage]: + ui_messages: list[ui_messages_.UIMessage], +) -> list[ui_messages_.UIMessage]: """Heal stale tool-part states from persisted assistant history.""" - normalized: list[ui_message.UIMessage] = [] + normalized: list[ui_messages_.UIMessage] = [] for message in ui_messages: new_parts = [] changed = False @@ -270,7 +271,7 @@ def _normalize_ui_messages( def to_messages( - ui_messages: list[ui_message.UIMessage], + ui_messages: list[ui_messages_.UIMessage], ) -> tuple[list[messages_.Message], list[ApprovalResponse]]: """Parse a UI request into runtime messages + extracted approvals. @@ -363,7 +364,7 @@ def _is_approval_response(msg: messages_.Message) -> bool: def _parse( - ui_messages: list[ui_message.UIMessage], + ui_messages: list[ui_messages_.UIMessage], ) -> list[messages_.Message]: def _build_result_part( *, @@ -414,7 +415,7 @@ def _build_builtin_return_part( for part in ui_msg.parts: match part: - case ui_message.UITextPart(text=text) if text: + case ui_messages_.UITextPart(text=text) if text: assistant_parts.append( messages_.TextPart( text=text, @@ -422,7 +423,7 @@ def _build_builtin_return_part( ) ) - case ui_message.UIReasoningPart(text=reasoning) if reasoning: + case ui_messages_.UIReasoningPart(text=reasoning) if reasoning: assistant_parts.append( messages_.ReasoningPart( text=reasoning, @@ -430,7 +431,7 @@ def _build_builtin_return_part( ) ) - case ui_message.UIToolInvocationPart() as inv: + case ui_messages_.UIToolInvocationPart() as inv: tool_args = json.dumps(inv.args) if inv.args else "{}" if inv.provider_executed: assistant_parts.append( @@ -470,8 +471,8 @@ def _build_builtin_return_part( case ( ( - ui_message.UIToolPart() - | ui_message.UIDynamicToolPart() + ui_messages_.UIToolPart() + | ui_messages_.UIDynamicToolPart() ) as tp ): tool_input = _tool_input_for_args(tp) @@ -532,7 +533,7 @@ def _build_builtin_return_part( } ) - case ui_message.UIFilePart() as fp: + case ui_messages_.UIFilePart() as fp: assistant_parts.append( messages_.FilePart( data=fp.url, @@ -543,12 +544,12 @@ def _build_builtin_return_part( ) case ( - ui_message.UIStepStartPart() - | ui_message.UISourceUrlPart() - | ui_message.UISourceDocumentPart() - | ui_message.UIReasoningFilePart() - | ui_message.UICustomPart() - | ui_message.UIDataPart() + ui_messages_.UIStepStartPart() + | ui_messages_.UISourceUrlPart() + | ui_messages_.UISourceDocumentPart() + | ui_messages_.UIReasoningFilePart() + | ui_messages_.UICustomPart() + | ui_messages_.UIDataPart() ): pass diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index 7d480718..146b13fa 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -9,7 +9,7 @@ from .....types import media from .....types import messages as messages_ from ....agent import MessageBundle -from .. import _approvals, protocol +from .. import _approvals, ui_events from . import history @@ -101,22 +101,22 @@ def __init__(self) -> None: # -- boundary helpers ---------------------------------------------------- - def _close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: - parts: list[protocol.UIMessageStreamPart] = [] + def _close_open_blocks(self) -> list[ui_events.UIMessageStreamPart]: + parts: list[ui_events.UIMessageStreamPart] = [] for rid in list(self.open_reasoning_ids): - parts.append(protocol.ReasoningEndPart(id=rid)) + parts.append(ui_events.ReasoningEndPart(id=rid)) self.completed_reasoning_ids.add(rid) self.open_reasoning_ids.clear() for tid in list(self.open_text_ids): - parts.append(protocol.TextEndPart(id=tid)) + parts.append(ui_events.TextEndPart(id=tid)) self.completed_text_ids.add(tid) self.open_text_ids.clear() return parts - def _finish_step(self) -> list[protocol.UIMessageStreamPart]: + def _finish_step(self) -> list[ui_events.UIMessageStreamPart]: parts = self._close_open_blocks() if self.in_step: - parts.append(protocol.FinishStepPart()) + parts.append(ui_events.FinishStepPart()) self.in_step = False return parts @@ -130,14 +130,14 @@ def _reset_step_tracking(self) -> None: def _ensure_started( self, message_id: str | None = None, - ) -> list[protocol.UIMessageStreamPart]: + ) -> list[ui_events.UIMessageStreamPart]: """Lazily emit StartPart / StartStepPart on the first event.""" - parts: list[protocol.UIMessageStreamPart] = [] + parts: list[ui_events.UIMessageStreamPart] = [] if not self.emitted_start: self.ui_message_id = message_id - parts.append(protocol.StartPart(message_id=self.ui_message_id)) - parts.append(protocol.StartStepPart()) + parts.append(ui_events.StartPart(message_id=self.ui_message_id)) + parts.append(ui_events.StartStepPart()) self.emitted_start = True self.in_step = True self._reset_step_tracking() @@ -148,8 +148,8 @@ def _ensure_started( def on_event( self, event: events_.Event - ) -> list[protocol.UIMessageStreamPart]: - out: list[protocol.UIMessageStreamPart] = [] + ) -> list[ui_events.UIMessageStreamPart]: + out: list[ui_events.UIMessageStreamPart] = [] # Lazily open the UI message on the first streaming event. if not self.emitted_start: @@ -159,7 +159,7 @@ def on_event( case events_.TextStart(block_id=pid): self.open_text_ids.add(pid) out.append( - protocol.TextStartPart( + ui_events.TextStartPart( id=pid, provider_metadata=event.provider_metadata, ) @@ -169,14 +169,14 @@ def on_event( if pid not in self.open_text_ids: self.open_text_ids.add(pid) out.append( - protocol.TextStartPart( + ui_events.TextStartPart( id=pid, provider_metadata=event.provider_metadata, ) ) self.text_delta_ids.add(pid) out.append( - protocol.TextDeltaPart( + ui_events.TextDeltaPart( id=pid, delta=chunk, provider_metadata=event.provider_metadata, @@ -188,7 +188,7 @@ def on_event( self.open_text_ids.discard(pid) self.completed_text_ids.add(pid) out.append( - protocol.TextEndPart( + ui_events.TextEndPart( id=pid, provider_metadata=event.provider_metadata, ) @@ -197,7 +197,7 @@ def on_event( case events_.ReasoningStart(block_id=pid): self.open_reasoning_ids.add(pid) out.append( - protocol.ReasoningStartPart( + ui_events.ReasoningStartPart( id=pid, provider_metadata=event.provider_metadata, ) @@ -207,14 +207,14 @@ def on_event( if pid not in self.open_reasoning_ids: self.open_reasoning_ids.add(pid) out.append( - protocol.ReasoningStartPart( + ui_events.ReasoningStartPart( id=pid, provider_metadata=event.provider_metadata, ) ) self.reasoning_delta_ids.add(pid) out.append( - protocol.ReasoningDeltaPart( + ui_events.ReasoningDeltaPart( id=pid, delta=chunk, provider_metadata=event.provider_metadata, @@ -226,7 +226,7 @@ def on_event( self.open_reasoning_ids.discard(pid) self.completed_reasoning_ids.add(pid) out.append( - protocol.ReasoningEndPart( + ui_events.ReasoningEndPart( id=pid, provider_metadata=event.provider_metadata, ) @@ -238,7 +238,7 @@ def on_event( return out self.started_tool_inputs.add(tcid) out.append( - protocol.ToolInputStartPart( + ui_events.ToolInputStartPart( tool_call_id=tcid, tool_name=name, provider_metadata=event.provider_metadata, @@ -249,14 +249,14 @@ def on_event( if tcid not in self.started_tool_inputs: self.started_tool_inputs.add(tcid) out.append( - protocol.ToolInputStartPart( + ui_events.ToolInputStartPart( tool_call_id=tcid, tool_name=self.tool_names.get(tcid, ""), provider_metadata=event.provider_metadata, ) ) out.append( - protocol.ToolInputDeltaPart( + ui_events.ToolInputDeltaPart( tool_call_id=tcid, input_text_delta=chunk, ) @@ -271,7 +271,7 @@ def on_event( return out self.started_tool_inputs.add(tcid) out.append( - protocol.ToolInputStartPart( + ui_events.ToolInputStartPart( tool_call_id=tcid, tool_name=name, provider_executed=True, @@ -284,7 +284,7 @@ def on_event( if tcid not in self.started_tool_inputs: self.started_tool_inputs.add(tcid) out.append( - protocol.ToolInputStartPart( + ui_events.ToolInputStartPart( tool_call_id=tcid, tool_name=self.tool_names.get(tcid, ""), provider_executed=True, @@ -293,7 +293,7 @@ def on_event( ) ) out.append( - protocol.ToolInputDeltaPart( + ui_events.ToolInputDeltaPart( tool_call_id=tcid, input_text_delta=chunk, ) @@ -303,7 +303,7 @@ def on_event( if tcid not in self.input_available_emitted: self.input_available_emitted.add(tcid) out.append( - protocol.ToolInputAvailablePart( + ui_events.ToolInputAvailablePart( tool_call_id=tcid, tool_name=tc.tool_name, input=_normalize_tool_input(tc.tool_args), @@ -320,7 +320,7 @@ def on_event( self.emitted_tool_results.add(tcid) if result.is_error: out.append( - protocol.ToolOutputErrorPart( + ui_events.ToolOutputErrorPart( tool_call_id=tcid, error_text=str(result.result), provider_executed=True, @@ -331,7 +331,7 @@ def on_event( ) else: out.append( - protocol.ToolOutputAvailablePart( + ui_events.ToolOutputAvailablePart( tool_call_id=tcid, output=result.result, provider_executed=True, @@ -346,7 +346,7 @@ def on_event( data=data, ): out.append( - protocol.FilePart( + ui_events.FilePart( url=media.data_to_data_url(data, media_type), media_type=media_type, provider_metadata=event.provider_metadata, @@ -359,10 +359,10 @@ def on_event( def on_tool_result( self, event: events_.ToolCallResult - ) -> list[protocol.UIMessageStreamPart]: + ) -> list[ui_events.UIMessageStreamPart]: """Handle a ``ToolCallResult`` — emit tool input/output parts.""" msg = event.message - out: list[protocol.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamPart] = [] out.extend(self._ensure_started(msg.turn_id)) @@ -376,14 +376,14 @@ def on_tool_result( if part.tool_call_id not in self.started_tool_inputs: self.started_tool_inputs.add(part.tool_call_id) out.append( - protocol.ToolInputStartPart( + ui_events.ToolInputStartPart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, provider_metadata=part.provider_metadata, ) ) out.append( - protocol.ToolInputAvailablePart( + ui_events.ToolInputAvailablePart( tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=_normalize_tool_input(part.tool_args), @@ -402,7 +402,7 @@ def on_tool_result( self.emitted_tool_results.add(part.tool_call_id) if part.is_error: out.append( - protocol.ToolOutputErrorPart( + ui_events.ToolOutputErrorPart( tool_call_id=part.tool_call_id, error_text=_tool_error_text(part), provider_metadata=part.provider_metadata, @@ -417,7 +417,7 @@ def on_tool_result( # streaming view if any. continue out.append( - protocol.ToolOutputAvailablePart( + ui_events.ToolOutputAvailablePart( tool_call_id=part.tool_call_id, output=wire_output, provider_metadata=part.provider_metadata, @@ -428,7 +428,7 @@ def on_tool_result( def on_partial_tool_result( self, event: events_.PartialToolCallResult - ) -> list[protocol.UIMessageStreamPart]: + ) -> list[ui_events.UIMessageStreamPart]: """Feed the value and emit a preliminary output. Each PartialToolCallResult carries one yielded value plus the @@ -438,7 +438,7 @@ def on_partial_tool_result( The AI SDK supersedes preliminary outputs with the final ``ToolCallResult`` when it arrives. """ - out: list[protocol.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamPart] = [] tcid = event.tool_call_id factory = event.aggregator_factory @@ -460,7 +460,7 @@ def on_partial_tool_result( return out out.append( - protocol.ToolOutputAvailablePart( + ui_events.ToolOutputAvailablePart( tool_call_id=tcid, output=wire_output, preliminary=True, @@ -472,10 +472,10 @@ def on_partial_tool_result( def on_hook( self, event: events_.HookEvent - ) -> list[protocol.UIMessageStreamPart]: + ) -> list[ui_events.UIMessageStreamPart]: """Handle a ``HookEvent`` — emit approval parts.""" hook_part = event.hook - out: list[protocol.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamPart] = [] # Ensure the UI message is started. out.extend(self._ensure_started(event.message.turn_id)) @@ -489,7 +489,7 @@ def on_hook( return out self.emitted_approval_requests.add(tc_id) out.append( - protocol.ToolApprovalRequestPart( + ui_events.ToolApprovalRequestPart( approval_id=hook_part.hook_id, tool_call_id=tc_id, is_automatic=_metadata_bool( @@ -500,7 +500,7 @@ def on_hook( elif hook_part.status == "resolved": resolution: dict[str, Any] = hook_part.resolution or {} out.append( - protocol.ToolApprovalResponsePart( + ui_events.ToolApprovalResponsePart( approval_id=hook_part.hook_id, approved=bool(resolution.get("granted")), reason=resolution.get("reason"), @@ -513,10 +513,10 @@ def on_hook( ) ) if not resolution.get("granted"): - out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) + out.append(ui_events.ToolOutputDeniedPart(tool_call_id=tc_id)) elif hook_part.status == "cancelled": out.append( - protocol.ToolOutputErrorPart( + ui_events.ToolOutputErrorPart( tool_call_id=tc_id, error_text="Hook cancelled", ) @@ -526,8 +526,8 @@ def on_hook( # -- phase: stream finish ------------------------------------------------ - def finish(self) -> list[protocol.UIMessageStreamPart]: + def finish(self) -> list[ui_events.UIMessageStreamPart]: parts = self._finish_step() if self.emitted_start: - parts.append(protocol.FinishPart(finish_reason="stop")) + parts.append(ui_events.FinishPart(finish_reason="stop")) return parts diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py index 51fe953f..5c4bfc42 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ b/src/ai/agents/ui/ai_sdk/outbound/history.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING -from .. import _parts, _roundtrip, ui_message +from .. import _parts, _roundtrip, ui_messages if TYPE_CHECKING: from .....types import messages as messages_ @@ -28,7 +28,7 @@ def _belongs_to_bubble( def to_ui_messages( messages: list[messages_.Message], -) -> list[ui_message.UIMessage]: +) -> list[ui_messages.UIMessage]: """Group persisted messages into UIMessage bubbles. ``user``/``system`` messages become standalone UIMessages. Runs of @@ -36,7 +36,7 @@ def to_ui_messages( assistant UIMessage, with tool results and approval state folded into the corresponding tool-call parts. """ - result: list[ui_message.UIMessage] = [] + result: list[ui_messages.UIMessage] = [] i = 0 while i < len(messages): @@ -44,7 +44,7 @@ def to_ui_messages( if msg.role in ("user", "system"): result.append( - ui_message.UIMessage( + ui_messages.UIMessage( id=msg.id, role=msg.role, metadata=_roundtrip.metadata_for([msg]), @@ -55,7 +55,7 @@ def to_ui_messages( continue if msg.role == "assistant": - ui_parts: list[ui_message.UIMessagePart] = [] + ui_parts: list[ui_messages.UIMessagePart] = [] source_messages: list[messages_.Message] = [] bubble_id = _assistant_bubble_id(msg) @@ -79,7 +79,7 @@ def to_ui_messages( ui_parts = _parts.dedupe_tool_parts(ui_parts) result.append( - ui_message.UIMessage( + ui_messages.UIMessage( id=bubble_id, role="assistant", metadata=_roundtrip.metadata_for(source_messages), diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index b859bc2d..516c1377 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -8,7 +8,7 @@ import pydantic -from .. import protocol +from .. import ui_events from .stream import to_stream if TYPE_CHECKING: @@ -36,17 +36,17 @@ def _json_default(obj: Any) -> Any: ) -def serialize_part(part: protocol.UIMessageStreamPart) -> str: +def serialize_part(part: ui_events.UIMessageStreamPart) -> str: """Serialize a stream part to JSON with camelCase keys.""" d = dataclasses.asdict(part) - if isinstance(part, protocol.DataPart): + if isinstance(part, ui_events.DataPart): d["type"] = part.type del d["data_type"] camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} return json.dumps(camel_dict, default=_json_default) -def format_sse(part: protocol.UIMessageStreamPart) -> str: +def format_sse(part: ui_events.UIMessageStreamPart) -> str: """Format a stream part as an SSE data line.""" return f"data: {serialize_part(part)}\n\n" diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 4b70f920..33d8c2de 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -10,12 +10,12 @@ if TYPE_CHECKING: from collections.abc import AsyncGenerator, AsyncIterable - from .. import protocol + from .. import ui_events async def to_stream( events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[protocol.UIMessageStreamPart]: +) -> AsyncGenerator[ui_events.UIMessageStreamPart]: """Walk ``events`` once, emitting AI SDK UI stream parts. Streaming text/reasoning/tool-input deltas come from model events. diff --git a/src/ai/agents/ui/ai_sdk/protocol.py b/src/ai/agents/ui/ai_sdk/ui_events.py similarity index 100% rename from src/ai/agents/ui/ai_sdk/protocol.py rename to src/ai/agents/ui/ai_sdk/ui_events.py diff --git a/src/ai/agents/ui/ai_sdk/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_messages.py similarity index 100% rename from src/ai/agents/ui/ai_sdk/ui_message.py rename to src/ai/agents/ui/ai_sdk/ui_messages.py diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py index b896cfce..67fe8619 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ b/tests/agents/ui/ai_sdk/outbound/test_history.py @@ -4,7 +4,7 @@ from ai.agents.ui import ai_sdk from ai.agents.ui.ai_sdk import to_ui_messages -from ai.agents.ui.ai_sdk.ui_message import ( +from ai.agents.ui.ai_sdk.ui_messages import ( UIDynamicToolPart, UIFilePart, UITextPart, diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 9f3a0cf7..eefc065a 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -3,7 +3,7 @@ import json from collections.abc import AsyncGenerator -from ai.agents.ui.ai_sdk import protocol, to_sse +from ai.agents.ui.ai_sdk import to_sse, ui_events from ai.agents.ui.ai_sdk.outbound.sse import ( format_done_sse, format_sse, @@ -14,27 +14,27 @@ def test_serialize_part_camelcases_keys() -> None: - part = protocol.StartPart(message_id="m1") + part = ui_events.StartPart(message_id="m1") payload = json.loads(serialize_part(part)) assert payload == {"type": "start", "messageId": "m1"} def test_format_sse_wraps_data_line() -> None: - part = protocol.TextDeltaPart(id="t1", delta="hi") + part = ui_events.TextDeltaPart(id="t1", delta="hi") line = format_sse(part) assert line.startswith("data: ") assert line.endswith("\n\n") def test_serialize_data_part_uses_type_with_prefix() -> None: - part = protocol.DataPart(data_type="custom", data={"k": 1}) + part = ui_events.DataPart(data_type="custom", data={"k": 1}) payload = json.loads(serialize_part(part)) assert payload["type"] == "data-custom" assert "dataType" not in payload def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: - part = protocol.ToolApprovalResponsePart( + part = ui_events.ToolApprovalResponsePart( approval_id="approval-1", approved=False, reason="no", diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 4e18fdd2..4ca5eea5 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -1,9 +1,10 @@ from __future__ import annotations from collections.abc import AsyncGenerator +from typing import Any import ai -from ai.agents.ui.ai_sdk import protocol, to_stream +from ai.agents.ui.ai_sdk import to_stream, ui_events from ai.types import events as agent_events_ from ai.types import events as events_ from ai.types import messages as messages_ @@ -18,7 +19,7 @@ async def _gen( async def _collect( stream_events: list[agent_events_.AgentEvent], -) -> list[protocol.UIMessageStreamPart]: +) -> list[ui_events.UIMessageStreamPart]: return [part async for part in to_stream(_gen(stream_events))] @@ -39,7 +40,7 @@ async def test_stream_start_uses_runtime_message_id() -> None: ] ) - start = next(part for part in out if isinstance(part, protocol.StartPart)) + start = next(part for part in out if isinstance(part, ui_events.StartPart)) assert start.message_id == "assistant-runtime-id" @@ -54,13 +55,13 @@ async def test_event_driven_text_streaming() -> None: ] ) - assert isinstance(out[0], protocol.StartPart) - assert isinstance(out[1], protocol.StartStepPart) - assert isinstance(out[2], protocol.TextStartPart) and out[2].id == text_id - assert isinstance(out[3], protocol.TextDeltaPart) and out[3].delta == "hi" - assert isinstance(out[4], protocol.TextEndPart) and out[4].id == text_id - assert isinstance(out[5], protocol.FinishStepPart) - assert isinstance(out[6], protocol.FinishPart) + assert isinstance(out[0], ui_events.StartPart) + assert isinstance(out[1], ui_events.StartStepPart) + assert isinstance(out[2], ui_events.TextStartPart) and out[2].id == text_id + assert isinstance(out[3], ui_events.TextDeltaPart) and out[3].delta == "hi" + assert isinstance(out[4], ui_events.TextEndPart) and out[4].id == text_id + assert isinstance(out[5], ui_events.FinishStepPart) + assert isinstance(out[6], ui_events.FinishPart) async def test_tool_call_and_result_emit_terminal_parts() -> None: @@ -168,7 +169,7 @@ async def test_approval_request_hook_emits_approval_part() -> None: ] ) approval_parts = [ - p for p in out if isinstance(p, protocol.ToolApprovalRequestPart) + p for p in out if isinstance(p, ui_events.ToolApprovalRequestPart) ] assert len(approval_parts) == 1 assert approval_parts[0].tool_call_id == "tc1" @@ -203,7 +204,7 @@ async def test_partial_tool_results_emit_preliminary_outputs() -> None: prelim = [ p for p in out - if isinstance(p, protocol.ToolOutputAvailablePart) and p.preliminary + if isinstance(p, ui_events.ToolOutputAvailablePart) and p.preliminary ] assert [p.output for p in prelim] == [ "hit 1, ", @@ -215,7 +216,7 @@ async def test_partial_tool_results_emit_preliminary_outputs() -> None: async def test_partial_message_bundle_becomes_ui_message() -> None: """MessageAggregator's snapshot collapses to one UIMessage.""" - from ai.agents.ui.ai_sdk.ui_message import UIMessage + from ai.agents.ui.ai_sdk.ui_messages import UIMessage inner_msg = messages_.Message( role="assistant", @@ -238,7 +239,7 @@ async def test_partial_message_bundle_becomes_ui_message() -> None: [prelim] = [ p for p in out - if isinstance(p, protocol.ToolOutputAvailablePart) and p.preliminary + if isinstance(p, ui_events.ToolOutputAvailablePart) and p.preliminary ] assert isinstance(prelim.output, UIMessage) assert prelim.output.role == "assistant" @@ -256,7 +257,9 @@ async def test_partial_tool_result_without_factory_is_skipped() -> None: ), ] ) - assert not any(isinstance(p, protocol.ToolOutputAvailablePart) for p in out) + assert not any( + isinstance(p, ui_events.ToolOutputAvailablePart) for p in out + ) async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: @@ -289,13 +292,13 @@ async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: ] ) - start = next(p for p in out if isinstance(p, protocol.ToolInputStartPart)) + start = next(p for p in out if isinstance(p, ui_events.ToolInputStartPart)) assert start.provider_executed is True assert start.dynamic is True assert start.provider_metadata == {"provider": {"start": True}} available = next( - p for p in out if isinstance(p, protocol.ToolInputAvailablePart) + p for p in out if isinstance(p, ui_events.ToolInputAvailablePart) ) assert available.provider_executed is True assert available.dynamic is True @@ -303,7 +306,7 @@ async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: assert available.provider_metadata == {"provider": {"call": True}} result = next( - p for p in out if isinstance(p, protocol.ToolOutputAvailablePart) + p for p in out if isinstance(p, ui_events.ToolOutputAvailablePart) ) assert result.provider_executed is True assert result.dynamic is True @@ -322,14 +325,14 @@ async def test_file_event_emits_file_part_with_data_url_and_metadata() -> None: ] ) - file_part = next(p for p in out if isinstance(p, protocol.FilePart)) + file_part = next(p for p in out if isinstance(p, ui_events.FilePart)) assert file_part.url == "data:image/png;base64,YWJj" assert file_part.media_type == "image/png" assert file_part.provider_metadata == {"provider": {"file": True}} async def test_resolved_approval_hook_emits_response_part() -> None: - hook = messages_.HookPart( + hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc1", hook_type="ToolApproval", status="resolved", @@ -355,14 +358,14 @@ async def test_resolved_approval_hook_emits_response_part() -> None: ) response = next( - p for p in out if isinstance(p, protocol.ToolApprovalResponsePart) + p for p in out if isinstance(p, ui_events.ToolApprovalResponsePart) ) assert response.approval_id == "approve_tc1" assert response.approved is False assert response.reason == "not allowed" assert response.provider_executed is True assert response.provider_metadata == {"provider": {"approval": True}} - assert any(isinstance(p, protocol.ToolOutputDeniedPart) for p in out) + assert any(isinstance(p, ui_events.ToolOutputDeniedPart) for p in out) # NOTE: agent-change boundary detection used to be driven by diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py index feb8b7d2..7bc89009 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -10,7 +10,7 @@ _normalize_ui_messages, extract_approvals, ) -from ai.agents.ui.ai_sdk.ui_message import UIMessage, UIToolPart +from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart from ai.types import messages as messages_ diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py index 9876e9c8..1b581e75 100644 --- a/tests/agents/ui/ai_sdk/test_parts.py +++ b/tests/agents/ui/ai_sdk/test_parts.py @@ -1,7 +1,7 @@ from __future__ import annotations from ai.agents.ui.ai_sdk import _parts -from ai.agents.ui.ai_sdk.ui_message import ( +from ai.agents.ui.ai_sdk.ui_messages import ( UIReasoningPart, UITextPart, UIToolApproval, From 175909f81b2b3f7b75c22e77607d57bb8b0f3acf Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 16:11:34 -0700 Subject: [PATCH 06/13] Rename ui events to UIFooEvent --- src/ai/agents/ui/ai_sdk/outbound/_state.py | 114 +++++++-------- src/ai/agents/ui/ai_sdk/outbound/sse.py | 20 +-- src/ai/agents/ui/ai_sdk/outbound/stream.py | 26 ++-- src/ai/agents/ui/ai_sdk/ui_events.py | 136 +++++++++--------- tests/agents/ui/ai_sdk/outbound/test_sse.py | 24 ++-- .../agents/ui/ai_sdk/outbound/test_stream.py | 94 ++++++------ 6 files changed, 212 insertions(+), 202 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py index 146b13fa..3d92b351 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound/_state.py @@ -101,24 +101,24 @@ def __init__(self) -> None: # -- boundary helpers ---------------------------------------------------- - def _close_open_blocks(self) -> list[ui_events.UIMessageStreamPart]: - parts: list[ui_events.UIMessageStreamPart] = [] + def _close_open_blocks(self) -> list[ui_events.UIMessageStreamEvent]: + events: list[ui_events.UIMessageStreamEvent] = [] for rid in list(self.open_reasoning_ids): - parts.append(ui_events.ReasoningEndPart(id=rid)) + events.append(ui_events.UIReasoningEndEvent(id=rid)) self.completed_reasoning_ids.add(rid) self.open_reasoning_ids.clear() for tid in list(self.open_text_ids): - parts.append(ui_events.TextEndPart(id=tid)) + events.append(ui_events.UITextEndEvent(id=tid)) self.completed_text_ids.add(tid) self.open_text_ids.clear() - return parts + return events - def _finish_step(self) -> list[ui_events.UIMessageStreamPart]: - parts = self._close_open_blocks() + def _finish_step(self) -> list[ui_events.UIMessageStreamEvent]: + events = self._close_open_blocks() if self.in_step: - parts.append(ui_events.FinishStepPart()) + events.append(ui_events.UIFinishStepEvent()) self.in_step = False - return parts + return events def _reset_step_tracking(self) -> None: self.started_tool_inputs.clear() @@ -130,26 +130,26 @@ def _reset_step_tracking(self) -> None: def _ensure_started( self, message_id: str | None = None, - ) -> list[ui_events.UIMessageStreamPart]: - """Lazily emit StartPart / StartStepPart on the first event.""" - parts: list[ui_events.UIMessageStreamPart] = [] + ) -> list[ui_events.UIMessageStreamEvent]: + """Lazily emit UIStartEvent / UIStartStepEvent on the first event.""" + events: list[ui_events.UIMessageStreamEvent] = [] if not self.emitted_start: self.ui_message_id = message_id - parts.append(ui_events.StartPart(message_id=self.ui_message_id)) - parts.append(ui_events.StartStepPart()) + events.append(ui_events.UIStartEvent(message_id=self.ui_message_id)) + events.append(ui_events.UIStartStepEvent()) self.emitted_start = True self.in_step = True self._reset_step_tracking() - return parts + return events # -- phase: streaming events -------------------------------------------- def on_event( self, event: events_.Event - ) -> list[ui_events.UIMessageStreamPart]: - out: list[ui_events.UIMessageStreamPart] = [] + ) -> list[ui_events.UIMessageStreamEvent]: + out: list[ui_events.UIMessageStreamEvent] = [] # Lazily open the UI message on the first streaming event. if not self.emitted_start: @@ -159,7 +159,7 @@ def on_event( case events_.TextStart(block_id=pid): self.open_text_ids.add(pid) out.append( - ui_events.TextStartPart( + ui_events.UITextStartEvent( id=pid, provider_metadata=event.provider_metadata, ) @@ -169,14 +169,14 @@ def on_event( if pid not in self.open_text_ids: self.open_text_ids.add(pid) out.append( - ui_events.TextStartPart( + ui_events.UITextStartEvent( id=pid, provider_metadata=event.provider_metadata, ) ) self.text_delta_ids.add(pid) out.append( - ui_events.TextDeltaPart( + ui_events.UITextDeltaEvent( id=pid, delta=chunk, provider_metadata=event.provider_metadata, @@ -188,7 +188,7 @@ def on_event( self.open_text_ids.discard(pid) self.completed_text_ids.add(pid) out.append( - ui_events.TextEndPart( + ui_events.UITextEndEvent( id=pid, provider_metadata=event.provider_metadata, ) @@ -197,7 +197,7 @@ def on_event( case events_.ReasoningStart(block_id=pid): self.open_reasoning_ids.add(pid) out.append( - ui_events.ReasoningStartPart( + ui_events.UIReasoningStartEvent( id=pid, provider_metadata=event.provider_metadata, ) @@ -207,14 +207,14 @@ def on_event( if pid not in self.open_reasoning_ids: self.open_reasoning_ids.add(pid) out.append( - ui_events.ReasoningStartPart( + ui_events.UIReasoningStartEvent( id=pid, provider_metadata=event.provider_metadata, ) ) self.reasoning_delta_ids.add(pid) out.append( - ui_events.ReasoningDeltaPart( + ui_events.UIReasoningDeltaEvent( id=pid, delta=chunk, provider_metadata=event.provider_metadata, @@ -226,7 +226,7 @@ def on_event( self.open_reasoning_ids.discard(pid) self.completed_reasoning_ids.add(pid) out.append( - ui_events.ReasoningEndPart( + ui_events.UIReasoningEndEvent( id=pid, provider_metadata=event.provider_metadata, ) @@ -238,7 +238,7 @@ def on_event( return out self.started_tool_inputs.add(tcid) out.append( - ui_events.ToolInputStartPart( + ui_events.UIToolInputStartEvent( tool_call_id=tcid, tool_name=name, provider_metadata=event.provider_metadata, @@ -249,14 +249,14 @@ def on_event( if tcid not in self.started_tool_inputs: self.started_tool_inputs.add(tcid) out.append( - ui_events.ToolInputStartPart( + ui_events.UIToolInputStartEvent( tool_call_id=tcid, tool_name=self.tool_names.get(tcid, ""), provider_metadata=event.provider_metadata, ) ) out.append( - ui_events.ToolInputDeltaPart( + ui_events.UIToolInputDeltaEvent( tool_call_id=tcid, input_text_delta=chunk, ) @@ -271,7 +271,7 @@ def on_event( return out self.started_tool_inputs.add(tcid) out.append( - ui_events.ToolInputStartPart( + ui_events.UIToolInputStartEvent( tool_call_id=tcid, tool_name=name, provider_executed=True, @@ -284,7 +284,7 @@ def on_event( if tcid not in self.started_tool_inputs: self.started_tool_inputs.add(tcid) out.append( - ui_events.ToolInputStartPart( + ui_events.UIToolInputStartEvent( tool_call_id=tcid, tool_name=self.tool_names.get(tcid, ""), provider_executed=True, @@ -293,7 +293,7 @@ def on_event( ) ) out.append( - ui_events.ToolInputDeltaPart( + ui_events.UIToolInputDeltaEvent( tool_call_id=tcid, input_text_delta=chunk, ) @@ -303,7 +303,7 @@ def on_event( if tcid not in self.input_available_emitted: self.input_available_emitted.add(tcid) out.append( - ui_events.ToolInputAvailablePart( + ui_events.UIToolInputAvailableEvent( tool_call_id=tcid, tool_name=tc.tool_name, input=_normalize_tool_input(tc.tool_args), @@ -320,7 +320,7 @@ def on_event( self.emitted_tool_results.add(tcid) if result.is_error: out.append( - ui_events.ToolOutputErrorPart( + ui_events.UIToolOutputErrorEvent( tool_call_id=tcid, error_text=str(result.result), provider_executed=True, @@ -331,7 +331,7 @@ def on_event( ) else: out.append( - ui_events.ToolOutputAvailablePart( + ui_events.UIToolOutputAvailableEvent( tool_call_id=tcid, output=result.result, provider_executed=True, @@ -346,7 +346,7 @@ def on_event( data=data, ): out.append( - ui_events.FilePart( + ui_events.UIFileEvent( url=media.data_to_data_url(data, media_type), media_type=media_type, provider_metadata=event.provider_metadata, @@ -359,10 +359,10 @@ def on_event( def on_tool_result( self, event: events_.ToolCallResult - ) -> list[ui_events.UIMessageStreamPart]: - """Handle a ``ToolCallResult`` — emit tool input/output parts.""" + ) -> list[ui_events.UIMessageStreamEvent]: + """Handle a ``ToolCallResult`` — emit tool input/output events.""" msg = event.message - out: list[ui_events.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamEvent] = [] out.extend(self._ensure_started(msg.turn_id)) @@ -376,14 +376,14 @@ def on_tool_result( if part.tool_call_id not in self.started_tool_inputs: self.started_tool_inputs.add(part.tool_call_id) out.append( - ui_events.ToolInputStartPart( + ui_events.UIToolInputStartEvent( tool_call_id=part.tool_call_id, tool_name=part.tool_name, provider_metadata=part.provider_metadata, ) ) out.append( - ui_events.ToolInputAvailablePart( + ui_events.UIToolInputAvailableEvent( tool_call_id=part.tool_call_id, tool_name=part.tool_name, input=_normalize_tool_input(part.tool_args), @@ -402,7 +402,7 @@ def on_tool_result( self.emitted_tool_results.add(part.tool_call_id) if part.is_error: out.append( - ui_events.ToolOutputErrorPart( + ui_events.UIToolOutputErrorEvent( tool_call_id=part.tool_call_id, error_text=_tool_error_text(part), provider_metadata=part.provider_metadata, @@ -417,7 +417,7 @@ def on_tool_result( # streaming view if any. continue out.append( - ui_events.ToolOutputAvailablePart( + ui_events.UIToolOutputAvailableEvent( tool_call_id=part.tool_call_id, output=wire_output, provider_metadata=part.provider_metadata, @@ -428,17 +428,17 @@ def on_tool_result( def on_partial_tool_result( self, event: events_.PartialToolCallResult - ) -> list[ui_events.UIMessageStreamPart]: + ) -> list[ui_events.UIMessageStreamEvent]: """Feed the value and emit a preliminary output. Each PartialToolCallResult carries one yielded value plus the aggregator factory the tool was declared with. We instantiate the aggregator once per ``tool_call_id`` and use its snapshot - as the ``output`` of a preliminary ``ToolOutputAvailablePart``. + as the ``output`` of a preliminary ``UIToolOutputAvailableEvent``. The AI SDK supersedes preliminary outputs with the final ``ToolCallResult`` when it arrives. """ - out: list[ui_events.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamEvent] = [] tcid = event.tool_call_id factory = event.aggregator_factory @@ -460,7 +460,7 @@ def on_partial_tool_result( return out out.append( - ui_events.ToolOutputAvailablePart( + ui_events.UIToolOutputAvailableEvent( tool_call_id=tcid, output=wire_output, preliminary=True, @@ -472,10 +472,10 @@ def on_partial_tool_result( def on_hook( self, event: events_.HookEvent - ) -> list[ui_events.UIMessageStreamPart]: - """Handle a ``HookEvent`` — emit approval parts.""" + ) -> list[ui_events.UIMessageStreamEvent]: + """Handle a ``HookEvent`` — emit approval events.""" hook_part = event.hook - out: list[ui_events.UIMessageStreamPart] = [] + out: list[ui_events.UIMessageStreamEvent] = [] # Ensure the UI message is started. out.extend(self._ensure_started(event.message.turn_id)) @@ -489,7 +489,7 @@ def on_hook( return out self.emitted_approval_requests.add(tc_id) out.append( - ui_events.ToolApprovalRequestPart( + ui_events.UIToolApprovalRequestEvent( approval_id=hook_part.hook_id, tool_call_id=tc_id, is_automatic=_metadata_bool( @@ -500,7 +500,7 @@ def on_hook( elif hook_part.status == "resolved": resolution: dict[str, Any] = hook_part.resolution or {} out.append( - ui_events.ToolApprovalResponsePart( + ui_events.UIToolApprovalResponseEvent( approval_id=hook_part.hook_id, approved=bool(resolution.get("granted")), reason=resolution.get("reason"), @@ -513,10 +513,10 @@ def on_hook( ) ) if not resolution.get("granted"): - out.append(ui_events.ToolOutputDeniedPart(tool_call_id=tc_id)) + out.append(ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id)) elif hook_part.status == "cancelled": out.append( - ui_events.ToolOutputErrorPart( + ui_events.UIToolOutputErrorEvent( tool_call_id=tc_id, error_text="Hook cancelled", ) @@ -526,8 +526,8 @@ def on_hook( # -- phase: stream finish ------------------------------------------------ - def finish(self) -> list[ui_events.UIMessageStreamPart]: - parts = self._finish_step() + def finish(self) -> list[ui_events.UIMessageStreamEvent]: + events = self._finish_step() if self.emitted_start: - parts.append(ui_events.FinishPart(finish_reason="stop")) - return parts + events.append(ui_events.UIFinishEvent(finish_reason="stop")) + return events diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py index 516c1377..9fb667b6 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ b/src/ai/agents/ui/ai_sdk/outbound/sse.py @@ -36,19 +36,19 @@ def _json_default(obj: Any) -> Any: ) -def serialize_part(part: ui_events.UIMessageStreamPart) -> str: - """Serialize a stream part to JSON with camelCase keys.""" - d = dataclasses.asdict(part) - if isinstance(part, ui_events.DataPart): - d["type"] = part.type +def serialize_event(event: ui_events.UIMessageStreamEvent) -> str: + """Serialize a stream event to JSON with camelCase keys.""" + d = dataclasses.asdict(event) + if isinstance(event, ui_events.UIDataEvent): + d["type"] = event.type del d["data_type"] camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} return json.dumps(camel_dict, default=_json_default) -def format_sse(part: ui_events.UIMessageStreamPart) -> str: - """Format a stream part as an SSE data line.""" - return f"data: {serialize_part(part)}\n\n" +def format_sse(event: ui_events.UIMessageStreamEvent) -> str: + """Format a stream event as an SSE data line.""" + return f"data: {serialize_event(event)}\n\n" def format_done_sse() -> str: @@ -60,6 +60,6 @@ async def to_sse( events: AsyncIterable[events_.AgentEvent], ) -> AsyncGenerator[str]: """Convert an internal event stream into SSE strings.""" - async for part in to_stream(events): - yield format_sse(part) + async for event in to_stream(events): + yield format_sse(event) yield format_done_sse() diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py index 33d8c2de..4556fad4 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound/stream.py @@ -1,4 +1,4 @@ -"""Convert internal event streams into AI SDK UI protocol parts.""" +"""Convert internal event streams into AI SDK UI protocol events.""" from __future__ import annotations @@ -15,8 +15,8 @@ async def to_stream( events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[ui_events.UIMessageStreamPart]: - """Walk ``events`` once, emitting AI SDK UI stream parts. +) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: + """Walk ``events`` once, emitting AI SDK UI stream events. Streaming text/reasoning/tool-input deltas come from model events. Tool results come from ``ToolCallResult``. Hook signals come from @@ -26,17 +26,17 @@ async def to_stream( async for event in events: if isinstance(event, events_.ToolCallResult): - for part in state.on_tool_result(event): - yield part + for ui_event in state.on_tool_result(event): + yield ui_event elif isinstance(event, events_.PartialToolCallResult): - for part in state.on_partial_tool_result(event): - yield part + for ui_event in state.on_partial_tool_result(event): + yield ui_event elif isinstance(event, events_.HookEvent): - for part in state.on_hook(event): - yield part + for ui_event in state.on_hook(event): + yield ui_event else: - for part in state.on_event(event): - yield part + for ui_event in state.on_event(event): + yield ui_event - for part in state.finish(): - yield part + for ui_event in state.finish(): + yield ui_event diff --git a/src/ai/agents/ui/ai_sdk/ui_events.py b/src/ai/agents/ui/ai_sdk/ui_events.py index c8db7b3e..a3a130b2 100644 --- a/src/ai/agents/ui/ai_sdk/ui_events.py +++ b/src/ai/agents/ui/ai_sdk/ui_events.py @@ -13,7 +13,7 @@ } -# different kinds of messages expected by the frontend +# different kinds of stream events expected by the frontend FinishReason = Literal[ "stop", "length", "content-filter", "tool-calls", "error", "other" @@ -21,7 +21,7 @@ @dataclasses.dataclass -class StartPart: +class UIStartEvent: """Indicates the beginning of a new message with metadata.""" type: Literal["start"] = dataclasses.field(default="start", init=False) @@ -30,7 +30,7 @@ class StartPart: @dataclasses.dataclass -class TextStartPart: +class UITextStartEvent: """Indicates the beginning of a text block.""" id: str @@ -41,7 +41,7 @@ class TextStartPart: @dataclasses.dataclass -class TextDeltaPart: +class UITextDeltaEvent: """Contains incremental text content for the text block.""" id: str @@ -53,7 +53,7 @@ class TextDeltaPart: @dataclasses.dataclass -class TextEndPart: +class UITextEndEvent: """Indicates the completion of a text block.""" id: str @@ -64,7 +64,7 @@ class TextEndPart: @dataclasses.dataclass -class ReasoningStartPart: +class UIReasoningStartEvent: """Indicates the beginning of a reasoning block.""" id: str @@ -75,7 +75,7 @@ class ReasoningStartPart: @dataclasses.dataclass -class ReasoningDeltaPart: +class UIReasoningDeltaEvent: """Contains incremental reasoning content for the reasoning block.""" id: str @@ -87,7 +87,7 @@ class ReasoningDeltaPart: @dataclasses.dataclass -class ReasoningEndPart: +class UIReasoningEndEvent: """Indicates the completion of a reasoning block.""" id: str @@ -98,8 +98,8 @@ class ReasoningEndPart: @dataclasses.dataclass -class CustomPart: - """Provider-specific content that does not fit standard UI parts.""" +class UICustomEvent: + """Provider-specific content that does not fit standard UI events.""" kind: str type: Literal["custom"] = dataclasses.field(default="custom", init=False) @@ -107,7 +107,7 @@ class CustomPart: @dataclasses.dataclass -class SourceUrlPart: +class UISourceUrlEvent: """References to external URLs.""" source_id: str @@ -120,7 +120,7 @@ class SourceUrlPart: @dataclasses.dataclass -class SourceDocumentPart: +class UISourceDocumentEvent: """References to documents or files.""" source_id: str @@ -134,8 +134,8 @@ class SourceDocumentPart: @dataclasses.dataclass -class FilePart: - """The file parts contain references to files with their media type.""" +class UIFileEvent: + """References to files with their media type.""" url: str media_type: str @@ -144,7 +144,7 @@ class FilePart: @dataclasses.dataclass -class ReasoningFilePart: +class UIReasoningFileEvent: """A file generated as part of model reasoning.""" url: str @@ -156,14 +156,14 @@ class ReasoningFilePart: @dataclasses.dataclass -class DataPart: - """Custom data part for arbitrary structured data. +class UIDataEvent: + """Custom data event for arbitrary structured data. - Data parts support type-specific handling. + Data events support type-specific handling. The wire type is ``data-{data_type}`` (e.g. ``data-custom``), exposed - via the ``type`` property so that ``DataPart`` is uniform with every - other ``UIMessageStreamPart`` variant. + via the ``type`` property so that ``UIDataEvent`` is uniform with every + other ``UIMessageStreamEvent`` variant. """ data_type: str @@ -178,7 +178,7 @@ def type(self) -> str: @dataclasses.dataclass -class ToolInputStartPart: +class UIToolInputStartEvent: """Indicates the beginning of tool input streaming.""" tool_call_id: str @@ -194,7 +194,7 @@ class ToolInputStartPart: @dataclasses.dataclass -class ToolInputDeltaPart: +class UIToolInputDeltaEvent: """Incremental chunks of tool input as it's being generated.""" tool_call_id: str @@ -205,7 +205,7 @@ class ToolInputDeltaPart: @dataclasses.dataclass -class ToolInputAvailablePart: +class UIToolInputAvailableEvent: """Indicates that tool input is complete and ready for execution.""" tool_call_id: str @@ -222,7 +222,7 @@ class ToolInputAvailablePart: @dataclasses.dataclass -class ToolInputErrorPart: +class UIToolInputErrorEvent: """Indicates an error occurred during tool input processing.""" tool_call_id: str @@ -240,7 +240,7 @@ class ToolInputErrorPart: @dataclasses.dataclass -class ToolOutputAvailablePart: +class UIToolOutputAvailableEvent: """Contains the result of tool execution.""" tool_call_id: str @@ -256,7 +256,7 @@ class ToolOutputAvailablePart: @dataclasses.dataclass -class ToolOutputErrorPart: +class UIToolOutputErrorEvent: """Indicates an error occurred during tool execution.""" tool_call_id: str @@ -271,7 +271,7 @@ class ToolOutputErrorPart: @dataclasses.dataclass -class ToolOutputDeniedPart: +class UIToolOutputDeniedEvent: """Indicates tool execution was denied.""" tool_call_id: str @@ -281,7 +281,7 @@ class ToolOutputDeniedPart: @dataclasses.dataclass -class ToolApprovalRequestPart: +class UIToolApprovalRequestEvent: """Requests approval for tool execution.""" approval_id: str @@ -293,7 +293,7 @@ class ToolApprovalRequestPart: @dataclasses.dataclass -class ToolApprovalResponsePart: +class UIToolApprovalResponseEvent: """Records an approval decision for a tool call.""" approval_id: str @@ -307,8 +307,8 @@ class ToolApprovalResponsePart: @dataclasses.dataclass -class StartStepPart: - """A part indicating the start of a step.""" +class UIStartStepEvent: + """Indicates the start of a step.""" type: Literal["start-step"] = dataclasses.field( default="start-step", init=False @@ -316,8 +316,8 @@ class StartStepPart: @dataclasses.dataclass -class FinishStepPart: - """A part indicating that a step has been completed.""" +class UIFinishStepEvent: + """Indicates that a step has been completed.""" type: Literal["finish-step"] = dataclasses.field( default="finish-step", init=False @@ -325,8 +325,8 @@ class FinishStepPart: @dataclasses.dataclass -class FinishPart: - """A part indicating the completion of a message.""" +class UIFinishEvent: + """Indicates the completion of a message.""" type: Literal["finish"] = dataclasses.field(default="finish", init=False) finish_reason: FinishReason | None = None @@ -334,7 +334,7 @@ class FinishPart: @dataclasses.dataclass -class AbortPart: +class UIAbortEvent: """Indicates the message was aborted.""" type: Literal["abort"] = dataclasses.field(default="abort", init=False) @@ -342,7 +342,7 @@ class AbortPart: @dataclasses.dataclass -class MessageMetadataPart: +class UIMessageMetadataEvent: """Contains message metadata.""" message_metadata: Any @@ -352,40 +352,40 @@ class MessageMetadataPart: @dataclasses.dataclass -class ErrorPart: - """The error parts are appended to the message as they are received.""" +class UIErrorEvent: + """Errors appended to the message as they are received.""" error_text: str type: Literal["error"] = dataclasses.field(default="error", init=False) -UIMessageStreamPart = ( - StartPart - | TextStartPart - | TextDeltaPart - | TextEndPart - | ReasoningStartPart - | ReasoningDeltaPart - | ReasoningEndPart - | CustomPart - | SourceUrlPart - | SourceDocumentPart - | FilePart - | ReasoningFilePart - | DataPart - | ToolInputStartPart - | ToolInputDeltaPart - | ToolInputAvailablePart - | ToolInputErrorPart - | ToolOutputAvailablePart - | ToolOutputErrorPart - | ToolOutputDeniedPart - | ToolApprovalRequestPart - | ToolApprovalResponsePart - | StartStepPart - | FinishStepPart - | FinishPart - | AbortPart - | MessageMetadataPart - | ErrorPart +UIMessageStreamEvent = ( + UIStartEvent + | UITextStartEvent + | UITextDeltaEvent + | UITextEndEvent + | UIReasoningStartEvent + | UIReasoningDeltaEvent + | UIReasoningEndEvent + | UICustomEvent + | UISourceUrlEvent + | UISourceDocumentEvent + | UIFileEvent + | UIReasoningFileEvent + | UIDataEvent + | UIToolInputStartEvent + | UIToolInputDeltaEvent + | UIToolInputAvailableEvent + | UIToolInputErrorEvent + | UIToolOutputAvailableEvent + | UIToolOutputErrorEvent + | UIToolOutputDeniedEvent + | UIToolApprovalRequestEvent + | UIToolApprovalResponseEvent + | UIStartStepEvent + | UIFinishStepEvent + | UIFinishEvent + | UIAbortEvent + | UIMessageMetadataEvent + | UIErrorEvent ) diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index eefc065a..02308aee 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -7,34 +7,34 @@ from ai.agents.ui.ai_sdk.outbound.sse import ( format_done_sse, format_sse, - serialize_part, + serialize_event, ) from ai.types import events as agent_events_ from ai.types import events as events_ -def test_serialize_part_camelcases_keys() -> None: - part = ui_events.StartPart(message_id="m1") - payload = json.loads(serialize_part(part)) +def test_serialize_event_camelcases_keys() -> None: + event = ui_events.UIStartEvent(message_id="m1") + payload = json.loads(serialize_event(event)) assert payload == {"type": "start", "messageId": "m1"} def test_format_sse_wraps_data_line() -> None: - part = ui_events.TextDeltaPart(id="t1", delta="hi") - line = format_sse(part) + event = ui_events.UITextDeltaEvent(id="t1", delta="hi") + line = format_sse(event) assert line.startswith("data: ") assert line.endswith("\n\n") -def test_serialize_data_part_uses_type_with_prefix() -> None: - part = ui_events.DataPart(data_type="custom", data={"k": 1}) - payload = json.loads(serialize_part(part)) +def test_serialize_data_event_uses_type_with_prefix() -> None: + event = ui_events.UIDataEvent(data_type="custom", data={"k": 1}) + payload = json.loads(serialize_event(event)) assert payload["type"] == "data-custom" assert "dataType" not in payload def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: - part = ui_events.ToolApprovalResponsePart( + event = ui_events.UIToolApprovalResponseEvent( approval_id="approval-1", approved=False, reason="no", @@ -42,7 +42,7 @@ def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: provider_metadata={"provider": {"k": "v"}}, ) - payload = json.loads(serialize_part(part)) + payload = json.loads(serialize_event(event)) assert payload == { "type": "tool-approval-response", @@ -79,7 +79,7 @@ async def test_to_sse_emits_data_prefixed_lines() -> None: ) ] assert all(line.startswith("data: ") for line in lines) - # first line is the start part (lazy open) + # first line is the start event (lazy open) first = json.loads(lines[0].removeprefix("data: ").rstrip()) assert first["type"] == "start" assert lines[-1] == "data: [DONE]\n\n" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py index 4ca5eea5..42d2c991 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/outbound/test_stream.py @@ -19,8 +19,8 @@ async def _gen( async def _collect( stream_events: list[agent_events_.AgentEvent], -) -> list[ui_events.UIMessageStreamPart]: - return [part async for part in to_stream(_gen(stream_events))] +) -> list[ui_events.UIMessageStreamEvent]: + return [event async for event in to_stream(_gen(stream_events))] async def test_stream_start_uses_runtime_message_id() -> None: @@ -40,7 +40,9 @@ async def test_stream_start_uses_runtime_message_id() -> None: ] ) - start = next(part for part in out if isinstance(part, ui_events.StartPart)) + start = next( + event for event in out if isinstance(event, ui_events.UIStartEvent) + ) assert start.message_id == "assistant-runtime-id" @@ -55,17 +57,23 @@ async def test_event_driven_text_streaming() -> None: ] ) - assert isinstance(out[0], ui_events.StartPart) - assert isinstance(out[1], ui_events.StartStepPart) - assert isinstance(out[2], ui_events.TextStartPart) and out[2].id == text_id - assert isinstance(out[3], ui_events.TextDeltaPart) and out[3].delta == "hi" - assert isinstance(out[4], ui_events.TextEndPart) and out[4].id == text_id - assert isinstance(out[5], ui_events.FinishStepPart) - assert isinstance(out[6], ui_events.FinishPart) + assert isinstance(out[0], ui_events.UIStartEvent) + assert isinstance(out[1], ui_events.UIStartStepEvent) + assert ( + isinstance(out[2], ui_events.UITextStartEvent) + and out[2].id == text_id + ) + assert ( + isinstance(out[3], ui_events.UITextDeltaEvent) + and out[3].delta == "hi" + ) + assert isinstance(out[4], ui_events.UITextEndEvent) and out[4].id == text_id + assert isinstance(out[5], ui_events.UIFinishStepEvent) + assert isinstance(out[6], ui_events.UIFinishEvent) -async def test_tool_call_and_result_emit_terminal_parts() -> None: - """ToolCallResult emits tool input and output parts.""" +async def test_tool_call_and_result_emit_terminal_events() -> None: + """ToolCallResult emits tool input and output events.""" tool_result_msg = messages_.Message( role="tool", parts=[ @@ -96,13 +104,13 @@ async def test_tool_call_and_result_emit_terminal_parts() -> None: ), ] ) - types = [type(part).__name__ for part in out] - assert "ToolInputStartPart" in types - assert "ToolOutputAvailablePart" in types + types = [type(event).__name__ for event in out] + assert "UIToolInputStartEvent" in types + assert "UIToolOutputAvailableEvent" in types async def test_tool_result_without_streaming_emits_input_start() -> None: - """ToolCallResult for a non-streamed tool emits input + output parts.""" + """ToolCallResult for a non-streamed tool emits input + output events.""" tool_result_msg = messages_.Message( role="tool", parts=[ @@ -127,14 +135,14 @@ async def test_tool_result_without_streaming_emits_input_start() -> None: ), ] ) - types = [type(part).__name__ for part in out] - assert "ToolInputStartPart" in types - assert "ToolInputAvailablePart" in types - assert "ToolOutputAvailablePart" in types + types = [type(event).__name__ for event in out] + assert "UIToolInputStartEvent" in types + assert "UIToolInputAvailableEvent" in types + assert "UIToolOutputAvailableEvent" in types -async def test_approval_request_hook_emits_approval_part() -> None: - """HookEvent with pending status emits a ToolApprovalRequestPart.""" +async def test_approval_request_hook_emits_approval_event() -> None: + """HookEvent with pending status emits a UIToolApprovalRequestEvent.""" out = await _collect( [ # Streaming tool events first @@ -168,16 +176,16 @@ async def test_approval_request_hook_emits_approval_part() -> None: ), ] ) - approval_parts = [ - p for p in out if isinstance(p, ui_events.ToolApprovalRequestPart) + approval_events = [ + p for p in out if isinstance(p, ui_events.UIToolApprovalRequestEvent) ] - assert len(approval_parts) == 1 - assert approval_parts[0].tool_call_id == "tc1" - assert approval_parts[0].approval_id == "approve_tc1" + assert len(approval_events) == 1 + assert approval_events[0].tool_call_id == "tc1" + assert approval_events[0].approval_id == "approve_tc1" async def test_partial_tool_results_emit_preliminary_outputs() -> None: - """Each partial result yields a preliminary part.""" + """Each partial result yields a preliminary event.""" out = await _collect( [ agent_events_.PartialToolCallResult( @@ -204,7 +212,7 @@ async def test_partial_tool_results_emit_preliminary_outputs() -> None: prelim = [ p for p in out - if isinstance(p, ui_events.ToolOutputAvailablePart) and p.preliminary + if isinstance(p, ui_events.UIToolOutputAvailableEvent) and p.preliminary ] assert [p.output for p in prelim] == [ "hit 1, ", @@ -239,7 +247,7 @@ async def test_partial_message_bundle_becomes_ui_message() -> None: [prelim] = [ p for p in out - if isinstance(p, ui_events.ToolOutputAvailablePart) and p.preliminary + if isinstance(p, ui_events.UIToolOutputAvailableEvent) and p.preliminary ] assert isinstance(prelim.output, UIMessage) assert prelim.output.role == "assistant" @@ -258,7 +266,7 @@ async def test_partial_tool_result_without_factory_is_skipped() -> None: ] ) assert not any( - isinstance(p, ui_events.ToolOutputAvailablePart) for p in out + isinstance(p, ui_events.UIToolOutputAvailableEvent) for p in out ) @@ -292,13 +300,15 @@ async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: ] ) - start = next(p for p in out if isinstance(p, ui_events.ToolInputStartPart)) + start = next( + p for p in out if isinstance(p, ui_events.UIToolInputStartEvent) + ) assert start.provider_executed is True assert start.dynamic is True assert start.provider_metadata == {"provider": {"start": True}} available = next( - p for p in out if isinstance(p, ui_events.ToolInputAvailablePart) + p for p in out if isinstance(p, ui_events.UIToolInputAvailableEvent) ) assert available.provider_executed is True assert available.dynamic is True @@ -306,7 +316,7 @@ async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: assert available.provider_metadata == {"provider": {"call": True}} result = next( - p for p in out if isinstance(p, ui_events.ToolOutputAvailablePart) + p for p in out if isinstance(p, ui_events.UIToolOutputAvailableEvent) ) assert result.provider_executed is True assert result.dynamic is True @@ -314,7 +324,7 @@ async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: assert result.provider_metadata == {"provider": {"result": True}} -async def test_file_event_emits_file_part_with_data_url_and_metadata() -> None: +async def test_file_event_emits_ui_file_event() -> None: out = await _collect( [ events_.FileEvent( @@ -325,13 +335,13 @@ async def test_file_event_emits_file_part_with_data_url_and_metadata() -> None: ] ) - file_part = next(p for p in out if isinstance(p, ui_events.FilePart)) - assert file_part.url == "data:image/png;base64,YWJj" - assert file_part.media_type == "image/png" - assert file_part.provider_metadata == {"provider": {"file": True}} + file_event = next(p for p in out if isinstance(p, ui_events.UIFileEvent)) + assert file_event.url == "data:image/png;base64,YWJj" + assert file_event.media_type == "image/png" + assert file_event.provider_metadata == {"provider": {"file": True}} -async def test_resolved_approval_hook_emits_response_part() -> None: +async def test_resolved_approval_hook_emits_response_event() -> None: hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc1", hook_type="ToolApproval", @@ -358,14 +368,14 @@ async def test_resolved_approval_hook_emits_response_part() -> None: ) response = next( - p for p in out if isinstance(p, ui_events.ToolApprovalResponsePart) + p for p in out if isinstance(p, ui_events.UIToolApprovalResponseEvent) ) assert response.approval_id == "approve_tc1" assert response.approved is False assert response.reason == "not allowed" assert response.provider_executed is True assert response.provider_metadata == {"provider": {"approval": True}} - assert any(isinstance(p, ui_events.ToolOutputDeniedPart) for p in out) + assert any(isinstance(p, ui_events.UIToolOutputDeniedEvent) for p in out) # NOTE: agent-change boundary detection used to be driven by From 137b9baaedb76a69e3628890e47d5aa090de5606 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 16:49:33 -0700 Subject: [PATCH 07/13] Gather ui adapter code into fewer files --- src/ai/agents/ui/ai_sdk/__init__.py | 7 +- src/ai/agents/ui/ai_sdk/_approvals.py | 33 ---- src/ai/agents/ui/ai_sdk/approvals.py | 147 +++++++++++++++ .../ui/ai_sdk/{_roundtrip.py => id_utils.py} | 0 .../{inbound.py => inbound_messages.py} | 169 ++---------------- src/ai/agents/ui/ai_sdk/outbound/__init__.py | 7 - src/ai/agents/ui/ai_sdk/outbound/history.py | 94 ---------- src/ai/agents/ui/ai_sdk/outbound/sse.py | 65 ------- src/ai/agents/ui/ai_sdk/outbound/stream.py | 42 ----- .../{_parts.py => outbound_messages.py} | 129 +++++++++---- .../_state.py => outbound_stream.py} | 122 +++++++++---- src/ai/agents/ui/ai_sdk/tool_utils.py | 25 +++ tests/agents/ui/ai_sdk/outbound/test_sse.py | 2 +- tests/agents/ui/ai_sdk/test_approvals.py | 41 +---- tests/agents/ui/ai_sdk/test_inbound.py | 6 +- tests/agents/ui/ai_sdk/test_parts.py | 16 +- 16 files changed, 386 insertions(+), 519 deletions(-) delete mode 100644 src/ai/agents/ui/ai_sdk/_approvals.py create mode 100644 src/ai/agents/ui/ai_sdk/approvals.py rename src/ai/agents/ui/ai_sdk/{_roundtrip.py => id_utils.py} (100%) rename src/ai/agents/ui/ai_sdk/{inbound.py => inbound_messages.py} (79%) delete mode 100644 src/ai/agents/ui/ai_sdk/outbound/__init__.py delete mode 100644 src/ai/agents/ui/ai_sdk/outbound/history.py delete mode 100644 src/ai/agents/ui/ai_sdk/outbound/sse.py delete mode 100644 src/ai/agents/ui/ai_sdk/outbound/stream.py rename src/ai/agents/ui/ai_sdk/{_parts.py => outbound_messages.py} (77%) rename src/ai/agents/ui/ai_sdk/{outbound/_state.py => outbound_stream.py} (86%) create mode 100644 src/ai/agents/ui/ai_sdk/tool_utils.py diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py index d71c41d1..224f806e 100644 --- a/src/ai/agents/ui/ai_sdk/__init__.py +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -1,12 +1,13 @@ """AI SDK UI adapter for messages and SSE streams.""" -from .inbound import ( +from .approvals import ( ApprovalResponse, apply_approvals, extract_approvals, - to_messages, ) -from .outbound import to_sse, to_stream, to_ui_messages +from .inbound_messages import to_messages +from .outbound_messages import to_ui_messages +from .outbound_stream import to_sse, to_stream from .ui_events import UI_MESSAGE_STREAM_HEADERS from .ui_messages import UIMessage diff --git a/src/ai/agents/ui/ai_sdk/_approvals.py b/src/ai/agents/ui/ai_sdk/_approvals.py deleted file mode 100644 index f0694657..00000000 --- a/src/ai/agents/ui/ai_sdk/_approvals.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Approval-prefix linkage between ToolApproval hooks and tool calls. - -TODO(datamodel-rework §4): once ``HookPart.target_tool_call_id`` is added, -delete this module and replace call sites with direct field access. -""" - -from __future__ import annotations - -from typing import Any - -from ....types import messages as messages_ -from ...hooks import TOOL_APPROVAL_HOOK_TYPE - -_PREFIX = "approve_" - - -def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: - """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" - if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: - return None - if hook_part.hook_id.startswith(_PREFIX): - return hook_part.hook_id[len(_PREFIX) :] - return None - - -def is_tool_approval_message(msg: messages_.Message) -> bool: - """Return whether every part of ``msg`` is a ToolApproval HookPart.""" - if not msg.parts: - return False - return all( - isinstance(p, messages_.HookPart) and tool_call_id_for(p) is not None - for p in msg.parts - ) diff --git a/src/ai/agents/ui/ai_sdk/approvals.py b/src/ai/agents/ui/ai_sdk/approvals.py new file mode 100644 index 00000000..e23dc163 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/approvals.py @@ -0,0 +1,147 @@ +"""Tool approval helpers for AI SDK UI adapters.""" + +from __future__ import annotations + +from typing import Any, NamedTuple + +from ....types import messages as messages_ +from ...hooks import TOOL_APPROVAL_HOOK_TYPE, resolve_hook +from . import ui_messages + +_PREFIX = "approve_" + + +ToolPart = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + + +class ApprovalResponse(NamedTuple): + """Approval response extracted from a responded UI tool part.""" + + hook_id: str + granted: bool + reason: str | None + tool_call_id: str + + +def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: + """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" + if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: + return None + if hook_part.hook_id.startswith(_PREFIX): + return hook_part.hook_id[len(_PREFIX) :] + return None + + +def metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: + value = metadata.get(key) + return value if isinstance(value, bool) else None + + +def metadata_dict( + metadata: dict[str, Any], + key: str, +) -> dict[str, Any] | None: + value = metadata.get(key) + return value if isinstance(value, dict) else None + + +def metadata_from_tool_part(tp: ToolPart) -> dict[str, Any]: + metadata: dict[str, Any] = {} + if tp.approval is not None and tp.approval.is_automatic is not None: + metadata["isAutomatic"] = tp.approval.is_automatic + if tp.provider_executed is not None: + metadata["providerExecuted"] = tp.provider_executed + if tp.call_provider_metadata is not None: + metadata["callProviderMetadata"] = tp.call_provider_metadata + return metadata + + +def hook_part_from_tool_part(tp: ToolPart) -> messages_.HookPart[Any] | None: + """Reconstruct approval hook state from a UI tool part when possible.""" + approval = tp.approval + if approval is None: + return None + + metadata = metadata_from_tool_part(tp) + + if tp.state == "approval-requested": + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="pending", + metadata=metadata, + ) + + if tp.state == "approval-responded" and approval.approved is not None: + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="resolved", + metadata=metadata, + resolution={ + "granted": approval.approved, + "reason": approval.reason, + }, + ) + + if tp.state == "output-denied": + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="resolved", + metadata=metadata, + resolution={ + "granted": False, + "reason": approval.reason, + }, + ) + + return None + + +def extract_approvals( + ui_messages_list: list[ui_messages.UIMessage], +) -> list[ApprovalResponse]: + """Return every approval response found in UI messages.""" + approvals: list[ApprovalResponse] = [] + for ui_msg in ui_messages_list: + for part in ui_msg.parts: + if not isinstance( + part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + continue + if ( + part.state == "approval-responded" + and part.approval is not None + and part.approval.approved is not None + ): + approvals.append( + ApprovalResponse( + hook_id=part.approval.id, + granted=part.approval.approved, + reason=part.approval.reason, + tool_call_id=part.tool_call_id, + ) + ) + return approvals + + +def apply_approvals(approvals: list[ApprovalResponse]) -> None: + """Pre-register each approval resolution with the hooks registry.""" + for approval in approvals: + resolve_hook( + approval.hook_id, + {"granted": approval.granted, "reason": approval.reason}, + ) + + +def is_resolved_approval_message(msg: messages_.Message) -> bool: + """Return whether ``msg`` records a resolved tool approval hook.""" + if msg.role != "internal" or len(msg.parts) != 1: + return False + part = msg.parts[0] + return ( + isinstance(part, messages_.HookPart) + and part.hook_type == TOOL_APPROVAL_HOOK_TYPE + and part.status == "resolved" + ) diff --git a/src/ai/agents/ui/ai_sdk/_roundtrip.py b/src/ai/agents/ui/ai_sdk/id_utils.py similarity index 100% rename from src/ai/agents/ui/ai_sdk/_roundtrip.py rename to src/ai/agents/ui/ai_sdk/id_utils.py diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py similarity index 79% rename from src/ai/agents/ui/ai_sdk/inbound.py rename to src/ai/agents/ui/ai_sdk/inbound_messages.py index 41870e8f..e252e304 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -8,13 +8,14 @@ import json import logging -from typing import Any, NamedTuple +from typing import Any from ....types import messages as messages_ from ...agent import MessageBundle -from ...hooks import resolve_hook -from . import _roundtrip +from . import approvals, id_utils from . import ui_messages as ui_messages_ +from .approvals import ApprovalResponse, extract_approvals +from .tool_utils import normalize_tool_args logger = logging.getLogger(__name__) @@ -33,19 +34,6 @@ def _is_tool_error(state: ui_messages_.UIToolInvocationState) -> bool: return state in _TOOL_ERROR_STATES -# TODO(datamodel-rework §4): once tool args have a canonical shape, drop -# these normalizers. -def _normalize_tool_args(tool_input: Any) -> str: - """Normalize tool input (JSON string, JSON value, or None) to a string.""" - match tool_input: - case str(): - return tool_input - case None: - return "{}" - case _: - return json.dumps(tool_input) - - def _tool_input_for_args( part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> Any: @@ -106,121 +94,6 @@ def _decode_wire_output(output: Any) -> Any: return MessageBundle(messages=tuple(inner)) -def _approval_hook_part( - tp: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, -) -> messages_.HookPart[Any] | None: - """Reconstruct approval hook state from a UI tool part when possible.""" - approval = tp.approval - if approval is None: - return None - metadata = _approval_metadata(tp) - - if tp.state == "approval-requested": - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="pending", - metadata=metadata, - ) - - if tp.state == "approval-responded" and approval.approved is not None: - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="resolved", - metadata=metadata, - resolution={ - "granted": approval.approved, - "reason": approval.reason, - }, - ) - - if tp.state == "output-denied": - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="resolved", - metadata=metadata, - resolution={ - "granted": False, - "reason": approval.reason, - }, - ) - - return None - - -def _approval_metadata( - tp: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, -) -> dict[str, Any]: - metadata: dict[str, Any] = {} - if tp.approval is not None and tp.approval.is_automatic is not None: - metadata["isAutomatic"] = tp.approval.is_automatic - if tp.provider_executed is not None: - metadata["providerExecuted"] = tp.provider_executed - if tp.call_provider_metadata is not None: - metadata["callProviderMetadata"] = tp.call_provider_metadata - return metadata - - -# ============================================================================ -# Approval extraction + bulk resolution -# ============================================================================ - - -class ApprovalResponse(NamedTuple): - """Approval response extracted from a responded UIToolPart.""" - - hook_id: str - granted: bool - reason: str | None - tool_call_id: str - - -def extract_approvals( - ui_messages: list[ui_messages_.UIMessage], -) -> list[ApprovalResponse]: - """Return every approval response found in *ui_messages*. - - Pure function — does not resolve hooks or trigger side effects. - """ - approvals: list[ApprovalResponse] = [] - for ui_msg in ui_messages: - for part in ui_msg.parts: - if not isinstance( - part, ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart - ): - continue - if ( - part.state == "approval-responded" - and part.approval is not None - and part.approval.approved is not None - ): - approvals.append( - ApprovalResponse( - hook_id=part.approval.id, - granted=part.approval.approved, - reason=part.approval.reason, - tool_call_id=part.tool_call_id, - ) - ) - return approvals - - -def apply_approvals(approvals: list[ApprovalResponse]) -> None: - """Pre-register each approval resolution with the hooks registry.""" - for approval in approvals: - resolve_hook( - approval.hook_id, - {"granted": approval.granted, "reason": approval.reason}, - ) - - -# ============================================================================ -# UI message normalization (heal stale tool states) -# ============================================================================ - - def _normalize_ui_messages( ui_messages: list[ui_messages_.UIMessage], ) -> list[ui_messages_.UIMessage]: @@ -292,10 +165,14 @@ def to_messages( :meth:`Agent.run` if the run should resume from a hook. """ normalized = _normalize_ui_messages(ui_messages) - approvals = extract_approvals(normalized) - messages = [m for m in _parse(normalized) if not _is_approval_response(m)] - _patch_pending_hook_aborts(messages, approvals) - return messages, approvals + approval_responses = extract_approvals(normalized) + messages = [ + m + for m in _parse(normalized) + if not approvals.is_resolved_approval_message(m) + ] + _patch_pending_hook_aborts(messages, approval_responses) + return messages, approval_responses def _patch_pending_hook_aborts( @@ -351,18 +228,6 @@ def _patch_pending_hook_aborts( messages[-1] = tool_msg.model_copy(update={"parts": new_parts}) -def _is_approval_response(msg: messages_.Message) -> bool: - """Return whether ``msg`` records a resolved tool-approval hook.""" - if msg.role != "internal" or len(msg.parts) != 1: - return False - part = msg.parts[0] - return ( - isinstance(part, messages_.HookPart) - and part.hook_type == "ToolApproval" - and part.status == "resolved" - ) - - def _parse( ui_messages: list[ui_messages_.UIMessage], ) -> list[messages_.Message]: @@ -408,7 +273,7 @@ def _build_builtin_return_part( result: list[messages_.Message] = [] for ui_msg in ui_messages: - source_messages = _roundtrip.source_messages_from(ui_msg.metadata) + source_messages = id_utils.source_messages_from(ui_msg.metadata) assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] hook_parts: list[messages_.HookPart[Any]] = [] @@ -476,7 +341,7 @@ def _build_builtin_return_part( ) as tp ): tool_input = _tool_input_for_args(tp) - tool_args = _normalize_tool_args(tool_input) + tool_args = normalize_tool_args(tool_input) if tp.provider_executed: assistant_parts.append( @@ -496,7 +361,7 @@ def _build_builtin_return_part( provider_metadata=tp.call_provider_metadata, ) ) - approval_hook = _approval_hook_part(tp) + approval_hook = approvals.hook_part_from_tool_part(tp) if approval_hook is not None: hook_parts.append(approval_hook) @@ -580,11 +445,11 @@ def _build_builtin_return_part( ) ) result.extend( - _roundtrip.restore_source_ids(parsed, source_messages) + id_utils.restore_source_ids(parsed, source_messages) ) else: result.extend( - _roundtrip.restore_source_ids( + id_utils.restore_source_ids( [ messages_.Message( id=ui_msg.id, diff --git a/src/ai/agents/ui/ai_sdk/outbound/__init__.py b/src/ai/agents/ui/ai_sdk/outbound/__init__.py deleted file mode 100644 index abb6f69c..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Outbound adapter: ``ai.messages.Message`` stream → AI SDK UI protocol.""" - -from .history import to_ui_messages -from .sse import to_sse -from .stream import to_stream - -__all__ = ["to_sse", "to_stream", "to_ui_messages"] diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py deleted file mode 100644 index 5c4bfc42..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Persisted-message → UIMessage list for history endpoints.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .. import _parts, _roundtrip, ui_messages - -if TYPE_CHECKING: - from .....types import messages as messages_ - - -def _message_turn_key(message: messages_.Message) -> str | None: - return message.turn_id - - -def _assistant_bubble_id(message: messages_.Message) -> str: - return _message_turn_key(message) or message.id - - -def _belongs_to_bubble( - message: messages_.Message, - bubble_id: str, -) -> bool: - key = _message_turn_key(message) - return key is None or key == bubble_id - - -def to_ui_messages( - messages: list[messages_.Message], -) -> list[ui_messages.UIMessage]: - """Group persisted messages into UIMessage bubbles. - - ``user``/``system`` messages become standalone UIMessages. Runs of - ``assistant``/``tool``/``internal`` messages merge into a single - assistant UIMessage, with tool results and approval state folded into - the corresponding tool-call parts. - """ - result: list[ui_messages.UIMessage] = [] - - i = 0 - while i < len(messages): - msg = messages[i] - - if msg.role in ("user", "system"): - result.append( - ui_messages.UIMessage( - id=msg.id, - role=msg.role, - metadata=_roundtrip.metadata_for([msg]), - parts=_parts.to_ui_parts(msg.parts), - ) - ) - i += 1 - continue - - if msg.role == "assistant": - ui_parts: list[ui_messages.UIMessagePart] = [] - source_messages: list[messages_.Message] = [] - bubble_id = _assistant_bubble_id(msg) - - while i < len(messages) and messages[i].role in ( - "assistant", - "tool", - "internal", - ): - current = messages[i] - if not _belongs_to_bubble(current, bubble_id): - break - source_messages.append(current) - if current.role == "assistant": - ui_parts.extend(_parts.to_ui_parts(current.parts)) - ui_parts = _parts.dedupe_tool_parts(ui_parts) - elif current.role == "tool": - _parts.merge_tool_results(ui_parts, current.parts) - elif current.role == "internal": - _parts.merge_approval_signals(ui_parts, current.parts) - i += 1 - ui_parts = _parts.dedupe_tool_parts(ui_parts) - - result.append( - ui_messages.UIMessage( - id=bubble_id, - role="assistant", - metadata=_roundtrip.metadata_for(source_messages), - parts=ui_parts, - ) - ) - continue - - # Orphan tool / internal messages — skip; they have no assistant anchor. - i += 1 - - return result diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py deleted file mode 100644 index 9fb667b6..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Serialize the UI message stream as Server-Sent Events.""" - -from __future__ import annotations - -import dataclasses -import json -from typing import TYPE_CHECKING, Any - -import pydantic - -from .. import ui_events -from .stream import to_stream - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterable - - from .....types import events as events_ - - -def _to_camel_case(snake_str: str) -> str: - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def _json_default(obj: Any) -> Any: - """Fallback encoder for json.dumps — handle pydantic models recursively. - - Aggregator snapshots and tool outputs may carry pydantic models - (e.g. ``MessageBundle``, ``UIMessage``). ``model_dump(mode="json")`` - converts them to plain JSON-native dicts/lists. - """ - if isinstance(obj, pydantic.BaseModel): - return obj.model_dump(mode="json", by_alias=True) - raise TypeError( - f"Object of type {type(obj).__name__} is not JSON serializable" - ) - - -def serialize_event(event: ui_events.UIMessageStreamEvent) -> str: - """Serialize a stream event to JSON with camelCase keys.""" - d = dataclasses.asdict(event) - if isinstance(event, ui_events.UIDataEvent): - d["type"] = event.type - del d["data_type"] - camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} - return json.dumps(camel_dict, default=_json_default) - - -def format_sse(event: ui_events.UIMessageStreamEvent) -> str: - """Format a stream event as an SSE data line.""" - return f"data: {serialize_event(event)}\n\n" - - -def format_done_sse() -> str: - """Format the AI SDK UI stream termination marker.""" - return "data: [DONE]\n\n" - - -async def to_sse( - events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[str]: - """Convert an internal event stream into SSE strings.""" - async for event in to_stream(events): - yield format_sse(event) - yield format_done_sse() diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py deleted file mode 100644 index 4556fad4..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Convert internal event streams into AI SDK UI protocol events.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .....types import events as events_ -from ._state import _StreamState - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterable - - from .. import ui_events - - -async def to_stream( - events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: - """Walk ``events`` once, emitting AI SDK UI stream events. - - Streaming text/reasoning/tool-input deltas come from model events. - Tool results come from ``ToolCallResult``. Hook signals come from - ``HookEvent``. - """ - state = _StreamState() - - async for event in events: - if isinstance(event, events_.ToolCallResult): - for ui_event in state.on_tool_result(event): - yield ui_event - elif isinstance(event, events_.PartialToolCallResult): - for ui_event in state.on_partial_tool_result(event): - yield ui_event - elif isinstance(event, events_.HookEvent): - for ui_event in state.on_hook(event): - yield ui_event - else: - for ui_event in state.on_event(event): - yield ui_event - - for ui_event in state.finish(): - yield ui_event diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py similarity index 77% rename from src/ai/agents/ui/ai_sdk/_parts.py rename to src/ai/agents/ui/ai_sdk/outbound_messages.py index 312c90fa..06e1a9f1 100644 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -1,18 +1,13 @@ -"""Shared conversions between internal Part objects and UIMessagePart objects. - -Used by ``outbound.history`` to reconstruct UIMessages from persisted -``ai.messages.Message`` lists. The live outbound stream does not use these; it -emits wire-protocol deltas directly from event streams. -""" +"""Persisted-message conversion for AI SDK UI messages.""" from __future__ import annotations -import json from typing import Any, TypeGuard, cast from ....types import media from ....types import messages as messages_ -from . import _approvals, ui_messages +from . import approvals, id_utils, ui_messages +from .tool_utils import normalize_tool_input UIToolLike = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart @@ -27,23 +22,6 @@ } -def _normalize_tool_input(raw: str) -> Any: - """Parse tool args JSON string into a JSON value; fall back to raw string. - - TODO(datamodel-rework §4): once ``ToolCallPart.tool_args`` has a - canonical shape, drop this helper. - """ - try: - return json.loads(raw) - except (json.JSONDecodeError, TypeError): - return raw - - -def _metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: - value = metadata.get(key) - return value if isinstance(value, bool) else None - - def _is_tool_part(part: ui_messages.UIMessagePart) -> TypeGuard[UIToolLike]: return isinstance( part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart @@ -54,8 +32,24 @@ def _tool_call_id(part: UIToolLike) -> str: return part.tool_call_id +def _message_turn_key(message: messages_.Message) -> str | None: + return message.turn_id + + +def _assistant_bubble_id(message: messages_.Message) -> str: + return _message_turn_key(message) or message.id + + +def _belongs_to_bubble( + message: messages_.Message, + bubble_id: str, +) -> bool: + key = _message_turn_key(message) + return key is None or key == bubble_id + + def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: - """Convert internal Part objects to UIMessagePart objects.""" + """Convert internal parts to UI message parts.""" result: list[ui_messages.UIMessagePart] = [] for part in parts: if isinstance(part, messages_.TextPart) and part.text: @@ -85,7 +79,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: "type": f"tool-{part.tool_name}", "toolCallId": part.tool_call_id, "state": "input-available", - "input": _normalize_tool_input(part.tool_args), + "input": normalize_tool_input(part.tool_args), "callProviderMetadata": part.provider_metadata, } ) @@ -98,7 +92,7 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: "toolName": part.tool_name, "toolCallId": part.tool_call_id, "state": "input-available", - "input": _normalize_tool_input(part.tool_args), + "input": normalize_tool_input(part.tool_args), "providerExecuted": True, "callProviderMetadata": part.provider_metadata, } @@ -186,7 +180,7 @@ def _merge_tool_part( def dedupe_tool_parts( ui_parts: list[ui_messages.UIMessagePart], ) -> list[ui_messages.UIMessagePart]: - """Collapse duplicate UIToolParts by tool_call_id.""" + """Collapse duplicate UI tool parts by tool_call_id.""" result: list[ui_messages.UIMessagePart] = [] tool_index: dict[str, int] = {} @@ -212,7 +206,7 @@ def merge_tool_results( ui_parts: list[ui_messages.UIMessagePart], tool_parts: list[messages_.Part], ) -> None: - """Merge ToolResultParts into existing UIToolParts in-place.""" + """Merge tool result parts into existing UI tool parts.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): if _is_tool_part(ui_part): @@ -245,9 +239,7 @@ def merge_tool_results( updates["output"] = part.result else: continue - # Hook-abort placeholders are internal: the corresponding - # HookPart(pending) carries the user-visible state via - # merge_approval_signals. + if isinstance(part, messages_.ToolResultPart) and part.is_hook_pending: continue idx_opt = tool_index.get(tool_call_id) @@ -266,7 +258,7 @@ def merge_approval_signals( ui_parts: list[ui_messages.UIMessagePart], internal_parts: list[messages_.Part], ) -> None: - """Merge HookPart approval state into existing UIToolParts in-place.""" + """Merge approval hook state into existing UI tool parts.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): if _is_tool_part(ui_part): @@ -276,7 +268,7 @@ def merge_approval_signals( if not isinstance(part, messages_.HookPart): continue - tool_call_id = _approvals.tool_call_id_for(part) + tool_call_id = approvals.tool_call_id_for(part) if tool_call_id is None: continue @@ -290,7 +282,7 @@ def merge_approval_signals( continue updates: dict[str, Any] = {} - if (provider_executed := _metadata_bool( + if (provider_executed := approvals.metadata_bool( part.metadata, "providerExecuted" )) is not None: updates["provider_executed"] = provider_executed @@ -299,7 +291,7 @@ def merge_approval_signals( updates["approval"] = ui_messages.UIToolApproval.model_validate( { "id": part.hook_id, - "isAutomatic": _metadata_bool( + "isAutomatic": approvals.metadata_bool( part.metadata, "isAutomatic" ), } @@ -314,7 +306,7 @@ def merge_approval_signals( "id": part.hook_id, "approved": resolution.get("granted"), "reason": resolution.get("reason"), - "isAutomatic": _metadata_bool( + "isAutomatic": approvals.metadata_bool( part.metadata, "isAutomatic" ), } @@ -330,3 +322,64 @@ def merge_approval_signals( if updates: ui_parts[idx] = existing.model_copy(update=updates) + + +def to_ui_messages( + messages: list[messages_.Message], +) -> list[ui_messages.UIMessage]: + """Group persisted messages into UI message bubbles.""" + result: list[ui_messages.UIMessage] = [] + + i = 0 + while i < len(messages): + msg = messages[i] + + if msg.role in ("user", "system"): + result.append( + ui_messages.UIMessage( + id=msg.id, + role=msg.role, + metadata=id_utils.metadata_for([msg]), + parts=to_ui_parts(msg.parts), + ) + ) + i += 1 + continue + + if msg.role == "assistant": + ui_parts: list[ui_messages.UIMessagePart] = [] + source_messages: list[messages_.Message] = [] + bubble_id = _assistant_bubble_id(msg) + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "internal", + ): + current = messages[i] + if not _belongs_to_bubble(current, bubble_id): + break + source_messages.append(current) + if current.role == "assistant": + ui_parts.extend(to_ui_parts(current.parts)) + ui_parts = dedupe_tool_parts(ui_parts) + elif current.role == "tool": + merge_tool_results(ui_parts, current.parts) + elif current.role == "internal": + merge_approval_signals(ui_parts, current.parts) + i += 1 + ui_parts = dedupe_tool_parts(ui_parts) + + result.append( + ui_messages.UIMessage( + id=bubble_id, + role="assistant", + metadata=id_utils.metadata_for(source_messages), + parts=ui_parts, + ) + ) + continue + + i += 1 + + return result diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py similarity index 86% rename from src/ai/agents/ui/ai_sdk/outbound/_state.py rename to src/ai/agents/ui/ai_sdk/outbound_stream.py index 3d92b351..b522eaa1 100644 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -1,16 +1,22 @@ -"""Stream state bookkeeping for the event-first outbound walk.""" +"""Live event stream conversion for the AI SDK UI protocol.""" from __future__ import annotations +import dataclasses import json -from typing import Any +from typing import TYPE_CHECKING, Any -from .....types import events as events_ -from .....types import media -from .....types import messages as messages_ -from ....agent import MessageBundle -from .. import _approvals, ui_events -from . import history +import pydantic + +from ....types import events as events_ +from ....types import media +from ....types import messages as messages_ +from ...agent import MessageBundle +from . import approvals, outbound_messages, ui_events +from .tool_utils import normalize_tool_input + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, AsyncIterable def _tool_error_text(part: messages_.ToolResultPart) -> str: @@ -25,26 +31,6 @@ def _tool_error_text(part: messages_.ToolResultPart) -> str: return "Tool execution failed" -def _normalize_tool_input(raw: str) -> Any: - try: - return json.loads(raw) - except Exception: - return raw - - -def _metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: - value = metadata.get(key) - return value if isinstance(value, bool) else None - - -def _metadata_dict( - metadata: dict[str, Any], - key: str, -) -> dict[str, Any] | None: - value = metadata.get(key) - return value if isinstance(value, dict) else None - - def _to_wire_output(snapshot: Any) -> Any: """Convert an aggregator snapshot to its UI wire representation. @@ -57,7 +43,7 @@ def _to_wire_output(snapshot: Any) -> Any: skip emitting in that case. """ if isinstance(snapshot, MessageBundle): - ui_msgs = history.to_ui_messages(list(snapshot.messages)) + ui_msgs = outbound_messages.to_ui_messages(list(snapshot.messages)) return ui_msgs[-1] if ui_msgs else None return snapshot @@ -306,7 +292,7 @@ def on_event( ui_events.UIToolInputAvailableEvent( tool_call_id=tcid, tool_name=tc.tool_name, - input=_normalize_tool_input(tc.tool_args), + input=normalize_tool_input(tc.tool_args), provider_executed=True, provider_metadata=tc.provider_metadata or event.provider_metadata, @@ -386,7 +372,7 @@ def on_tool_result( ui_events.UIToolInputAvailableEvent( tool_call_id=part.tool_call_id, tool_name=part.tool_name, - input=_normalize_tool_input(part.tool_args), + input=normalize_tool_input(part.tool_args), provider_metadata=part.provider_metadata, ) ) @@ -480,7 +466,7 @@ def on_hook( # Ensure the UI message is started. out.extend(self._ensure_started(event.message.turn_id)) - tc_id = _approvals.tool_call_id_for(hook_part) + tc_id = approvals.tool_call_id_for(hook_part) if tc_id is None: return out @@ -492,7 +478,7 @@ def on_hook( ui_events.UIToolApprovalRequestEvent( approval_id=hook_part.hook_id, tool_call_id=tc_id, - is_automatic=_metadata_bool( + is_automatic=approvals.metadata_bool( hook_part.metadata, "isAutomatic" ), ) @@ -504,10 +490,10 @@ def on_hook( approval_id=hook_part.hook_id, approved=bool(resolution.get("granted")), reason=resolution.get("reason"), - provider_executed=_metadata_bool( + provider_executed=approvals.metadata_bool( hook_part.metadata, "providerExecuted" ), - provider_metadata=_metadata_dict( + provider_metadata=approvals.metadata_dict( hook_part.metadata, "callProviderMetadata" ), ) @@ -531,3 +517,69 @@ def finish(self) -> list[ui_events.UIMessageStreamEvent]: if self.emitted_start: events.append(ui_events.UIFinishEvent(finish_reason="stop")) return events + + +async def to_stream( + events: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: + """Walk internal events once, emitting AI SDK UI stream events.""" + state = _StreamState() + + async for event in events: + if isinstance(event, events_.ToolCallResult): + for ui_event in state.on_tool_result(event): + yield ui_event + elif isinstance(event, events_.PartialToolCallResult): + for ui_event in state.on_partial_tool_result(event): + yield ui_event + elif isinstance(event, events_.HookEvent): + for ui_event in state.on_hook(event): + yield ui_event + else: + for ui_event in state.on_event(event): + yield ui_event + + for ui_event in state.finish(): + yield ui_event + + +def _to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _json_default(obj: Any) -> Any: + if isinstance(obj, pydantic.BaseModel): + return obj.model_dump(mode="json", by_alias=True) + raise TypeError( + f"Object of type {type(obj).__name__} is not JSON serializable" + ) + + +def serialize_event(event: ui_events.UIMessageStreamEvent) -> str: + """Serialize a stream event to JSON with camelCase keys.""" + d = dataclasses.asdict(event) + if isinstance(event, ui_events.UIDataEvent): + d["type"] = event.type + del d["data_type"] + camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} + return json.dumps(camel_dict, default=_json_default) + + +def format_sse(event: ui_events.UIMessageStreamEvent) -> str: + """Format a stream event as an SSE data line.""" + return f"data: {serialize_event(event)}\n\n" + + +def format_done_sse() -> str: + """Format the AI SDK UI stream termination marker.""" + return "data: [DONE]\n\n" + + +async def to_sse( + events: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[str]: + """Convert an internal event stream into SSE strings.""" + async for event in to_stream(events): + yield format_sse(event) + yield format_done_sse() diff --git a/src/ai/agents/ui/ai_sdk/tool_utils.py b/src/ai/agents/ui/ai_sdk/tool_utils.py new file mode 100644 index 00000000..b3f5d8a4 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/tool_utils.py @@ -0,0 +1,25 @@ +"""Tool value helpers shared by AI SDK UI adapters.""" + +from __future__ import annotations + +import json +from typing import Any + + +def normalize_tool_input(raw: str) -> Any: + """Parse serialized tool args into a JSON value when possible.""" + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError): + return raw + + +def normalize_tool_args(tool_input: Any) -> str: + """Normalize UI tool input to the internal serialized args form.""" + match tool_input: + case str(): + return tool_input + case None: + return "{}" + case _: + return json.dumps(tool_input) diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py index 02308aee..d9a6a7c9 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ b/tests/agents/ui/ai_sdk/outbound/test_sse.py @@ -4,7 +4,7 @@ from collections.abc import AsyncGenerator from ai.agents.ui.ai_sdk import to_sse, ui_events -from ai.agents.ui.ai_sdk.outbound.sse import ( +from ai.agents.ui.ai_sdk.outbound_stream import ( format_done_sse, format_sse, serialize_event, diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py index 993a9c25..545bfa6d 100644 --- a/tests/agents/ui/ai_sdk/test_approvals.py +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -2,7 +2,7 @@ from typing import Any -from ai.agents.ui.ai_sdk import _approvals +from ai.agents.ui.ai_sdk import approvals from ai.types import messages as messages_ @@ -12,7 +12,7 @@ def test_tool_call_id_for_strips_prefix() -> None: hook_type="ToolApproval", status="pending", ) - assert _approvals.tool_call_id_for(hook) == "tc_42" + assert approvals.tool_call_id_for(hook) == "tc_42" def test_tool_call_id_for_rejects_non_approval_type() -> None: @@ -21,7 +21,7 @@ def test_tool_call_id_for_rejects_non_approval_type() -> None: hook_type="SomethingElse", status="pending", ) - assert _approvals.tool_call_id_for(hook) is None + assert approvals.tool_call_id_for(hook) is None def test_tool_call_id_for_rejects_bad_prefix() -> None: @@ -30,37 +30,4 @@ def test_tool_call_id_for_rejects_bad_prefix() -> None: hook_type="ToolApproval", status="pending", ) - assert _approvals.tool_call_id_for(hook) is None - - -def test_is_tool_approval_message_detects_all_approval_hooks() -> None: - msg = messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="approve_tc_1", - hook_type="ToolApproval", - status="pending", - ), - ], - ) - assert _approvals.is_tool_approval_message(msg) - - -def test_is_tool_approval_message_false_for_non_approval() -> None: - msg = messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="other", - hook_type="Something", - status="pending", - ), - ], - ) - assert not _approvals.is_tool_approval_message(msg) - - -def test_is_tool_approval_message_false_for_empty() -> None: - msg = messages_.Message(role="internal", parts=[]) - assert not _approvals.is_tool_approval_message(msg) + assert approvals.tool_call_id_for(hook) is None diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound.py index 7bc89009..ff27a290 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound.py @@ -6,10 +6,8 @@ from ai.agents.agent import MessageBundle from ai.agents.ui.ai_sdk import to_messages -from ai.agents.ui.ai_sdk.inbound import ( - _normalize_ui_messages, - extract_approvals, -) +from ai.agents.ui.ai_sdk.approvals import extract_approvals +from ai.agents.ui.ai_sdk.inbound_messages import _normalize_ui_messages from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart from ai.types import messages as messages_ diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py index 1b581e75..dcd6692d 100644 --- a/tests/agents/ui/ai_sdk/test_parts.py +++ b/tests/agents/ui/ai_sdk/test_parts.py @@ -1,6 +1,6 @@ from __future__ import annotations -from ai.agents.ui.ai_sdk import _parts +from ai.agents.ui.ai_sdk import outbound_messages from ai.agents.ui.ai_sdk.ui_messages import ( UIReasoningPart, UITextPart, @@ -15,7 +15,7 @@ def test_to_ui_parts_text_and_reasoning() -> None: messages_.ReasoningPart(text="thinking"), messages_.TextPart(text="hi"), ] - ui_parts = _parts.to_ui_parts(parts) + ui_parts = outbound_messages.to_ui_parts(parts) assert isinstance(ui_parts[0], UIReasoningPart) assert ui_parts[0].text == "thinking" assert isinstance(ui_parts[1], UITextPart) @@ -30,7 +30,7 @@ def test_to_ui_parts_tool_call_parses_json_args() -> None: tool_args='{"q": "x"}', ) ] - ui_parts = _parts.to_ui_parts(parts) + ui_parts = outbound_messages.to_ui_parts(parts) assert isinstance(ui_parts[0], UIToolPart) assert ui_parts[0].type == "tool-search" assert ui_parts[0].input == {"q": "x"} @@ -45,8 +45,8 @@ def test_merge_tool_results_updates_state_and_output() -> None: tool_args="{}", ) ] - ui_parts = _parts.to_ui_parts(parts) - _parts.merge_tool_results( + ui_parts = outbound_messages.to_ui_parts(parts) + outbound_messages.merge_tool_results( ui_parts, [ messages_.ToolResultPart( @@ -70,9 +70,9 @@ def test_merge_approval_signals_pending_then_resolved() -> None: tool_args="{}", ) ] - ui_parts = _parts.to_ui_parts(parts) + ui_parts = outbound_messages.to_ui_parts(parts) - _parts.merge_approval_signals( + outbound_messages.merge_approval_signals( ui_parts, [ messages_.HookPart( @@ -87,7 +87,7 @@ def test_merge_approval_signals_pending_then_resolved() -> None: assert requested.state == "approval-requested" assert isinstance(requested.approval, UIToolApproval) - _parts.merge_approval_signals( + outbound_messages.merge_approval_signals( ui_parts, [ messages_.HookPart( From db21c371b0e953f718f3ec5063dc953354fd4a8c Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 17:08:01 -0700 Subject: [PATCH 08/13] Move tests around to match project file layout --- src/ai/agents/ui/ai_sdk/id_utils.py | 15 ++- src/ai/agents/ui/ai_sdk/inbound_messages.py | 4 +- src/ai/agents/ui/ai_sdk/outbound_messages.py | 8 +- src/ai/agents/ui/ai_sdk/outbound_stream.py | 4 +- tests/agents/ui/ai_sdk/outbound/__init__.py | 0 tests/agents/ui/ai_sdk/outbound/test_sse.py | 85 -------------- tests/agents/ui/ai_sdk/test_approvals.py | 75 +++++++++++++ ...st_inbound.py => test_inbound_messages.py} | 55 --------- ...t_history.py => test_outbound_messages.py} | 99 ++++++++++++++++- ...test_stream.py => test_outbound_stream.py} | 78 ++++++++++++- tests/agents/ui/ai_sdk/test_parts.py | 105 ------------------ 11 files changed, 264 insertions(+), 264 deletions(-) delete mode 100644 tests/agents/ui/ai_sdk/outbound/__init__.py delete mode 100644 tests/agents/ui/ai_sdk/outbound/test_sse.py rename tests/agents/ui/ai_sdk/{test_inbound.py => test_inbound_messages.py} (84%) rename tests/agents/ui/ai_sdk/{outbound/test_history.py => test_outbound_messages.py} (84%) rename tests/agents/ui/ai_sdk/{outbound/test_stream.py => test_outbound_stream.py} (84%) delete mode 100644 tests/agents/ui/ai_sdk/test_parts.py diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py index 56d23581..fb20c093 100644 --- a/src/ai/agents/ui/ai_sdk/id_utils.py +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -48,11 +48,13 @@ def source_messages_from(metadata: object) -> list[SourceMessage]: if not isinstance(metadata, dict): return [] - adapter_metadata = metadata.get(ADAPTER_METADATA_KEY) + metadata_dict = cast("dict[str, object]", metadata) + adapter_metadata = metadata_dict.get(ADAPTER_METADATA_KEY) if not isinstance(adapter_metadata, dict): return [] - raw_source_messages = adapter_metadata.get(SOURCE_MESSAGES_KEY) + adapter_metadata_dict = cast("dict[str, object]", adapter_metadata) + raw_source_messages = adapter_metadata_dict.get(SOURCE_MESSAGES_KEY) if not isinstance(raw_source_messages, list): return [] @@ -95,15 +97,16 @@ def _parse_source_message(raw: object) -> SourceMessage | None: if not isinstance(raw, dict): return None - message_id = raw.get("id") - role = raw.get("role") + raw_dict = cast("dict[str, object]", raw) + message_id = raw_dict.get("id") + role = raw_dict.get("role") if not isinstance(message_id, str) or role not in _VALID_ROLES: return None - raw_turn_id = raw.get("turnId") + raw_turn_id = raw_dict.get("turnId") turn_id = raw_turn_id if isinstance(raw_turn_id, str) else None - raw_part_ids = raw.get("partIds") + raw_part_ids = raw_dict.get("partIds") part_ids = ( tuple(part_id for part_id in raw_part_ids if isinstance(part_id, str)) if isinstance(raw_part_ids, list) diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index e252e304..d14c7f33 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -444,9 +444,7 @@ def _build_builtin_return_part( parts=[hp], ) ) - result.extend( - id_utils.restore_source_ids(parsed, source_messages) - ) + result.extend(id_utils.restore_source_ids(parsed, source_messages)) else: result.extend( id_utils.restore_source_ids( diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index 06e1a9f1..c92e7bae 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -282,9 +282,11 @@ def merge_approval_signals( continue updates: dict[str, Any] = {} - if (provider_executed := approvals.metadata_bool( - part.metadata, "providerExecuted" - )) is not None: + if ( + provider_executed := approvals.metadata_bool( + part.metadata, "providerExecuted" + ) + ) is not None: updates["provider_executed"] = provider_executed if part.status == "pending": updates["state"] = "approval-requested" diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index b522eaa1..d33e689b 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -499,7 +499,9 @@ def on_hook( ) ) if not resolution.get("granted"): - out.append(ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id)) + out.append( + ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id) + ) elif hook_part.status == "cancelled": out.append( ui_events.UIToolOutputErrorEvent( diff --git a/tests/agents/ui/ai_sdk/outbound/__init__.py b/tests/agents/ui/ai_sdk/outbound/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py deleted file mode 100644 index d9a6a7c9..00000000 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ /dev/null @@ -1,85 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncGenerator - -from ai.agents.ui.ai_sdk import to_sse, ui_events -from ai.agents.ui.ai_sdk.outbound_stream import ( - format_done_sse, - format_sse, - serialize_event, -) -from ai.types import events as agent_events_ -from ai.types import events as events_ - - -def test_serialize_event_camelcases_keys() -> None: - event = ui_events.UIStartEvent(message_id="m1") - payload = json.loads(serialize_event(event)) - assert payload == {"type": "start", "messageId": "m1"} - - -def test_format_sse_wraps_data_line() -> None: - event = ui_events.UITextDeltaEvent(id="t1", delta="hi") - line = format_sse(event) - assert line.startswith("data: ") - assert line.endswith("\n\n") - - -def test_serialize_data_event_uses_type_with_prefix() -> None: - event = ui_events.UIDataEvent(data_type="custom", data={"k": 1}) - payload = json.loads(serialize_event(event)) - assert payload["type"] == "data-custom" - assert "dataType" not in payload - - -def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: - event = ui_events.UIToolApprovalResponseEvent( - approval_id="approval-1", - approved=False, - reason="no", - provider_executed=True, - provider_metadata={"provider": {"k": "v"}}, - ) - - payload = json.loads(serialize_event(event)) - - assert payload == { - "type": "tool-approval-response", - "approvalId": "approval-1", - "approved": False, - "reason": "no", - "providerExecuted": True, - "providerMetadata": {"provider": {"k": "v"}}, - } - - -def test_format_done_sse_returns_done_sentinel() -> None: - assert format_done_sse() == "data: [DONE]\n\n" - - -async def _gen( - stream_events: list[agent_events_.AgentEvent], -) -> AsyncGenerator[agent_events_.AgentEvent]: - for event in stream_events: - yield event - - -async def test_to_sse_emits_data_prefixed_lines() -> None: - lines = [ - line - async for line in to_sse( - _gen( - [ - events_.TextStart(block_id="t1"), - events_.TextDelta(block_id="t1", chunk="hi"), - events_.TextEnd(block_id="t1"), - ] - ) - ) - ] - assert all(line.startswith("data: ") for line in lines) - # first line is the start event (lazy open) - first = json.loads(lines[0].removeprefix("data: ").rstrip()) - assert first["type"] == "start" - assert lines[-1] == "data: [DONE]\n\n" diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py index 545bfa6d..447ef9db 100644 --- a/tests/agents/ui/ai_sdk/test_approvals.py +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -3,9 +3,30 @@ from typing import Any from ai.agents.ui.ai_sdk import approvals +from ai.agents.ui.ai_sdk.ui_messages import UIMessage from ai.types import messages as messages_ +def _ui(role: str, *parts: dict[str, Any], id: str = "m1") -> UIMessage: + return UIMessage.model_validate( + {"id": id, "role": role, "parts": list(parts)} + ) + + +def _tool( + tool_name: str, + tool_call_id: str, + state: str, + **extra: Any, +) -> dict[str, Any]: + return { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "state": state, + **extra, + } + + def test_tool_call_id_for_strips_prefix() -> None: hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc_42", @@ -24,6 +45,60 @@ def test_tool_call_id_for_rejects_non_approval_type() -> None: assert approvals.tool_call_id_for(hook) is None +def test_extract_approvals_returns_approved_responses() -> None: + approval_responses = approvals.extract_approvals( + [ + _ui( + "assistant", + _tool( + "x", + "tc1", + "approval-responded", + approval={ + "id": "approve_tc1", + "approved": False, + "reason": "nope", + }, + ), + ) + ] + ) + assert len(approval_responses) == 1 + assert approval_responses[0].hook_id == "approve_tc1" + assert approval_responses[0].granted is False + assert approval_responses[0].reason == "nope" + + +def test_extract_approvals_handles_dynamic_tool_responses() -> None: + approval_responses = approvals.extract_approvals( + [ + _ui( + "assistant", + { + "type": "dynamic-tool", + "toolName": "web_search", + "toolCallId": "tc1", + "state": "approval-responded", + "input": {"query": "ai"}, + "approval": { + "id": "approve_tc1", + "approved": True, + "reason": "ok", + "isAutomatic": True, + }, + "providerExecuted": True, + }, + ) + ] + ) + + assert len(approval_responses) == 1 + assert approval_responses[0].hook_id == "approve_tc1" + assert approval_responses[0].granted is True + assert approval_responses[0].reason == "ok" + assert approval_responses[0].tool_call_id == "tc1" + + def test_tool_call_id_for_rejects_bad_prefix() -> None: hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="tc_42", diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py similarity index 84% rename from tests/agents/ui/ai_sdk/test_inbound.py rename to tests/agents/ui/ai_sdk/test_inbound_messages.py index ff27a290..f2d2abbc 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -6,7 +6,6 @@ from ai.agents.agent import MessageBundle from ai.agents.ui.ai_sdk import to_messages -from ai.agents.ui.ai_sdk.approvals import extract_approvals from ai.agents.ui.ai_sdk.inbound_messages import _normalize_ui_messages from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart from ai.types import messages as messages_ @@ -134,60 +133,6 @@ def test_to_messages_keeps_trailing_assistant_when_approved() -> None: assert [a.hook_id for a in approvals] == ["approve_tc1"] -def test_extract_approvals_returns_approved_responses() -> None: - approvals = extract_approvals( - [ - _ui( - "assistant", - _tool( - "x", - "tc1", - "approval-responded", - approval={ - "id": "approve_tc1", - "approved": False, - "reason": "nope", - }, - ), - ) - ] - ) - assert len(approvals) == 1 - assert approvals[0].hook_id == "approve_tc1" - assert approvals[0].granted is False - assert approvals[0].reason == "nope" - - -def test_extract_approvals_handles_dynamic_tool_responses() -> None: - approvals = extract_approvals( - [ - _ui( - "assistant", - { - "type": "dynamic-tool", - "toolName": "web_search", - "toolCallId": "tc1", - "state": "approval-responded", - "input": {"query": "ai"}, - "approval": { - "id": "approve_tc1", - "approved": True, - "reason": "ok", - "isAutomatic": True, - }, - "providerExecuted": True, - }, - ) - ] - ) - - assert len(approvals) == 1 - assert approvals[0].hook_id == "approve_tc1" - assert approvals[0].granted is True - assert approvals[0].reason == "ok" - assert approvals[0].tool_call_id == "tc1" - - def test_normalize_ui_messages_heals_stale_tool_state() -> None: ui = [ _ui( diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/test_outbound_messages.py similarity index 84% rename from tests/agents/ui/ai_sdk/outbound/test_history.py rename to tests/agents/ui/ai_sdk/test_outbound_messages.py index 67fe8619..ad503e58 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ b/tests/agents/ui/ai_sdk/test_outbound_messages.py @@ -3,11 +3,13 @@ from collections import Counter from ai.agents.ui import ai_sdk -from ai.agents.ui.ai_sdk import to_ui_messages +from ai.agents.ui.ai_sdk import outbound_messages, to_ui_messages from ai.agents.ui.ai_sdk.ui_messages import ( UIDynamicToolPart, UIFilePart, + UIReasoningPart, UITextPart, + UIToolApproval, UIToolPart, ) from ai.types import integrity @@ -80,6 +82,101 @@ def _parallel_tool_turn( ] +def test_to_ui_parts_text_and_reasoning() -> None: + parts: list[messages_.Part] = [ + messages_.ReasoningPart(text="thinking"), + messages_.TextPart(text="hi"), + ] + ui_parts = outbound_messages.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIReasoningPart) + assert ui_parts[0].text == "thinking" + assert isinstance(ui_parts[1], UITextPart) + assert ui_parts[1].text == "hi" + + +def test_to_ui_parts_tool_call_parses_json_args() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q": "x"}', + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIToolPart) + assert ui_parts[0].type == "tool-search" + assert ui_parts[0].input == {"q": "x"} + assert ui_parts[0].state == "input-available" + + +def test_merge_tool_results_updates_state_and_output() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + outbound_messages.merge_tool_results( + ui_parts, + [ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 3}, + ) + ], + ) + merged = ui_parts[0] + assert isinstance(merged, UIToolPart) + assert merged.state == "output-available" + assert merged.output == {"hits": 3} + + +def test_merge_approval_signals_pending_then_resolved() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + + outbound_messages.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ) + requested = ui_parts[0] + assert isinstance(requested, UIToolPart) + assert requested.state == "approval-requested" + assert isinstance(requested.approval, UIToolApproval) + + outbound_messages.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + resolution={"granted": True, "reason": None}, + ) + ], + ) + responded = ui_parts[0] + assert isinstance(responded, UIToolPart) + assert responded.state == "approval-responded" + assert responded.approval is not None + assert responded.approval.approved is True + + def _tool_counts( messages: list[messages_.Message], ) -> Counter[tuple[str, str]]: diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/test_outbound_stream.py similarity index 84% rename from tests/agents/ui/ai_sdk/outbound/test_stream.py rename to tests/agents/ui/ai_sdk/test_outbound_stream.py index 42d2c991..837619c0 100644 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ b/tests/agents/ui/ai_sdk/test_outbound_stream.py @@ -1,10 +1,16 @@ from __future__ import annotations +import json from collections.abc import AsyncGenerator from typing import Any import ai -from ai.agents.ui.ai_sdk import to_stream, ui_events +from ai.agents.ui.ai_sdk import to_sse, to_stream, ui_events +from ai.agents.ui.ai_sdk.outbound_stream import ( + format_done_sse, + format_sse, + serialize_event, +) from ai.types import events as agent_events_ from ai.types import events as events_ from ai.types import messages as messages_ @@ -23,6 +29,70 @@ async def _collect( return [event async for event in to_stream(_gen(stream_events))] +def test_serialize_event_camelcases_keys() -> None: + event = ui_events.UIStartEvent(message_id="m1") + payload = json.loads(serialize_event(event)) + assert payload == {"type": "start", "messageId": "m1"} + + +def test_format_sse_wraps_data_line() -> None: + event = ui_events.UITextDeltaEvent(id="t1", delta="hi") + line = format_sse(event) + assert line.startswith("data: ") + assert line.endswith("\n\n") + + +def test_serialize_data_event_uses_type_with_prefix() -> None: + event = ui_events.UIDataEvent(data_type="custom", data={"k": 1}) + payload = json.loads(serialize_event(event)) + assert payload["type"] == "data-custom" + assert "dataType" not in payload + + +def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: + event = ui_events.UIToolApprovalResponseEvent( + approval_id="approval-1", + approved=False, + reason="no", + provider_executed=True, + provider_metadata={"provider": {"k": "v"}}, + ) + + payload = json.loads(serialize_event(event)) + + assert payload == { + "type": "tool-approval-response", + "approvalId": "approval-1", + "approved": False, + "reason": "no", + "providerExecuted": True, + "providerMetadata": {"provider": {"k": "v"}}, + } + + +def test_format_done_sse_returns_done_sentinel() -> None: + assert format_done_sse() == "data: [DONE]\n\n" + + +async def test_to_sse_emits_data_prefixed_lines() -> None: + lines = [ + line + async for line in to_sse( + _gen( + [ + events_.TextStart(block_id="t1"), + events_.TextDelta(block_id="t1", chunk="hi"), + events_.TextEnd(block_id="t1"), + ] + ) + ) + ] + assert all(line.startswith("data: ") for line in lines) + first = json.loads(lines[0].removeprefix("data: ").rstrip()) + assert first["type"] == "start" + assert lines[-1] == "data: [DONE]\n\n" + + async def test_stream_start_uses_runtime_message_id() -> None: assistant = messages_.Message( id="assistant-runtime-id", @@ -60,12 +130,10 @@ async def test_event_driven_text_streaming() -> None: assert isinstance(out[0], ui_events.UIStartEvent) assert isinstance(out[1], ui_events.UIStartStepEvent) assert ( - isinstance(out[2], ui_events.UITextStartEvent) - and out[2].id == text_id + isinstance(out[2], ui_events.UITextStartEvent) and out[2].id == text_id ) assert ( - isinstance(out[3], ui_events.UITextDeltaEvent) - and out[3].delta == "hi" + isinstance(out[3], ui_events.UITextDeltaEvent) and out[3].delta == "hi" ) assert isinstance(out[4], ui_events.UITextEndEvent) and out[4].id == text_id assert isinstance(out[5], ui_events.UIFinishStepEvent) diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py deleted file mode 100644 index dcd6692d..00000000 --- a/tests/agents/ui/ai_sdk/test_parts.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from ai.agents.ui.ai_sdk import outbound_messages -from ai.agents.ui.ai_sdk.ui_messages import ( - UIReasoningPart, - UITextPart, - UIToolApproval, - UIToolPart, -) -from ai.types import messages as messages_ - - -def test_to_ui_parts_text_and_reasoning() -> None: - parts: list[messages_.Part] = [ - messages_.ReasoningPart(text="thinking"), - messages_.TextPart(text="hi"), - ] - ui_parts = outbound_messages.to_ui_parts(parts) - assert isinstance(ui_parts[0], UIReasoningPart) - assert ui_parts[0].text == "thinking" - assert isinstance(ui_parts[1], UITextPart) - assert ui_parts[1].text == "hi" - - -def test_to_ui_parts_tool_call_parses_json_args() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args='{"q": "x"}', - ) - ] - ui_parts = outbound_messages.to_ui_parts(parts) - assert isinstance(ui_parts[0], UIToolPart) - assert ui_parts[0].type == "tool-search" - assert ui_parts[0].input == {"q": "x"} - assert ui_parts[0].state == "input-available" - - -def test_merge_tool_results_updates_state_and_output() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args="{}", - ) - ] - ui_parts = outbound_messages.to_ui_parts(parts) - outbound_messages.merge_tool_results( - ui_parts, - [ - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 3}, - ) - ], - ) - merged = ui_parts[0] - assert isinstance(merged, UIToolPart) - assert merged.state == "output-available" - assert merged.output == {"hits": 3} - - -def test_merge_approval_signals_pending_then_resolved() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="delete", - tool_args="{}", - ) - ] - ui_parts = outbound_messages.to_ui_parts(parts) - - outbound_messages.merge_approval_signals( - ui_parts, - [ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ) - ], - ) - requested = ui_parts[0] - assert isinstance(requested, UIToolPart) - assert requested.state == "approval-requested" - assert isinstance(requested.approval, UIToolApproval) - - outbound_messages.merge_approval_signals( - ui_parts, - [ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="resolved", - resolution={"granted": True, "reason": None}, - ) - ], - ) - responded = ui_parts[0] - assert isinstance(responded, UIToolPart) - assert responded.state == "approval-responded" - assert responded.approval is not None - assert responded.approval.approved is True From 739fbcc7226218d3eff3b5cf9274c5943e8ea700 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 17:20:40 -0700 Subject: [PATCH 09/13] Inline useless helper functions to make the code more clear --- src/ai/agents/ui/ai_sdk/approvals.py | 32 ++--- src/ai/agents/ui/ai_sdk/id_utils.py | 40 +++---- src/ai/agents/ui/ai_sdk/inbound_messages.py | 120 +++++++++---------- src/ai/agents/ui/ai_sdk/outbound_messages.py | 81 +++++-------- src/ai/agents/ui/ai_sdk/outbound_stream.py | 69 ++++++----- 5 files changed, 140 insertions(+), 202 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/approvals.py b/src/ai/agents/ui/ai_sdk/approvals.py index e23dc163..9c931eeb 100644 --- a/src/ai/agents/ui/ai_sdk/approvals.py +++ b/src/ai/agents/ui/ai_sdk/approvals.py @@ -32,37 +32,19 @@ def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: return None -def metadata_bool(metadata: dict[str, Any], key: str) -> bool | None: - value = metadata.get(key) - return value if isinstance(value, bool) else None - - -def metadata_dict( - metadata: dict[str, Any], - key: str, -) -> dict[str, Any] | None: - value = metadata.get(key) - return value if isinstance(value, dict) else None - - -def metadata_from_tool_part(tp: ToolPart) -> dict[str, Any]: - metadata: dict[str, Any] = {} - if tp.approval is not None and tp.approval.is_automatic is not None: - metadata["isAutomatic"] = tp.approval.is_automatic - if tp.provider_executed is not None: - metadata["providerExecuted"] = tp.provider_executed - if tp.call_provider_metadata is not None: - metadata["callProviderMetadata"] = tp.call_provider_metadata - return metadata - - def hook_part_from_tool_part(tp: ToolPart) -> messages_.HookPart[Any] | None: """Reconstruct approval hook state from a UI tool part when possible.""" approval = tp.approval if approval is None: return None - metadata = metadata_from_tool_part(tp) + metadata: dict[str, Any] = {} + if approval.is_automatic is not None: + metadata["isAutomatic"] = approval.is_automatic + if tp.provider_executed is not None: + metadata["providerExecuted"] = tp.provider_executed + if tp.call_provider_metadata is not None: + metadata["callProviderMetadata"] = tp.call_provider_metadata if tp.state == "approval-requested": return messages_.HookPart( diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py index fb20c093..38f8d62a 100644 --- a/src/ai/agents/ui/ai_sdk/id_utils.py +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -23,22 +23,19 @@ class SourceMessage: part_ids: tuple[str, ...] -def source_message_entry(message: messages_.Message) -> dict[str, object]: - return { - "id": message.id, - "role": message.role, - "turnId": message.turn_id, - "partIds": [part.id for part in message.parts], - } - - def metadata_for( source_messages: list[messages_.Message], ) -> dict[str, object]: return { ADAPTER_METADATA_KEY: { SOURCE_MESSAGES_KEY: [ - source_message_entry(message) for message in source_messages + { + "id": message.id, + "role": message.role, + "turnId": message.turn_id, + "partIds": [part.id for part in message.parts], + } + for message in source_messages ] } } @@ -77,10 +74,13 @@ def restore_source_ids( source_index = 0 for message in messages: - match_index = _find_next_source( - source_messages, - role=message.role, - start=source_index, + match_index = next( + ( + index + for index in range(source_index, len(source_messages)) + if source_messages[index].role == message.role + ), + None, ) if match_index is None: restored.append(message) @@ -121,18 +121,6 @@ def _parse_source_message(raw: object) -> SourceMessage | None: ) -def _find_next_source( - source_messages: list[SourceMessage], - *, - role: MessageRole, - start: int, -) -> int | None: - for index in range(start, len(source_messages)): - if source_messages[index].role == role: - return index - return None - - def _restore_message_ids( message: messages_.Message, source: SourceMessage, diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index d14c7f33..cddf69d9 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -26,22 +26,6 @@ ) -def _is_tool_completed(state: ui_messages_.UIToolInvocationState) -> bool: - return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES - - -def _is_tool_error(state: ui_messages_.UIToolInvocationState) -> bool: - return state in _TOOL_ERROR_STATES - - -def _tool_input_for_args( - part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, -) -> Any: - if part.state == "output-error" and part.input is None: - return part.raw_input - return part.input - - def _tool_result_output( part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, ) -> Any: @@ -254,22 +238,6 @@ def _build_result_part( is_error=is_error, ) - def _build_builtin_return_part( - *, - tool_call_id: str, - tool_name: str, - output: Any, - is_error: bool, - provider_metadata: dict[str, Any] | None, - ) -> messages_.BuiltinToolReturnPart: - return messages_.BuiltinToolReturnPart( - tool_call_id=tool_call_id, - tool_name=tool_name, - result=output, - is_error=is_error, - provider_metadata=provider_metadata, - ) - result: list[messages_.Message] = [] for ui_msg in ui_messages: @@ -298,6 +266,11 @@ def _build_builtin_return_part( case ui_messages_.UIToolInvocationPart() as inv: tool_args = json.dumps(inv.args) if inv.args else "{}" + is_completed = ( + inv.state in _TOOL_RESULT_STATES + or inv.state in _TOOL_ERROR_STATES + ) + is_error = inv.state in _TOOL_ERROR_STATES if inv.provider_executed: assistant_parts.append( messages_.BuiltinToolCallPart( @@ -306,13 +279,13 @@ def _build_builtin_return_part( tool_args=tool_args, ) ) - if _is_tool_completed(inv.state): + if is_completed: assistant_parts.append( - _build_builtin_return_part( + messages_.BuiltinToolReturnPart( tool_call_id=inv.tool_invocation_id, tool_name=inv.tool_name, - output=inv.result, - is_error=_is_tool_error(inv.state), + result=inv.result, + is_error=is_error, provider_metadata=None, ) ) @@ -324,13 +297,13 @@ def _build_builtin_return_part( tool_args=tool_args, ) ) - if _is_tool_completed(inv.state): + if is_completed: tool_result_parts.append( _build_result_part( tool_call_id=inv.tool_invocation_id, tool_name=inv.tool_name, output=inv.result, - is_error=_is_tool_error(inv.state), + is_error=is_error, ) ) @@ -340,8 +313,17 @@ def _build_builtin_return_part( | ui_messages_.UIDynamicToolPart() ) as tp ): - tool_input = _tool_input_for_args(tp) + tool_input = ( + tp.raw_input + if tp.state == "output-error" and tp.input is None + else tp.input + ) tool_args = normalize_tool_args(tool_input) + is_completed = ( + tp.state in _TOOL_RESULT_STATES + or tp.state in _TOOL_ERROR_STATES + ) + is_error = tp.state in _TOOL_ERROR_STATES if tp.provider_executed: assistant_parts.append( @@ -365,13 +347,13 @@ def _build_builtin_return_part( if approval_hook is not None: hook_parts.append(approval_hook) - if tp.provider_executed and _is_tool_completed(tp.state): + if tp.provider_executed and is_completed: assistant_parts.append( - _build_builtin_return_part( + messages_.BuiltinToolReturnPart( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, - output=_tool_result_output(tp), - is_error=_is_tool_error(tp.state), + result=_tool_result_output(tp), + is_error=is_error, provider_metadata=( tp.result_provider_metadata or tp.call_provider_metadata @@ -384,7 +366,7 @@ def _build_builtin_return_part( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, output=_tool_result_output(tp), - is_error=_is_tool_error(tp.state), + is_error=is_error, ) ) if tp.result_provider_metadata is not None: @@ -494,32 +476,26 @@ def _split_assistant_parts( current_results: list[messages_.ToolResultPart] = [] seen_tool_call = False - def _append_assistant(parts_: list[messages_.Part]) -> None: - messages.append( - messages_.Message( - role="assistant", - parts=parts_, - turn_id=turn_id, - ) - ) - - def _append_tool(parts_: list[messages_.ToolResultPart]) -> None: - messages.append( - messages_.Message( - role="tool", - parts=list(parts_), - turn_id=turn_id, - ) - ) - for part in parts: if ( seen_tool_call and current_results and not isinstance(part, messages_.ToolCallPart) ): - _append_assistant(current) - _append_tool(current_results) + messages.append( + messages_.Message( + role="assistant", + parts=current, + turn_id=turn_id, + ) + ) + messages.append( + messages_.Message( + role="tool", + parts=list(current_results), + turn_id=turn_id, + ) + ) current = [] current_results = [] seen_tool_call = False @@ -532,8 +508,20 @@ def _append_tool(parts_: list[messages_.ToolResultPart]) -> None: current_results.append(results_by_id[part.tool_call_id]) if current: - _append_assistant(current) + messages.append( + messages_.Message( + role="assistant", + parts=current, + turn_id=turn_id, + ) + ) if current_results: - _append_tool(current_results) + messages.append( + messages_.Message( + role="tool", + parts=list(current_results), + turn_id=turn_id, + ) + ) return messages diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index c92e7bae..33ad99b7 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, TypeGuard, cast +from typing import Any, cast from ....types import media from ....types import messages as messages_ @@ -22,32 +22,6 @@ } -def _is_tool_part(part: ui_messages.UIMessagePart) -> TypeGuard[UIToolLike]: - return isinstance( - part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart - ) - - -def _tool_call_id(part: UIToolLike) -> str: - return part.tool_call_id - - -def _message_turn_key(message: messages_.Message) -> str | None: - return message.turn_id - - -def _assistant_bubble_id(message: messages_.Message) -> str: - return _message_turn_key(message) or message.id - - -def _belongs_to_bubble( - message: messages_.Message, - bubble_id: str, -) -> bool: - key = _message_turn_key(message) - return key is None or key == bubble_id - - def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: """Convert internal parts to UI message parts.""" result: list[ui_messages.UIMessagePart] = [] @@ -185,18 +159,22 @@ def dedupe_tool_parts( tool_index: dict[str, int] = {} for part in ui_parts: - if not _is_tool_part(part): + if not isinstance( + part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): result.append(part) continue - idx = tool_index.get(_tool_call_id(part)) + idx = tool_index.get(part.tool_call_id) if idx is None: - tool_index[_tool_call_id(part)] = len(result) + tool_index[part.tool_call_id] = len(result) result.append(part) continue existing = result[idx] - if _is_tool_part(existing): + if isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): result[idx] = _merge_tool_part(existing, part) return result @@ -209,8 +187,10 @@ def merge_tool_results( """Merge tool result parts into existing UI tool parts.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): - if _is_tool_part(ui_part): - tool_index[_tool_call_id(ui_part)] = idx + if isinstance( + ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + tool_index[ui_part.tool_call_id] = idx for part in tool_parts: if isinstance(part, messages_.ToolResultPart): @@ -247,7 +227,9 @@ def merge_tool_results( continue idx = idx_opt existing = ui_parts[idx] - if not _is_tool_part(existing): + if not isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): continue if existing.state == "output-denied": continue @@ -261,8 +243,10 @@ def merge_approval_signals( """Merge approval hook state into existing UI tool parts.""" tool_index: dict[str, int] = {} for idx, ui_part in enumerate(ui_parts): - if _is_tool_part(ui_part): - tool_index[_tool_call_id(ui_part)] = idx + if isinstance( + ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + tool_index[ui_part.tool_call_id] = idx for part in internal_parts: if not isinstance(part, messages_.HookPart): @@ -278,24 +262,23 @@ def merge_approval_signals( idx = idx_opt existing = ui_parts[idx] - if not _is_tool_part(existing): + if not isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): continue updates: dict[str, Any] = {} - if ( - provider_executed := approvals.metadata_bool( - part.metadata, "providerExecuted" - ) - ) is not None: + provider_executed = part.metadata.get("providerExecuted") + if isinstance(provider_executed, bool): updates["provider_executed"] = provider_executed + is_automatic = part.metadata.get("isAutomatic") + is_automatic = is_automatic if isinstance(is_automatic, bool) else None if part.status == "pending": updates["state"] = "approval-requested" updates["approval"] = ui_messages.UIToolApproval.model_validate( { "id": part.hook_id, - "isAutomatic": approvals.metadata_bool( - part.metadata, "isAutomatic" - ), + "isAutomatic": is_automatic, } ) elif part.status == "resolved": @@ -308,9 +291,7 @@ def merge_approval_signals( "id": part.hook_id, "approved": resolution.get("granted"), "reason": resolution.get("reason"), - "isAutomatic": approvals.metadata_bool( - part.metadata, "isAutomatic" - ), + "isAutomatic": is_automatic, } ) if resolution.get("granted", False): @@ -351,7 +332,7 @@ def to_ui_messages( if msg.role == "assistant": ui_parts: list[ui_messages.UIMessagePart] = [] source_messages: list[messages_.Message] = [] - bubble_id = _assistant_bubble_id(msg) + bubble_id = msg.turn_id or msg.id while i < len(messages) and messages[i].role in ( "assistant", @@ -359,7 +340,7 @@ def to_ui_messages( "internal", ): current = messages[i] - if not _belongs_to_bubble(current, bubble_id): + if current.turn_id is not None and current.turn_id != bubble_id: break source_messages.append(current) if current.role == "assistant": diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index d33e689b..30f832c8 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -48,15 +48,6 @@ def _to_wire_output(snapshot: Any) -> Any: return snapshot -def _stream_message_id(event: events_.Event) -> str | None: - message = event.message - if message.role != "assistant": - return None - if message.turn_id is not None: - return message.turn_id - return None if message.id == "" else message.id - - class _StreamState: """Single-pass state across one ``to_stream()`` call.""" @@ -99,20 +90,6 @@ def _close_open_blocks(self) -> list[ui_events.UIMessageStreamEvent]: self.open_text_ids.clear() return events - def _finish_step(self) -> list[ui_events.UIMessageStreamEvent]: - events = self._close_open_blocks() - if self.in_step: - events.append(ui_events.UIFinishStepEvent()) - self.in_step = False - return events - - def _reset_step_tracking(self) -> None: - self.started_tool_inputs.clear() - self.tool_names.clear() - self.input_available_emitted.clear() - self.emitted_tool_results.clear() - self.emitted_approval_requests.clear() - def _ensure_started( self, message_id: str | None = None, @@ -126,7 +103,11 @@ def _ensure_started( events.append(ui_events.UIStartStepEvent()) self.emitted_start = True self.in_step = True - self._reset_step_tracking() + self.started_tool_inputs.clear() + self.tool_names.clear() + self.input_available_emitted.clear() + self.emitted_tool_results.clear() + self.emitted_approval_requests.clear() return events @@ -139,7 +120,14 @@ def on_event( # Lazily open the UI message on the first streaming event. if not self.emitted_start: - out.extend(self._ensure_started(_stream_message_id(event))) + message = event.message + message_id = None + if message.role == "assistant": + if message.turn_id is not None: + message_id = message.turn_id + elif message.id != "": + message_id = message.id + out.extend(self._ensure_started(message_id)) match event: case events_.TextStart(block_id=pid): @@ -470,6 +458,8 @@ def on_hook( if tc_id is None: return out + is_automatic = hook_part.metadata.get("isAutomatic") + is_automatic = is_automatic if isinstance(is_automatic, bool) else None if hook_part.status == "pending": if tc_id in self.emitted_approval_requests: return out @@ -478,24 +468,30 @@ def on_hook( ui_events.UIToolApprovalRequestEvent( approval_id=hook_part.hook_id, tool_call_id=tc_id, - is_automatic=approvals.metadata_bool( - hook_part.metadata, "isAutomatic" - ), + is_automatic=is_automatic, ) ) elif hook_part.status == "resolved": resolution: dict[str, Any] = hook_part.resolution or {} + provider_executed = hook_part.metadata.get("providerExecuted") + provider_executed = ( + provider_executed + if isinstance(provider_executed, bool) + else None + ) + provider_metadata = hook_part.metadata.get("callProviderMetadata") + provider_metadata = ( + provider_metadata + if isinstance(provider_metadata, dict) + else None + ) out.append( ui_events.UIToolApprovalResponseEvent( approval_id=hook_part.hook_id, approved=bool(resolution.get("granted")), reason=resolution.get("reason"), - provider_executed=approvals.metadata_bool( - hook_part.metadata, "providerExecuted" - ), - provider_metadata=approvals.metadata_dict( - hook_part.metadata, "callProviderMetadata" - ), + provider_executed=provider_executed, + provider_metadata=provider_metadata, ) ) if not resolution.get("granted"): @@ -515,7 +511,10 @@ def on_hook( # -- phase: stream finish ------------------------------------------------ def finish(self) -> list[ui_events.UIMessageStreamEvent]: - events = self._finish_step() + events = self._close_open_blocks() + if self.in_step: + events.append(ui_events.UIFinishStepEvent()) + self.in_step = False if self.emitted_start: events.append(ui_events.UIFinishEvent(finish_reason="stop")) return events From 624e7f7957e8790ba7b1f656871bfac4d2a443ac Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 17:29:31 -0700 Subject: [PATCH 10/13] Replace if elif dispatch with match case --- src/ai/agents/ui/ai_sdk/outbound_messages.py | 366 ++++++++++--------- src/ai/agents/ui/ai_sdk/outbound_stream.py | 108 +++--- 2 files changed, 245 insertions(+), 229 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index 33ad99b7..ad70a0a8 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -26,88 +26,89 @@ def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: """Convert internal parts to UI message parts.""" result: list[ui_messages.UIMessagePart] = [] for part in parts: - if isinstance(part, messages_.TextPart) and part.text: - result.append( - ui_messages.UITextPart.model_validate( - { - "type": "text", - "text": part.text, - "providerMetadata": part.provider_metadata, - } + match part: + case messages_.TextPart(text=text) if text: + result.append( + ui_messages.UITextPart.model_validate( + { + "type": "text", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) ) - ) - elif isinstance(part, messages_.ReasoningPart) and part.text: - result.append( - ui_messages.UIReasoningPart.model_validate( - { - "type": "reasoning", - "text": part.text, - "providerMetadata": part.provider_metadata, - } + case messages_.ReasoningPart(text=text) if text: + result.append( + ui_messages.UIReasoningPart.model_validate( + { + "type": "reasoning", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) ) - ) - elif isinstance(part, messages_.ToolCallPart): - result.append( - ui_messages.UIToolPart.model_validate( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "state": "input-available", - "input": normalize_tool_input(part.tool_args), - "callProviderMetadata": part.provider_metadata, - } + case messages_.ToolCallPart(): + result.append( + ui_messages.UIToolPart.model_validate( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "callProviderMetadata": part.provider_metadata, + } + ) ) - ) - elif isinstance(part, messages_.BuiltinToolCallPart): - result.append( - ui_messages.UIDynamicToolPart.model_validate( - { - "type": "dynamic-tool", - "toolName": part.tool_name, - "toolCallId": part.tool_call_id, - "state": "input-available", - "input": normalize_tool_input(part.tool_args), - "providerExecuted": True, - "callProviderMetadata": part.provider_metadata, - } + case messages_.BuiltinToolCallPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "providerExecuted": True, + "callProviderMetadata": part.provider_metadata, + } + ) ) - ) - elif isinstance(part, messages_.BuiltinToolReturnPart): - result.append( - ui_messages.UIDynamicToolPart.model_validate( - { - "type": "dynamic-tool", - "toolName": part.tool_name, - "toolCallId": part.tool_call_id, - "state": ( - "output-error" - if part.is_error - else "output-available" - ), - "input": None, - "output": None if part.is_error else part.result, - "errorText": ( - str(part.result) if part.is_error else None - ), - "providerExecuted": True, - "resultProviderMetadata": part.provider_metadata, - } + case messages_.BuiltinToolReturnPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": ( + "output-error" + if part.is_error + else "output-available" + ), + "input": None, + "output": None if part.is_error else part.result, + "errorText": ( + str(part.result) if part.is_error else None + ), + "providerExecuted": True, + "resultProviderMetadata": part.provider_metadata, + } + ) ) - ) - elif isinstance(part, messages_.FilePart): - result.append( - ui_messages.UIFilePart.model_validate( - { - "type": "file", - "mediaType": part.media_type, - "url": media.data_to_data_url( - part.data, part.media_type - ), - "filename": part.filename, - "providerMetadata": part.provider_metadata, - } + case messages_.FilePart(): + result.append( + ui_messages.UIFilePart.model_validate( + { + "type": "file", + "mediaType": part.media_type, + "url": media.data_to_data_url( + part.data, part.media_type + ), + "filename": part.filename, + "providerMetadata": part.provider_metadata, + } + ) ) - ) return result @@ -193,35 +194,37 @@ def merge_tool_results( tool_index[ui_part.tool_call_id] = idx for part in tool_parts: - if isinstance(part, messages_.ToolResultPart): - tool_call_id = part.tool_call_id - state = "output-error" if part.is_error else "output-available" - updates: dict[str, Any] = { - "state": state, - "result_provider_metadata": part.provider_metadata, - } - if part.is_error: - updates["error_text"] = str(part.result) - else: - updates["output"] = part.result - elif isinstance(part, messages_.BuiltinToolReturnPart): - tool_call_id = part.tool_call_id - updates = { - "state": ( - "output-error" if part.is_error else "output-available" - ), - "provider_executed": True, - "result_provider_metadata": part.provider_metadata, - } - if part.is_error: - updates["error_text"] = str(part.result) - else: - updates["output"] = part.result - else: - continue + updates: dict[str, Any] + match part: + case messages_.ToolResultPart() if part.is_hook_pending: + continue + case messages_.ToolResultPart(): + tool_call_id = part.tool_call_id + state = "output-error" if part.is_error else "output-available" + updates = { + "state": state, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + case messages_.BuiltinToolReturnPart(): + tool_call_id = part.tool_call_id + updates = { + "state": ( + "output-error" if part.is_error else "output-available" + ), + "provider_executed": True, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + case _: + continue - if isinstance(part, messages_.ToolResultPart) and part.is_hook_pending: - continue idx_opt = tool_index.get(tool_call_id) if idx_opt is None: continue @@ -273,35 +276,42 @@ def merge_approval_signals( updates["provider_executed"] = provider_executed is_automatic = part.metadata.get("isAutomatic") is_automatic = is_automatic if isinstance(is_automatic, bool) else None - if part.status == "pending": - updates["state"] = "approval-requested" - updates["approval"] = ui_messages.UIToolApproval.model_validate( - { - "id": part.hook_id, - "isAutomatic": is_automatic, - } - ) - elif part.status == "resolved": - resolution = cast( - "dict[str, Any]", - part.resolution if isinstance(part.resolution, dict) else {}, - ) - updates["approval"] = ui_messages.UIToolApproval.model_validate( - { - "id": part.hook_id, - "approved": resolution.get("granted"), - "reason": resolution.get("reason"), - "isAutomatic": is_automatic, - } - ) - if resolution.get("granted", False): - updates["state"] = "approval-responded" - else: - updates["state"] = "output-denied" - updates["output"] = None - elif part.status == "cancelled": - updates["state"] = "output-error" - updates["error_text"] = "Hook cancelled" + match part.status: + case "pending": + updates["state"] = "approval-requested" + updates["approval"] = ( + ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "isAutomatic": is_automatic, + } + ) + ) + case "resolved": + resolution = cast( + "dict[str, Any]", + part.resolution + if isinstance(part.resolution, dict) + else {}, + ) + updates["approval"] = ( + ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + "isAutomatic": is_automatic, + } + ) + ) + if resolution.get("granted", False): + updates["state"] = "approval-responded" + else: + updates["state"] = "output-denied" + updates["output"] = None + case "cancelled": + updates["state"] = "output-error" + updates["error_text"] = "Hook cancelled" if updates: ui_parts[idx] = existing.model_copy(update=updates) @@ -317,52 +327,54 @@ def to_ui_messages( while i < len(messages): msg = messages[i] - if msg.role in ("user", "system"): - result.append( - ui_messages.UIMessage( - id=msg.id, - role=msg.role, - metadata=id_utils.metadata_for([msg]), - parts=to_ui_parts(msg.parts), + match msg.role: + case "user" | "system": + result.append( + ui_messages.UIMessage( + id=msg.id, + role=msg.role, + metadata=id_utils.metadata_for([msg]), + parts=to_ui_parts(msg.parts), + ) ) - ) - i += 1 - continue - - if msg.role == "assistant": - ui_parts: list[ui_messages.UIMessagePart] = [] - source_messages: list[messages_.Message] = [] - bubble_id = msg.turn_id or msg.id - - while i < len(messages) and messages[i].role in ( - "assistant", - "tool", - "internal", - ): - current = messages[i] - if current.turn_id is not None and current.turn_id != bubble_id: - break - source_messages.append(current) - if current.role == "assistant": - ui_parts.extend(to_ui_parts(current.parts)) - ui_parts = dedupe_tool_parts(ui_parts) - elif current.role == "tool": - merge_tool_results(ui_parts, current.parts) - elif current.role == "internal": - merge_approval_signals(ui_parts, current.parts) i += 1 - ui_parts = dedupe_tool_parts(ui_parts) - - result.append( - ui_messages.UIMessage( - id=bubble_id, - role="assistant", - metadata=id_utils.metadata_for(source_messages), - parts=ui_parts, + case "assistant": + ui_parts: list[ui_messages.UIMessagePart] = [] + source_messages: list[messages_.Message] = [] + bubble_id = msg.turn_id or msg.id + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "internal", + ): + current = messages[i] + if ( + current.turn_id is not None + and current.turn_id != bubble_id + ): + break + source_messages.append(current) + match current.role: + case "assistant": + ui_parts.extend(to_ui_parts(current.parts)) + ui_parts = dedupe_tool_parts(ui_parts) + case "tool": + merge_tool_results(ui_parts, current.parts) + case "internal": + merge_approval_signals(ui_parts, current.parts) + i += 1 + ui_parts = dedupe_tool_parts(ui_parts) + + result.append( + ui_messages.UIMessage( + id=bubble_id, + role="assistant", + metadata=id_utils.metadata_for(source_messages), + parts=ui_parts, + ) ) - ) - continue - - i += 1 + case _: + i += 1 return result diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index 30f832c8..8e810507 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -460,51 +460,54 @@ def on_hook( is_automatic = hook_part.metadata.get("isAutomatic") is_automatic = is_automatic if isinstance(is_automatic, bool) else None - if hook_part.status == "pending": - if tc_id in self.emitted_approval_requests: - return out - self.emitted_approval_requests.add(tc_id) - out.append( - ui_events.UIToolApprovalRequestEvent( - approval_id=hook_part.hook_id, - tool_call_id=tc_id, - is_automatic=is_automatic, + match hook_part.status: + case "pending": + if tc_id in self.emitted_approval_requests: + return out + self.emitted_approval_requests.add(tc_id) + out.append( + ui_events.UIToolApprovalRequestEvent( + approval_id=hook_part.hook_id, + tool_call_id=tc_id, + is_automatic=is_automatic, + ) ) - ) - elif hook_part.status == "resolved": - resolution: dict[str, Any] = hook_part.resolution or {} - provider_executed = hook_part.metadata.get("providerExecuted") - provider_executed = ( - provider_executed - if isinstance(provider_executed, bool) - else None - ) - provider_metadata = hook_part.metadata.get("callProviderMetadata") - provider_metadata = ( - provider_metadata - if isinstance(provider_metadata, dict) - else None - ) - out.append( - ui_events.UIToolApprovalResponseEvent( - approval_id=hook_part.hook_id, - approved=bool(resolution.get("granted")), - reason=resolution.get("reason"), - provider_executed=provider_executed, - provider_metadata=provider_metadata, + case "resolved": + resolution: dict[str, Any] = hook_part.resolution or {} + provider_executed = hook_part.metadata.get("providerExecuted") + provider_executed = ( + provider_executed + if isinstance(provider_executed, bool) + else None + ) + provider_metadata = hook_part.metadata.get( + "callProviderMetadata" + ) + provider_metadata = ( + provider_metadata + if isinstance(provider_metadata, dict) + else None ) - ) - if not resolution.get("granted"): out.append( - ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id) + ui_events.UIToolApprovalResponseEvent( + approval_id=hook_part.hook_id, + approved=bool(resolution.get("granted")), + reason=resolution.get("reason"), + provider_executed=provider_executed, + provider_metadata=provider_metadata, + ) ) - elif hook_part.status == "cancelled": - out.append( - ui_events.UIToolOutputErrorEvent( - tool_call_id=tc_id, - error_text="Hook cancelled", + if not resolution.get("granted"): + out.append( + ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id) + ) + case "cancelled": + out.append( + ui_events.UIToolOutputErrorEvent( + tool_call_id=tc_id, + error_text="Hook cancelled", + ) ) - ) return out @@ -527,18 +530,19 @@ async def to_stream( state = _StreamState() async for event in events: - if isinstance(event, events_.ToolCallResult): - for ui_event in state.on_tool_result(event): - yield ui_event - elif isinstance(event, events_.PartialToolCallResult): - for ui_event in state.on_partial_tool_result(event): - yield ui_event - elif isinstance(event, events_.HookEvent): - for ui_event in state.on_hook(event): - yield ui_event - else: - for ui_event in state.on_event(event): - yield ui_event + match event: + case events_.ToolCallResult(): + for ui_event in state.on_tool_result(event): + yield ui_event + case events_.PartialToolCallResult(): + for ui_event in state.on_partial_tool_result(event): + yield ui_event + case events_.HookEvent(): + for ui_event in state.on_hook(event): + yield ui_event + case _: + for ui_event in state.on_event(event): + yield ui_event for ui_event in state.finish(): yield ui_event From 8106729d51f3d1ab8cae1d3f1fa6df2ffcc75370 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 20 May 2026 17:39:00 -0700 Subject: [PATCH 11/13] Move important functions towards the bottom --- src/ai/agents/ui/ai_sdk/id_utils.py | 96 ++++---- src/ai/agents/ui/ai_sdk/inbound_messages.py | 121 +++++----- src/ai/agents/ui/ai_sdk/outbound_messages.py | 234 +++++++++---------- src/ai/agents/ui/ai_sdk/outbound_stream.py | 50 ++-- 4 files changed, 250 insertions(+), 251 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py index 38f8d62a..57b5ad56 100644 --- a/src/ai/agents/ui/ai_sdk/id_utils.py +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -23,6 +23,54 @@ class SourceMessage: part_ids: tuple[str, ...] +def _parse_source_message(raw: object) -> SourceMessage | None: + if not isinstance(raw, dict): + return None + + raw_dict = cast("dict[str, object]", raw) + message_id = raw_dict.get("id") + role = raw_dict.get("role") + if not isinstance(message_id, str) or role not in _VALID_ROLES: + return None + + raw_turn_id = raw_dict.get("turnId") + turn_id = raw_turn_id if isinstance(raw_turn_id, str) else None + + raw_part_ids = raw_dict.get("partIds") + part_ids = ( + tuple(part_id for part_id in raw_part_ids if isinstance(part_id, str)) + if isinstance(raw_part_ids, list) + else () + ) + + return SourceMessage( + id=message_id, + role=cast("MessageRole", role), + turn_id=turn_id, + part_ids=part_ids, + ) + + +def _restore_message_ids( + message: messages_.Message, + source: SourceMessage, +) -> messages_.Message: + updates: dict[str, object] = { + "id": source.id, + "turn_id": source.turn_id, + } + + if len(source.part_ids) == len(message.parts): + updates["parts"] = [ + part.model_copy(update={"id": part_id}) + for part, part_id in zip( + message.parts, source.part_ids, strict=True + ) + ] + + return message.model_copy(update=updates) + + def metadata_for( source_messages: list[messages_.Message], ) -> dict[str, object]: @@ -91,51 +139,3 @@ def restore_source_ids( restored.append(_restore_message_ids(message, source)) return restored - - -def _parse_source_message(raw: object) -> SourceMessage | None: - if not isinstance(raw, dict): - return None - - raw_dict = cast("dict[str, object]", raw) - message_id = raw_dict.get("id") - role = raw_dict.get("role") - if not isinstance(message_id, str) or role not in _VALID_ROLES: - return None - - raw_turn_id = raw_dict.get("turnId") - turn_id = raw_turn_id if isinstance(raw_turn_id, str) else None - - raw_part_ids = raw_dict.get("partIds") - part_ids = ( - tuple(part_id for part_id in raw_part_ids if isinstance(part_id, str)) - if isinstance(raw_part_ids, list) - else () - ) - - return SourceMessage( - id=message_id, - role=cast("MessageRole", role), - turn_id=turn_id, - part_ids=part_ids, - ) - - -def _restore_message_ids( - message: messages_.Message, - source: SourceMessage, -) -> messages_.Message: - updates: dict[str, object] = { - "id": source.id, - "turn_id": source.turn_id, - } - - if len(source.part_ids) == len(message.parts): - updates["parts"] = [ - part.model_copy(update={"id": part_id}) - for part, part_id in zip( - message.parts, source.part_ids, strict=True - ) - ] - - return message.model_copy(update=updates) diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index cddf69d9..db9f4c3b 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -78,6 +78,30 @@ def _decode_wire_output(output: Any) -> Any: return MessageBundle(messages=tuple(inner)) +def _build_result_part( + *, + tool_call_id: str, + tool_name: str, + output: Any, + is_error: bool, +) -> messages_.ToolResultPart: + if is_error: + result: Any = output + else: + decoded = _decode_wire_output(output) + result = ( + decoded + if isinstance(decoded, MessageBundle) + else _normalize_tool_result(decoded) + ) + return messages_.ToolResultPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result, + is_error=is_error, + ) + + def _normalize_ui_messages( ui_messages: list[ui_messages_.UIMessage], ) -> list[ui_messages_.UIMessage]: @@ -122,43 +146,6 @@ def _normalize_ui_messages( return normalized -# ============================================================================ -# UI → internal message conversion -# ============================================================================ - - -def to_messages( - ui_messages: list[ui_messages_.UIMessage], -) -> tuple[list[messages_.Message], list[ApprovalResponse]]: - """Parse a UI request into runtime messages + extracted approvals. - - Pure: normalizes stale tool states, extracts approval responses, - parses UIMessages into an ``ai.messages.Message`` list (split at - tool boundaries), drops the internal tombstones for approval - responses, and patches the trailing tool message with - ``is_hook_pending`` placeholders for tool calls whose approval was - just responded to but never recorded a real tool result. - - Sub-agent tool outputs (UIMessage wire shape) are decoded back to - ``MessageBundle`` so the parent agent's message history carries the - rich snapshot. Per-tool model-facing values are populated by - :meth:`Agent.run` (which has the tool registry), not here. - - Returns ``(messages, approvals)``. The caller can pre-register - resolutions via :func:`apply_approvals` before calling - :meth:`Agent.run` if the run should resume from a hook. - """ - normalized = _normalize_ui_messages(ui_messages) - approval_responses = extract_approvals(normalized) - messages = [ - m - for m in _parse(normalized) - if not approvals.is_resolved_approval_message(m) - ] - _patch_pending_hook_aborts(messages, approval_responses) - return messages, approval_responses - - def _patch_pending_hook_aborts( messages: list[messages_.Message], approvals: list[ApprovalResponse], @@ -215,29 +202,6 @@ def _patch_pending_hook_aborts( def _parse( ui_messages: list[ui_messages_.UIMessage], ) -> list[messages_.Message]: - def _build_result_part( - *, - tool_call_id: str, - tool_name: str, - output: Any, - is_error: bool, - ) -> messages_.ToolResultPart: - if is_error: - result: Any = output - else: - decoded = _decode_wire_output(output) - result = ( - decoded - if isinstance(decoded, MessageBundle) - else (_normalize_tool_result(decoded)) - ) - return messages_.ToolResultPart( - tool_call_id=tool_call_id, - tool_name=tool_name, - result=result, - is_error=is_error, - ) - result: list[messages_.Message] = [] for ui_msg in ui_messages: @@ -525,3 +489,40 @@ def _split_assistant_parts( ) return messages + + +# ============================================================================ +# UI → internal message conversion +# ============================================================================ + + +def to_messages( + ui_messages: list[ui_messages_.UIMessage], +) -> tuple[list[messages_.Message], list[ApprovalResponse]]: + """Parse a UI request into runtime messages + extracted approvals. + + Pure: normalizes stale tool states, extracts approval responses, + parses UIMessages into an ``ai.messages.Message`` list (split at + tool boundaries), drops the internal tombstones for approval + responses, and patches the trailing tool message with + ``is_hook_pending`` placeholders for tool calls whose approval was + just responded to but never recorded a real tool result. + + Sub-agent tool outputs (UIMessage wire shape) are decoded back to + ``MessageBundle`` so the parent agent's message history carries the + rich snapshot. Per-tool model-facing values are populated by + :meth:`Agent.run` (which has the tool registry), not here. + + Returns ``(messages, approvals)``. The caller can pre-register + resolutions via :func:`apply_approvals` before calling + :meth:`Agent.run` if the run should resume from a hook. + """ + normalized = _normalize_ui_messages(ui_messages) + approval_responses = extract_approvals(normalized) + messages = [ + m + for m in _parse(normalized) + if not approvals.is_resolved_approval_message(m) + ] + _patch_pending_hook_aborts(messages, approval_responses) + return messages, approval_responses diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index ad70a0a8..a37f7777 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -22,96 +22,6 @@ } -def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: - """Convert internal parts to UI message parts.""" - result: list[ui_messages.UIMessagePart] = [] - for part in parts: - match part: - case messages_.TextPart(text=text) if text: - result.append( - ui_messages.UITextPart.model_validate( - { - "type": "text", - "text": text, - "providerMetadata": part.provider_metadata, - } - ) - ) - case messages_.ReasoningPart(text=text) if text: - result.append( - ui_messages.UIReasoningPart.model_validate( - { - "type": "reasoning", - "text": text, - "providerMetadata": part.provider_metadata, - } - ) - ) - case messages_.ToolCallPart(): - result.append( - ui_messages.UIToolPart.model_validate( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "state": "input-available", - "input": normalize_tool_input(part.tool_args), - "callProviderMetadata": part.provider_metadata, - } - ) - ) - case messages_.BuiltinToolCallPart(): - result.append( - ui_messages.UIDynamicToolPart.model_validate( - { - "type": "dynamic-tool", - "toolName": part.tool_name, - "toolCallId": part.tool_call_id, - "state": "input-available", - "input": normalize_tool_input(part.tool_args), - "providerExecuted": True, - "callProviderMetadata": part.provider_metadata, - } - ) - ) - case messages_.BuiltinToolReturnPart(): - result.append( - ui_messages.UIDynamicToolPart.model_validate( - { - "type": "dynamic-tool", - "toolName": part.tool_name, - "toolCallId": part.tool_call_id, - "state": ( - "output-error" - if part.is_error - else "output-available" - ), - "input": None, - "output": None if part.is_error else part.result, - "errorText": ( - str(part.result) if part.is_error else None - ), - "providerExecuted": True, - "resultProviderMetadata": part.provider_metadata, - } - ) - ) - case messages_.FilePart(): - result.append( - ui_messages.UIFilePart.model_validate( - { - "type": "file", - "mediaType": part.media_type, - "url": media.data_to_data_url( - part.data, part.media_type - ), - "filename": part.filename, - "providerMetadata": part.provider_metadata, - } - ) - ) - return result - - def _merge_tool_part( existing: UIToolLike, candidate: UIToolLike, @@ -152,6 +62,18 @@ def _merge_tool_part( return existing.model_copy(update=updates) if updates else existing +def _tool_part_index_by_call_id( + ui_parts: list[ui_messages.UIMessagePart], +) -> dict[str, int]: + return { + ui_part.tool_call_id: idx + for idx, ui_part in enumerate(ui_parts) + if isinstance( + ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ) + } + + def dedupe_tool_parts( ui_parts: list[ui_messages.UIMessagePart], ) -> list[ui_messages.UIMessagePart]: @@ -186,12 +108,7 @@ def merge_tool_results( tool_parts: list[messages_.Part], ) -> None: """Merge tool result parts into existing UI tool parts.""" - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance( - ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart - ): - tool_index[ui_part.tool_call_id] = idx + tool_index = _tool_part_index_by_call_id(ui_parts) for part in tool_parts: updates: dict[str, Any] @@ -244,12 +161,7 @@ def merge_approval_signals( internal_parts: list[messages_.Part], ) -> None: """Merge approval hook state into existing UI tool parts.""" - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance( - ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart - ): - tool_index[ui_part.tool_call_id] = idx + tool_index = _tool_part_index_by_call_id(ui_parts) for part in internal_parts: if not isinstance(part, messages_.HookPart): @@ -279,13 +191,11 @@ def merge_approval_signals( match part.status: case "pending": updates["state"] = "approval-requested" - updates["approval"] = ( - ui_messages.UIToolApproval.model_validate( - { - "id": part.hook_id, - "isAutomatic": is_automatic, - } - ) + updates["approval"] = ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "isAutomatic": is_automatic, + } ) case "resolved": resolution = cast( @@ -294,15 +204,13 @@ def merge_approval_signals( if isinstance(part.resolution, dict) else {}, ) - updates["approval"] = ( - ui_messages.UIToolApproval.model_validate( - { - "id": part.hook_id, - "approved": resolution.get("granted"), - "reason": resolution.get("reason"), - "isAutomatic": is_automatic, - } - ) + updates["approval"] = ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + "isAutomatic": is_automatic, + } ) if resolution.get("granted", False): updates["state"] = "approval-responded" @@ -317,6 +225,96 @@ def merge_approval_signals( ui_parts[idx] = existing.model_copy(update=updates) +def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: + """Convert internal parts to UI message parts.""" + result: list[ui_messages.UIMessagePart] = [] + for part in parts: + match part: + case messages_.TextPart(text=text) if text: + result.append( + ui_messages.UITextPart.model_validate( + { + "type": "text", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) + ) + case messages_.ReasoningPart(text=text) if text: + result.append( + ui_messages.UIReasoningPart.model_validate( + { + "type": "reasoning", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) + ) + case messages_.ToolCallPart(): + result.append( + ui_messages.UIToolPart.model_validate( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "callProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.BuiltinToolCallPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "providerExecuted": True, + "callProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.BuiltinToolReturnPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": ( + "output-error" + if part.is_error + else "output-available" + ), + "input": None, + "output": None if part.is_error else part.result, + "errorText": ( + str(part.result) if part.is_error else None + ), + "providerExecuted": True, + "resultProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.FilePart(): + result.append( + ui_messages.UIFilePart.model_validate( + { + "type": "file", + "mediaType": part.media_type, + "url": media.data_to_data_url( + part.data, part.media_type + ), + "filename": part.filename, + "providerMetadata": part.provider_metadata, + } + ) + ) + return result + + def to_ui_messages( messages: list[messages_.Message], ) -> list[ui_messages.UIMessage]: diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index 8e810507..5e728c2b 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -523,31 +523,6 @@ def finish(self) -> list[ui_events.UIMessageStreamEvent]: return events -async def to_stream( - events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: - """Walk internal events once, emitting AI SDK UI stream events.""" - state = _StreamState() - - async for event in events: - match event: - case events_.ToolCallResult(): - for ui_event in state.on_tool_result(event): - yield ui_event - case events_.PartialToolCallResult(): - for ui_event in state.on_partial_tool_result(event): - yield ui_event - case events_.HookEvent(): - for ui_event in state.on_hook(event): - yield ui_event - case _: - for ui_event in state.on_event(event): - yield ui_event - - for ui_event in state.finish(): - yield ui_event - - def _to_camel_case(snake_str: str) -> str: components = snake_str.split("_") return components[0] + "".join(x.title() for x in components[1:]) @@ -581,6 +556,31 @@ def format_done_sse() -> str: return "data: [DONE]\n\n" +async def to_stream( + events: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: + """Walk internal events once, emitting AI SDK UI stream events.""" + state = _StreamState() + + async for event in events: + match event: + case events_.ToolCallResult(): + for ui_event in state.on_tool_result(event): + yield ui_event + case events_.PartialToolCallResult(): + for ui_event in state.on_partial_tool_result(event): + yield ui_event + case events_.HookEvent(): + for ui_event in state.on_hook(event): + yield ui_event + case _: + for ui_event in state.on_event(event): + yield ui_event + + for ui_event in state.finish(): + yield ui_event + + async def to_sse( events: AsyncIterable[events_.AgentEvent], ) -> AsyncGenerator[str]: From c2fae31a43295911bf3729121c9d6b71b3929e92 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 21 May 2026 11:55:43 -0700 Subject: [PATCH 12/13] Track messages inside the adapter to enable roundtrip reconstruction --- src/ai/agents/ui/ai_sdk/outbound_stream.py | 28 +++- .../agents/ui/ai_sdk/test_outbound_stream.py | 123 ++++++++++++++++++ 2 files changed, 150 insertions(+), 1 deletion(-) diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index 5e728c2b..288bfbe5 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -68,6 +68,7 @@ def __init__(self) -> None: self.completed_reasoning_ids: set[str] = set() self.text_delta_ids: set[str] = set() self.reasoning_delta_ids: set[str] = set() + self.source_messages: dict[str, messages_.Message] = {} # Per-tool-call aggregators for streaming generator tools. Each # PartialToolCallResult feeds its value into the aggregator and @@ -78,6 +79,23 @@ def __init__(self) -> None: # -- boundary helpers ---------------------------------------------------- + def _track_source_message(self, message: messages_.Message | None) -> None: + if message is None or message.id == "": + return + self.source_messages[message.id] = message + + def _latest_assistant_metadata(self) -> Any | None: + messages = [ + msg + for msg in self.source_messages.values() + if msg.role != "system" and msg.parts + ] + ui_messages = outbound_messages.to_ui_messages(messages) + for message in reversed(ui_messages): + if message.role == "assistant": + return message.metadata + return None + def _close_open_blocks(self) -> list[ui_events.UIMessageStreamEvent]: events: list[ui_events.UIMessageStreamEvent] = [] for rid in list(self.open_reasoning_ids): @@ -117,6 +135,7 @@ def on_event( self, event: events_.Event ) -> list[ui_events.UIMessageStreamEvent]: out: list[ui_events.UIMessageStreamEvent] = [] + self._track_source_message(event.message) # Lazily open the UI message on the first streaming event. if not self.emitted_start: @@ -338,6 +357,7 @@ def on_tool_result( msg = event.message out: list[ui_events.UIMessageStreamEvent] = [] + self._track_source_message(msg) out.extend(self._ensure_started(msg.turn_id)) # Emit ToolInputAvailable for each tool call that triggered @@ -451,6 +471,7 @@ def on_hook( hook_part = event.hook out: list[ui_events.UIMessageStreamEvent] = [] + self._track_source_message(event.message) # Ensure the UI message is started. out.extend(self._ensure_started(event.message.turn_id)) @@ -519,7 +540,12 @@ def finish(self) -> list[ui_events.UIMessageStreamEvent]: events.append(ui_events.UIFinishStepEvent()) self.in_step = False if self.emitted_start: - events.append(ui_events.UIFinishEvent(finish_reason="stop")) + events.append( + ui_events.UIFinishEvent( + finish_reason="stop", + message_metadata=self._latest_assistant_metadata(), + ) + ) return events diff --git a/tests/agents/ui/ai_sdk/test_outbound_stream.py b/tests/agents/ui/ai_sdk/test_outbound_stream.py index 837619c0..e496405c 100644 --- a/tests/agents/ui/ai_sdk/test_outbound_stream.py +++ b/tests/agents/ui/ai_sdk/test_outbound_stream.py @@ -29,6 +29,15 @@ async def _collect( return [event async for event in to_stream(_gen(stream_events))] +def _source_messages(metadata: Any | None) -> list[dict[str, Any]]: + assert isinstance(metadata, dict) + adapter_metadata = metadata.get("aiPython") + assert isinstance(adapter_metadata, dict) + source_messages = adapter_metadata.get("sourceMessages") + assert isinstance(source_messages, list) + return source_messages + + def test_serialize_event_camelcases_keys() -> None: event = ui_events.UIStartEvent(message_id="m1") payload = json.loads(serialize_event(event)) @@ -116,6 +125,99 @@ async def test_stream_start_uses_runtime_message_id() -> None: assert start.message_id == "assistant-runtime-id" +async def test_finish_metadata_tracks_streamed_assistant_message() -> None: + assistant = messages_.Message( + id="assistant-1", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="hello")], + ) + + out = await _collect( + [ + events_.TextStart(block_id="text-1", message=assistant), + events_.TextDelta( + block_id="text-1", chunk="hello", message=assistant + ), + events_.TextEnd(block_id="text-1", message=assistant), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert _source_messages(finish.message_metadata) == [ + { + "id": "assistant-1", + "role": "assistant", + "turnId": None, + "partIds": ["text-1"], + } + ] + + +async def test_finish_metadata_tracks_tool_and_internal_messages() -> None: + tool_call = messages_.ToolCallPart( + id="call-1", + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + assistant = messages_.Message( + id="assistant-1", + turn_id="turn-1", + role="assistant", + parts=[tool_call], + ) + tool = messages_.Message( + id="tool-1", + turn_id="turn-1", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-1", + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ) + hook = messages_.HookPart[Any]( + id="hook-1", + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + internal = messages_.Message( + id="internal-1", + turn_id="turn-1", + role="internal", + parts=[hook], + ) + + out = await _collect( + [ + events_.ToolStart( + tool_call_id="tc1", tool_name="search", message=assistant + ), + events_.ToolEnd( + tool_call_id="tc1", tool_call=tool_call, message=assistant + ), + agent_events_.ToolCallResult( + message=tool, + results=tool.tool_results, + ), + agent_events_.HookEvent(message=internal, hook=hook), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert [ + source["id"] for source in _source_messages(finish.message_metadata) + ] == ["assistant-1", "tool-1", "internal-1"] + + async def test_event_driven_text_streaming() -> None: """Streaming text events lazily open a UI message.""" text_id = "txt1" @@ -138,6 +240,27 @@ async def test_event_driven_text_streaming() -> None: assert isinstance(out[4], ui_events.UITextEndEvent) and out[4].id == text_id assert isinstance(out[5], ui_events.UIFinishStepEvent) assert isinstance(out[6], ui_events.UIFinishEvent) + assert out[6].message_metadata is None + + +async def test_finish_metadata_ignores_empty_messages() -> None: + assistant = messages_.Message( + id="assistant-empty", + role="assistant", + parts=[], + ) + + out = await _collect( + [ + events_.StreamStart(message=assistant), + events_.StreamEnd(message=assistant), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert finish.message_metadata is None async def test_tool_call_and_result_emit_terminal_events() -> None: From 85d1c9e7b6d92189407ac325111fa5cf992ca180 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 22 May 2026 13:10:36 -0700 Subject: [PATCH 13/13] Comment on merging logic, refactor translation dispatch, comment on metadata shape --- src/ai/agents/ui/ai_sdk/id_utils.py | 11 ++++- src/ai/agents/ui/ai_sdk/inbound_messages.py | 2 +- src/ai/agents/ui/ai_sdk/outbound_messages.py | 44 +++++++++++--------- 3 files changed, 35 insertions(+), 22 deletions(-) diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py index 57b5ad56..939d9bba 100644 --- a/src/ai/agents/ui/ai_sdk/id_utils.py +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -1,4 +1,10 @@ -"""Roundtrip metadata for preserving internal message identity.""" +"""Roundtrip metadata for preserving internal message identity. + +The adapter writes ``metadata["aiPython"]["sourceMessages"]`` with each +source message's ``id``, ``role``, ``turnId``, and ``partIds``. Outbound UI +bubbles can collapse assistant/tool/internal messages into one UI message; +inbound parsing uses this metadata to restore stable message and part ids. +""" from __future__ import annotations @@ -74,6 +80,7 @@ def _restore_message_ids( def metadata_for( source_messages: list[messages_.Message], ) -> dict[str, object]: + """Return adapter metadata for restoring collapsed source message ids.""" return { ADAPTER_METADATA_KEY: { SOURCE_MESSAGES_KEY: [ @@ -90,6 +97,7 @@ def metadata_for( def source_messages_from(metadata: object) -> list[SourceMessage]: + """Parse adapter metadata, ignoring missing or malformed entries.""" if not isinstance(metadata, dict): return [] @@ -115,6 +123,7 @@ def restore_source_ids( messages: list[messages_.Message], source_messages: list[SourceMessage], ) -> list[messages_.Message]: + """Restore message and part ids from matching source metadata.""" if not source_messages: return messages diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index db9f4c3b..c0c36631 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -324,7 +324,7 @@ def _parse( ), ) ) - elif tp.state in _TOOL_RESULT_STATES: + elif is_completed: tool_result_parts.append( _build_result_part( tool_call_id=tp.tool_call_id, diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index a37f7777..4520e20d 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -11,6 +11,12 @@ UIToolLike = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart +# Internal history can contain separate records for one tool call +# (call, approval, result). AI SDK UI expects one tool part per +# toolCallId, so later/higher-ranked states update the first part +# https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message#tooluipart +# https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol + _TOOL_STATE_RANK: dict[ui_messages.UIToolInvocationState, int] = { "input-streaming": 0, "input-available": 1, @@ -21,6 +27,19 @@ "output-available": 6, } +_MERGEABLE_TOOL_PART_FIELDS = ( + "raw_input", + "output", + "error_text", + "approval", + "provider_executed", + "call_provider_metadata", + "result_provider_metadata", + "tool_metadata", + "preliminary", + "title", +) + def _merge_tool_part( existing: UIToolLike, @@ -38,26 +57,11 @@ def _merge_tool_part( if existing.input is None and candidate.input is not None: updates["input"] = candidate.input - if candidate.output is not None: - updates["output"] = candidate.output - if candidate.raw_input is not None: - updates["raw_input"] = candidate.raw_input - if candidate.error_text is not None: - updates["error_text"] = candidate.error_text - if candidate.approval is not None: - updates["approval"] = candidate.approval - if candidate.provider_executed is not None: - updates["provider_executed"] = candidate.provider_executed - if candidate.call_provider_metadata is not None: - updates["call_provider_metadata"] = candidate.call_provider_metadata - if candidate.result_provider_metadata is not None: - updates["result_provider_metadata"] = candidate.result_provider_metadata - if candidate.tool_metadata is not None: - updates["tool_metadata"] = candidate.tool_metadata - if candidate.preliminary is not None: - updates["preliminary"] = candidate.preliminary - if candidate.title is not None: - updates["title"] = candidate.title + + for field in _MERGEABLE_TOOL_PART_FIELDS: + value = getattr(candidate, field) + if value is not None: + updates[field] = value return existing.model_copy(update=updates) if updates else existing