diff --git a/src/ai/agents/ui/ai_sdk/__init__.py b/src/ai/agents/ui/ai_sdk/__init__.py index 711e7892..224f806e 100644 --- a/src/ai/agents/ui/ai_sdk/__init__.py +++ b/src/ai/agents/ui/ai_sdk/__init__.py @@ -1,14 +1,15 @@ """AI SDK UI adapter for messages and SSE streams.""" -from .inbound import ( +from .approvals import ( ApprovalResponse, apply_approvals, extract_approvals, - to_messages, ) -from .outbound import to_sse, to_stream, to_ui_messages -from .protocol import UI_MESSAGE_STREAM_HEADERS -from .ui_message import UIMessage +from .inbound_messages import to_messages +from .outbound_messages import to_ui_messages +from .outbound_stream import to_sse, to_stream +from .ui_events import UI_MESSAGE_STREAM_HEADERS +from .ui_messages import UIMessage __all__ = [ "UI_MESSAGE_STREAM_HEADERS", diff --git a/src/ai/agents/ui/ai_sdk/_approvals.py b/src/ai/agents/ui/ai_sdk/_approvals.py deleted file mode 100644 index f0694657..00000000 --- a/src/ai/agents/ui/ai_sdk/_approvals.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Approval-prefix linkage between ToolApproval hooks and tool calls. - -TODO(datamodel-rework §4): once ``HookPart.target_tool_call_id`` is added, -delete this module and replace call sites with direct field access. -""" - -from __future__ import annotations - -from typing import Any - -from ....types import messages as messages_ -from ...hooks import TOOL_APPROVAL_HOOK_TYPE - -_PREFIX = "approve_" - - -def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: - """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" - if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: - return None - if hook_part.hook_id.startswith(_PREFIX): - return hook_part.hook_id[len(_PREFIX) :] - return None - - -def is_tool_approval_message(msg: messages_.Message) -> bool: - """Return whether every part of ``msg`` is a ToolApproval HookPart.""" - if not msg.parts: - return False - return all( - isinstance(p, messages_.HookPart) and tool_call_id_for(p) is not None - for p in msg.parts - ) diff --git a/src/ai/agents/ui/ai_sdk/_parts.py b/src/ai/agents/ui/ai_sdk/_parts.py deleted file mode 100644 index 3637174b..00000000 --- a/src/ai/agents/ui/ai_sdk/_parts.py +++ /dev/null @@ -1,149 +0,0 @@ -"""Shared conversions between internal Part objects and UIMessagePart objects. - -Used by ``outbound.history`` to reconstruct UIMessages from persisted -``ai.messages.Message`` lists. The live outbound stream does not use these; it -emits wire-protocol deltas directly from event streams. -""" - -from __future__ import annotations - -import json -from typing import Any, cast - -from ....types import messages as messages_ -from . import _approvals, ui_message - - -def _normalize_tool_input(raw: str) -> str | dict[str, Any]: - """Parse tool args JSON string into a dict; fall back to raw string. - - TODO(datamodel-rework §4): once ``ToolCallPart.tool_args`` has a - canonical shape, drop this helper. - """ - try: - parsed = json.loads(raw) - except (json.JSONDecodeError, TypeError): - return raw - return parsed if isinstance(parsed, dict) else raw - - -def to_ui_parts(parts: list[messages_.Part]) -> list[ui_message.UIMessagePart]: - """Convert internal Part objects to UIMessagePart objects.""" - result: list[ui_message.UIMessagePart] = [] - for part in parts: - if isinstance(part, messages_.TextPart) and part.text: - result.append(ui_message.UITextPart(type="text", text=part.text)) - elif isinstance(part, messages_.ReasoningPart) and part.text: - result.append( - ui_message.UIReasoningPart(type="reasoning", text=part.text) - ) - elif isinstance(part, messages_.ToolCallPart): - result.append( - ui_message.UIToolPart.model_validate( - { - "type": f"tool-{part.tool_name}", - "toolCallId": part.tool_call_id, - "state": "input-available", - "input": _normalize_tool_input(part.tool_args), - } - ) - ) - elif isinstance(part, messages_.FilePart): - result.append( - ui_message.UIFilePart.model_validate( - { - "type": "file", - "mediaType": part.media_type, - "url": part.data if isinstance(part.data, str) else "", - "filename": part.filename, - } - ) - ) - return result - - -def merge_tool_results( - ui_parts: list[ui_message.UIMessagePart], - tool_parts: list[messages_.Part], -) -> None: - """Merge ToolResultParts into existing UIToolParts in-place.""" - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx - - for part in tool_parts: - if not isinstance(part, messages_.ToolResultPart): - continue - # Hook-abort placeholders are internal: the corresponding - # HookPart(pending) carries the user-visible state via - # merge_approval_signals. - if part.is_hook_pending: - continue - idx_opt = tool_index.get(part.tool_call_id) - if idx_opt is None: - continue - idx = idx_opt - existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): - continue - if existing.state == "output-denied": - continue - state = "output-error" if part.is_error else "output-available" - ui_parts[idx] = existing.model_copy( - update={"state": state, "output": part.result} - ) - - -def merge_approval_signals( - ui_parts: list[ui_message.UIMessagePart], - internal_parts: list[messages_.Part], -) -> None: - """Merge HookPart approval state into existing UIToolParts in-place.""" - tool_index: dict[str, int] = {} - for idx, ui_part in enumerate(ui_parts): - if isinstance(ui_part, ui_message.UIToolPart): - tool_index[ui_part.tool_call_id] = idx - - for part in internal_parts: - if not isinstance(part, messages_.HookPart): - continue - - tool_call_id = _approvals.tool_call_id_for(part) - if tool_call_id is None: - continue - - idx_opt = tool_index.get(tool_call_id) - if idx_opt is None: - continue - idx = idx_opt - - existing = ui_parts[idx] - if not isinstance(existing, ui_message.UIToolPart): - continue - - updates: dict[str, Any] = {} - if part.status == "pending": - updates["state"] = "approval-requested" - updates["approval"] = ui_message.UIToolApproval(id=part.hook_id) - elif part.status == "resolved": - resolution = cast( - "dict[str, Any]", - part.resolution if isinstance(part.resolution, dict) else {}, - ) - updates["approval"] = ui_message.UIToolApproval( - id=part.hook_id, - approved=resolution.get("granted"), - reason=resolution.get("reason"), - ) - if resolution.get("granted", False): - updates["state"] = "approval-responded" - else: - updates["state"] = "output-denied" - updates["output"] = None - elif part.status == "cancelled": - updates["state"] = "output-error" - updates["error_text"] = "Hook cancelled" - - if updates: - ui_parts[idx] = existing.model_copy(update=updates) diff --git a/src/ai/agents/ui/ai_sdk/approvals.py b/src/ai/agents/ui/ai_sdk/approvals.py new file mode 100644 index 00000000..9c931eeb --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/approvals.py @@ -0,0 +1,129 @@ +"""Tool approval helpers for AI SDK UI adapters.""" + +from __future__ import annotations + +from typing import Any, NamedTuple + +from ....types import messages as messages_ +from ...hooks import TOOL_APPROVAL_HOOK_TYPE, resolve_hook +from . import ui_messages + +_PREFIX = "approve_" + + +ToolPart = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + + +class ApprovalResponse(NamedTuple): + """Approval response extracted from a responded UI tool part.""" + + hook_id: str + granted: bool + reason: str | None + tool_call_id: str + + +def tool_call_id_for(hook_part: messages_.HookPart[Any]) -> str | None: + """Return the tool_call_id encoded in a ToolApproval hook id, or None.""" + if hook_part.hook_type != TOOL_APPROVAL_HOOK_TYPE: + return None + if hook_part.hook_id.startswith(_PREFIX): + return hook_part.hook_id[len(_PREFIX) :] + return None + + +def hook_part_from_tool_part(tp: ToolPart) -> messages_.HookPart[Any] | None: + """Reconstruct approval hook state from a UI tool part when possible.""" + approval = tp.approval + if approval is None: + return None + + metadata: dict[str, Any] = {} + if approval.is_automatic is not None: + metadata["isAutomatic"] = approval.is_automatic + if tp.provider_executed is not None: + metadata["providerExecuted"] = tp.provider_executed + if tp.call_provider_metadata is not None: + metadata["callProviderMetadata"] = tp.call_provider_metadata + + if tp.state == "approval-requested": + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="pending", + metadata=metadata, + ) + + if tp.state == "approval-responded" and approval.approved is not None: + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="resolved", + metadata=metadata, + resolution={ + "granted": approval.approved, + "reason": approval.reason, + }, + ) + + if tp.state == "output-denied": + return messages_.HookPart( + hook_id=approval.id, + hook_type=TOOL_APPROVAL_HOOK_TYPE, + status="resolved", + metadata=metadata, + resolution={ + "granted": False, + "reason": approval.reason, + }, + ) + + return None + + +def extract_approvals( + ui_messages_list: list[ui_messages.UIMessage], +) -> list[ApprovalResponse]: + """Return every approval response found in UI messages.""" + approvals: list[ApprovalResponse] = [] + for ui_msg in ui_messages_list: + for part in ui_msg.parts: + if not isinstance( + part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + continue + if ( + part.state == "approval-responded" + and part.approval is not None + and part.approval.approved is not None + ): + approvals.append( + ApprovalResponse( + hook_id=part.approval.id, + granted=part.approval.approved, + reason=part.approval.reason, + tool_call_id=part.tool_call_id, + ) + ) + return approvals + + +def apply_approvals(approvals: list[ApprovalResponse]) -> None: + """Pre-register each approval resolution with the hooks registry.""" + for approval in approvals: + resolve_hook( + approval.hook_id, + {"granted": approval.granted, "reason": approval.reason}, + ) + + +def is_resolved_approval_message(msg: messages_.Message) -> bool: + """Return whether ``msg`` records a resolved tool approval hook.""" + if msg.role != "internal" or len(msg.parts) != 1: + return False + part = msg.parts[0] + return ( + isinstance(part, messages_.HookPart) + and part.hook_type == TOOL_APPROVAL_HOOK_TYPE + and part.status == "resolved" + ) diff --git a/src/ai/agents/ui/ai_sdk/id_utils.py b/src/ai/agents/ui/ai_sdk/id_utils.py new file mode 100644 index 00000000..939d9bba --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/id_utils.py @@ -0,0 +1,150 @@ +"""Roundtrip metadata for preserving internal message identity. + +The adapter writes ``metadata["aiPython"]["sourceMessages"]`` with each +source message's ``id``, ``role``, ``turnId``, and ``partIds``. Outbound UI +bubbles can collapse assistant/tool/internal messages into one UI message; +inbound parsing uses this metadata to restore stable message and part ids. +""" + +from __future__ import annotations + +import dataclasses +from typing import TYPE_CHECKING, Literal, cast + +if TYPE_CHECKING: + from ....types import messages as messages_ + +ADAPTER_METADATA_KEY = "aiPython" +SOURCE_MESSAGES_KEY = "sourceMessages" + +MessageRole = Literal["user", "assistant", "system", "tool", "internal"] +_VALID_ROLES = {"user", "assistant", "system", "tool", "internal"} + + +@dataclasses.dataclass(frozen=True) +class SourceMessage: + id: str + role: MessageRole + turn_id: str | None + part_ids: tuple[str, ...] + + +def _parse_source_message(raw: object) -> SourceMessage | None: + if not isinstance(raw, dict): + return None + + raw_dict = cast("dict[str, object]", raw) + message_id = raw_dict.get("id") + role = raw_dict.get("role") + if not isinstance(message_id, str) or role not in _VALID_ROLES: + return None + + raw_turn_id = raw_dict.get("turnId") + turn_id = raw_turn_id if isinstance(raw_turn_id, str) else None + + raw_part_ids = raw_dict.get("partIds") + part_ids = ( + tuple(part_id for part_id in raw_part_ids if isinstance(part_id, str)) + if isinstance(raw_part_ids, list) + else () + ) + + return SourceMessage( + id=message_id, + role=cast("MessageRole", role), + turn_id=turn_id, + part_ids=part_ids, + ) + + +def _restore_message_ids( + message: messages_.Message, + source: SourceMessage, +) -> messages_.Message: + updates: dict[str, object] = { + "id": source.id, + "turn_id": source.turn_id, + } + + if len(source.part_ids) == len(message.parts): + updates["parts"] = [ + part.model_copy(update={"id": part_id}) + for part, part_id in zip( + message.parts, source.part_ids, strict=True + ) + ] + + return message.model_copy(update=updates) + + +def metadata_for( + source_messages: list[messages_.Message], +) -> dict[str, object]: + """Return adapter metadata for restoring collapsed source message ids.""" + return { + ADAPTER_METADATA_KEY: { + SOURCE_MESSAGES_KEY: [ + { + "id": message.id, + "role": message.role, + "turnId": message.turn_id, + "partIds": [part.id for part in message.parts], + } + for message in source_messages + ] + } + } + + +def source_messages_from(metadata: object) -> list[SourceMessage]: + """Parse adapter metadata, ignoring missing or malformed entries.""" + if not isinstance(metadata, dict): + return [] + + metadata_dict = cast("dict[str, object]", metadata) + adapter_metadata = metadata_dict.get(ADAPTER_METADATA_KEY) + if not isinstance(adapter_metadata, dict): + return [] + + adapter_metadata_dict = cast("dict[str, object]", adapter_metadata) + raw_source_messages = adapter_metadata_dict.get(SOURCE_MESSAGES_KEY) + if not isinstance(raw_source_messages, list): + return [] + + result: list[SourceMessage] = [] + for raw in raw_source_messages: + source = _parse_source_message(raw) + if source is not None: + result.append(source) + return result + + +def restore_source_ids( + messages: list[messages_.Message], + source_messages: list[SourceMessage], +) -> list[messages_.Message]: + """Restore message and part ids from matching source metadata.""" + if not source_messages: + return messages + + restored: list[messages_.Message] = [] + source_index = 0 + + for message in messages: + match_index = next( + ( + index + for index in range(source_index, len(source_messages)) + if source_messages[index].role == message.role + ), + None, + ) + if match_index is None: + restored.append(message) + continue + + source = source_messages[match_index] + source_index = match_index + 1 + restored.append(_restore_message_ids(message, source)) + + return restored diff --git a/src/ai/agents/ui/ai_sdk/inbound.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py similarity index 54% rename from src/ai/agents/ui/ai_sdk/inbound.py rename to src/ai/agents/ui/ai_sdk/inbound_messages.py index e4918a22..c0c36631 100644 --- a/src/ai/agents/ui/ai_sdk/inbound.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -8,12 +8,14 @@ import json import logging -from typing import Any, NamedTuple +from typing import Any from ....types import messages as messages_ from ...agent import MessageBundle -from ...hooks import resolve_hook -from . import ui_message +from . import approvals, id_utils +from . import ui_messages as ui_messages_ +from .approvals import ApprovalResponse, extract_approvals +from .tool_utils import normalize_tool_args logger = logging.getLogger(__name__) @@ -24,25 +26,18 @@ ) -def _is_tool_completed(state: ui_message.UIToolInvocationState) -> bool: - return state in _TOOL_RESULT_STATES or state in _TOOL_ERROR_STATES - - -def _is_tool_error(state: ui_message.UIToolInvocationState) -> bool: - return state in _TOOL_ERROR_STATES - - -# TODO(datamodel-rework §4): once tool args have a canonical shape, drop -# these normalizers. -def _normalize_tool_args(tool_input: str | dict[str, Any] | None) -> str: - """Normalize tool input (JSON string, dict, or None) to a JSON string.""" - match tool_input: - case str(): - return tool_input - case dict(): - return json.dumps(tool_input) - case _: - return "{}" +def _tool_result_output( + part: ui_messages_.UIToolPart | ui_messages_.UIDynamicToolPart, +) -> Any: + if part.state == "output-error": + return _error_result(part.error_text, part.output) + if part.state == "output-denied": + reason = part.approval.reason if part.approval is not None else None + return { + "type": "error-text", + "value": reason or "Tool call execution denied.", + } + return part.output def _normalize_tool_result(output: Any) -> dict[str, Any] | None: @@ -76,121 +71,51 @@ def _decode_wire_output(output: Any) -> Any: if output.get("role") != "assistant" or "parts" not in output: return output try: - ui_msg = ui_message.UIMessage.model_validate(output) + ui_msg = ui_messages_.UIMessage.model_validate(output) except Exception: return output inner = list(_parse([ui_msg])) return MessageBundle(messages=tuple(inner)) -def _approval_hook_part( - tp: ui_message.UIToolPart, -) -> messages_.HookPart[Any] | None: - """Reconstruct approval hook state from a UI tool part when possible.""" - approval = tp.approval - if approval is None: - return None - - if tp.state == "approval-requested": - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="pending", - ) - - if tp.state == "approval-responded" and approval.approved is not None: - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="resolved", - resolution={ - "granted": approval.approved, - "reason": approval.reason, - }, - ) - - if tp.state == "output-denied": - return messages_.HookPart( - hook_id=approval.id, - hook_type="ToolApproval", - status="resolved", - resolution={ - "granted": False, - "reason": approval.reason, - }, +def _build_result_part( + *, + tool_call_id: str, + tool_name: str, + output: Any, + is_error: bool, +) -> messages_.ToolResultPart: + if is_error: + result: Any = output + else: + decoded = _decode_wire_output(output) + result = ( + decoded + if isinstance(decoded, MessageBundle) + else _normalize_tool_result(decoded) ) - - return None - - -# ============================================================================ -# Approval extraction + bulk resolution -# ============================================================================ - - -class ApprovalResponse(NamedTuple): - """Approval response extracted from a responded UIToolPart.""" - - hook_id: str - granted: bool - reason: str | None - tool_call_id: str - - -def extract_approvals( - ui_messages: list[ui_message.UIMessage], -) -> list[ApprovalResponse]: - """Return every approval response found in *ui_messages*. - - Pure function — does not resolve hooks or trigger side effects. - """ - approvals: list[ApprovalResponse] = [] - for ui_msg in ui_messages: - for part in ui_msg.parts: - if not isinstance(part, ui_message.UIToolPart): - continue - if ( - part.state == "approval-responded" - and part.approval is not None - and part.approval.approved is not None - ): - approvals.append( - ApprovalResponse( - hook_id=part.approval.id, - granted=part.approval.approved, - reason=part.approval.reason, - tool_call_id=part.tool_call_id, - ) - ) - return approvals - - -def apply_approvals(approvals: list[ApprovalResponse]) -> None: - """Pre-register each approval resolution with the hooks registry.""" - for approval in approvals: - resolve_hook( - approval.hook_id, - {"granted": approval.granted, "reason": approval.reason}, - ) - - -# ============================================================================ -# UI message normalization (heal stale tool states) -# ============================================================================ + return messages_.ToolResultPart( + tool_call_id=tool_call_id, + tool_name=tool_name, + result=result, + is_error=is_error, + ) def _normalize_ui_messages( - ui_messages: list[ui_message.UIMessage], -) -> list[ui_message.UIMessage]: + ui_messages: list[ui_messages_.UIMessage], +) -> list[ui_messages_.UIMessage]: """Heal stale tool-part states from persisted assistant history.""" - normalized: list[ui_message.UIMessage] = [] + normalized: list[ui_messages_.UIMessage] = [] for message in ui_messages: new_parts = [] changed = False for part in message.parts: part_type = getattr(part, "type", None) state = getattr(part, "state", None) - if isinstance(part_type, str) and part_type.startswith("tool-"): + if isinstance(part_type, str) and ( + part_type.startswith("tool-") or part_type == "dynamic-tool" + ): output = getattr(part, "output", None) approval = getattr(part, "approval", None) approved = approval.approved if approval is not None else None @@ -221,39 +146,6 @@ def _normalize_ui_messages( return normalized -# ============================================================================ -# UI → internal message conversion -# ============================================================================ - - -def to_messages( - ui_messages: list[ui_message.UIMessage], -) -> tuple[list[messages_.Message], list[ApprovalResponse]]: - """Parse a UI request into runtime messages + extracted approvals. - - Pure: normalizes stale tool states, extracts approval responses, - parses UIMessages into an ``ai.messages.Message`` list (split at - tool boundaries), drops the internal tombstones for approval - responses, and patches the trailing tool message with - ``is_hook_pending`` placeholders for tool calls whose approval was - just responded to but never recorded a real tool result. - - Sub-agent tool outputs (UIMessage wire shape) are decoded back to - ``MessageBundle`` so the parent agent's message history carries the - rich snapshot. Per-tool model-facing values are populated by - :meth:`Agent.run` (which has the tool registry), not here. - - Returns ``(messages, approvals)``. The caller can pre-register - resolutions via :func:`apply_approvals` before calling - :meth:`Agent.run` if the run should resume from a hook. - """ - normalized = _normalize_ui_messages(ui_messages) - approvals = extract_approvals(normalized) - messages = [m for m in _parse(normalized) if not _is_approval_response(m)] - _patch_pending_hook_aborts(messages, approvals) - return messages, approvals - - def _patch_pending_hook_aborts( messages: list[messages_.Message], approvals: list[ApprovalResponse], @@ -307,124 +199,168 @@ def _patch_pending_hook_aborts( messages[-1] = tool_msg.model_copy(update={"parts": new_parts}) -def _is_approval_response(msg: messages_.Message) -> bool: - """Return whether ``msg`` records a resolved tool-approval hook.""" - if msg.role != "internal" or len(msg.parts) != 1: - return False - part = msg.parts[0] - return ( - isinstance(part, messages_.HookPart) - and part.hook_type == "ToolApproval" - and part.status == "resolved" - ) - - def _parse( - ui_messages: list[ui_message.UIMessage], + ui_messages: list[ui_messages_.UIMessage], ) -> list[messages_.Message]: - def _build_result_part( - *, - tool_call_id: str, - tool_name: str, - output: Any, - is_error: bool, - ) -> messages_.ToolResultPart: - if is_error: - result: Any = output - else: - decoded = _decode_wire_output(output) - result = ( - decoded - if isinstance(decoded, MessageBundle) - else (_normalize_tool_result(decoded)) - ) - return messages_.ToolResultPart( - tool_call_id=tool_call_id, - tool_name=tool_name, - result=result, - is_error=is_error, - ) - result: list[messages_.Message] = [] for ui_msg in ui_messages: + source_messages = id_utils.source_messages_from(ui_msg.metadata) assistant_parts: list[messages_.Part] = [] tool_result_parts: list[messages_.ToolResultPart] = [] hook_parts: list[messages_.HookPart[Any]] = [] for part in ui_msg.parts: match part: - case ui_message.UITextPart(text=text) if text: - assistant_parts.append(messages_.TextPart(text=text)) - - case ui_message.UIReasoningPart(text=reasoning) if reasoning: + case ui_messages_.UITextPart(text=text) if text: assistant_parts.append( - messages_.ReasoningPart(text=reasoning) + messages_.TextPart( + text=text, + provider_metadata=part.provider_metadata, + ) ) - case ui_message.UIToolInvocationPart() as inv: - tool_args = json.dumps(inv.args) if inv.args else "{}" + case ui_messages_.UIReasoningPart(text=reasoning) if reasoning: assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=inv.tool_invocation_id, - tool_name=inv.tool_name, - tool_args=tool_args, + messages_.ReasoningPart( + text=reasoning, + provider_metadata=part.provider_metadata, ) ) - if _is_tool_completed(inv.state): - tool_result_parts.append( - _build_result_part( + + case ui_messages_.UIToolInvocationPart() as inv: + tool_args = json.dumps(inv.args) if inv.args else "{}" + is_completed = ( + inv.state in _TOOL_RESULT_STATES + or inv.state in _TOOL_ERROR_STATES + ) + is_error = inv.state in _TOOL_ERROR_STATES + if inv.provider_executed: + assistant_parts.append( + messages_.BuiltinToolCallPart( tool_call_id=inv.tool_invocation_id, tool_name=inv.tool_name, - output=inv.result, - is_error=_is_tool_error(inv.state), + tool_args=tool_args, ) ) - - case ui_message.UIToolPart() as tp: - assistant_parts.append( - messages_.ToolCallPart( - tool_call_id=tp.tool_call_id, - tool_name=tp.tool_name, - tool_args=_normalize_tool_args(tp.input), + if is_completed: + assistant_parts.append( + messages_.BuiltinToolReturnPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + result=inv.result, + is_error=is_error, + provider_metadata=None, + ) + ) + else: + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + tool_args=tool_args, + ) ) + if is_completed: + tool_result_parts.append( + _build_result_part( + tool_call_id=inv.tool_invocation_id, + tool_name=inv.tool_name, + output=inv.result, + is_error=is_error, + ) + ) + + case ( + ( + ui_messages_.UIToolPart() + | ui_messages_.UIDynamicToolPart() + ) as tp + ): + tool_input = ( + tp.raw_input + if tp.state == "output-error" and tp.input is None + else tp.input ) - approval_hook = _approval_hook_part(tp) + tool_args = normalize_tool_args(tool_input) + is_completed = ( + tp.state in _TOOL_RESULT_STATES + or tp.state in _TOOL_ERROR_STATES + ) + is_error = tp.state in _TOOL_ERROR_STATES + + if tp.provider_executed: + assistant_parts.append( + messages_.BuiltinToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=tool_args, + provider_metadata=tp.call_provider_metadata, + ) + ) + else: + assistant_parts.append( + messages_.ToolCallPart( + tool_call_id=tp.tool_call_id, + tool_name=tp.tool_name, + tool_args=tool_args, + provider_metadata=tp.call_provider_metadata, + ) + ) + approval_hook = approvals.hook_part_from_tool_part(tp) if approval_hook is not None: hook_parts.append(approval_hook) - if tp.state in _TOOL_RESULT_STATES: - tool_result_parts.append( - _build_result_part( + if tp.provider_executed and is_completed: + assistant_parts.append( + messages_.BuiltinToolReturnPart( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, - output=tp.output, - is_error=False, + result=_tool_result_output(tp), + is_error=is_error, + provider_metadata=( + tp.result_provider_metadata + or tp.call_provider_metadata + ), ) ) - elif tp.state == "output-error": + elif is_completed: tool_result_parts.append( - messages_.ToolResultPart( + _build_result_part( tool_call_id=tp.tool_call_id, tool_name=tp.tool_name, - result=_error_result(tp.error_text, tp.output), - is_error=True, + output=_tool_result_output(tp), + is_error=is_error, ) ) + if tp.result_provider_metadata is not None: + tool_result_parts[-1] = tool_result_parts[ + -1 + ].model_copy( + update={ + "provider_metadata": ( + tp.result_provider_metadata + ) + } + ) - case ui_message.UIFilePart() as fp: + case ui_messages_.UIFilePart() as fp: assistant_parts.append( messages_.FilePart( data=fp.url, media_type=fp.media_type, filename=fp.filename, + provider_metadata=fp.provider_metadata, ) ) case ( - ui_message.UIStepStartPart() - | ui_message.UISourceUrlPart() - | ui_message.UISourceDocumentPart() + ui_messages_.UIStepStartPart() + | ui_messages_.UISourceUrlPart() + | ui_messages_.UISourceDocumentPart() + | ui_messages_.UIReasoningFilePart() + | ui_messages_.UICustomPart() + | ui_messages_.UIDataPart() ): pass @@ -441,25 +377,31 @@ def _build_result_part( # LLM APIs expect one message per iteration, so split into # assistant + tool message pairs at tool-result boundaries. if ui_msg.role == "assistant": - result.extend( - _split_assistant_parts( - assistant_parts, tool_result_parts, msg_id=ui_msg.id - ) + parsed = _split_assistant_parts( + assistant_parts, + tool_result_parts, + turn_id=ui_msg.id, ) for hp in hook_parts: - result.append( + parsed.append( messages_.Message( - id=ui_msg.id, + turn_id=ui_msg.id, role="internal", parts=[hp], ) ) + result.extend(id_utils.restore_source_ids(parsed, source_messages)) else: - result.append( - messages_.Message( - id=ui_msg.id, - role=ui_msg.role, - parts=assistant_parts, + result.extend( + id_utils.restore_source_ids( + [ + messages_.Message( + id=ui_msg.id, + role=ui_msg.role, + parts=assistant_parts, + ) + ], + source_messages, ) ) @@ -469,7 +411,7 @@ def _build_result_part( def _split_assistant_parts( parts: list[messages_.Part], tool_results: list[messages_.ToolResultPart], - msg_id: str, + turn_id: str, ) -> list[messages_.Message]: """Split assistant parts into assistant + tool message pairs.""" results_by_id = {tr.tool_call_id: tr for tr in tool_results} @@ -484,7 +426,13 @@ def _split_assistant_parts( if not pending_results: if parts: - return [messages_.Message(role="assistant", parts=parts, id=msg_id)] + return [ + messages_.Message( + role="assistant", + parts=parts, + turn_id=turn_id, + ) + ] return [] messages: list[messages_.Message] = [] @@ -499,10 +447,18 @@ def _split_assistant_parts( and not isinstance(part, messages_.ToolCallPart) ): messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) + messages_.Message( + role="assistant", + parts=current, + turn_id=turn_id, + ) ) messages.append( - messages_.Message(role="tool", parts=list(current_results)) + messages_.Message( + role="tool", + parts=list(current_results), + turn_id=turn_id, + ) ) current = [] current_results = [] @@ -517,11 +473,56 @@ def _split_assistant_parts( if current: messages.append( - messages_.Message(role="assistant", parts=current, id=msg_id) + messages_.Message( + role="assistant", + parts=current, + turn_id=turn_id, + ) ) if current_results: messages.append( - messages_.Message(role="tool", parts=list(current_results)) + messages_.Message( + role="tool", + parts=list(current_results), + turn_id=turn_id, + ) ) return messages + + +# ============================================================================ +# UI → internal message conversion +# ============================================================================ + + +def to_messages( + ui_messages: list[ui_messages_.UIMessage], +) -> tuple[list[messages_.Message], list[ApprovalResponse]]: + """Parse a UI request into runtime messages + extracted approvals. + + Pure: normalizes stale tool states, extracts approval responses, + parses UIMessages into an ``ai.messages.Message`` list (split at + tool boundaries), drops the internal tombstones for approval + responses, and patches the trailing tool message with + ``is_hook_pending`` placeholders for tool calls whose approval was + just responded to but never recorded a real tool result. + + Sub-agent tool outputs (UIMessage wire shape) are decoded back to + ``MessageBundle`` so the parent agent's message history carries the + rich snapshot. Per-tool model-facing values are populated by + :meth:`Agent.run` (which has the tool registry), not here. + + Returns ``(messages, approvals)``. The caller can pre-register + resolutions via :func:`apply_approvals` before calling + :meth:`Agent.run` if the run should resume from a hook. + """ + normalized = _normalize_ui_messages(ui_messages) + approval_responses = extract_approvals(normalized) + messages = [ + m + for m in _parse(normalized) + if not approvals.is_resolved_approval_message(m) + ] + _patch_pending_hook_aborts(messages, approval_responses) + return messages, approval_responses diff --git a/src/ai/agents/ui/ai_sdk/outbound/__init__.py b/src/ai/agents/ui/ai_sdk/outbound/__init__.py deleted file mode 100644 index abb6f69c..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Outbound adapter: ``ai.messages.Message`` stream → AI SDK UI protocol.""" - -from .history import to_ui_messages -from .sse import to_sse -from .stream import to_stream - -__all__ = ["to_sse", "to_stream", "to_ui_messages"] diff --git a/src/ai/agents/ui/ai_sdk/outbound/_state.py b/src/ai/agents/ui/ai_sdk/outbound/_state.py deleted file mode 100644 index baf8ab1c..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/_state.py +++ /dev/null @@ -1,346 +0,0 @@ -"""Stream state bookkeeping for the event-first outbound walk.""" - -from __future__ import annotations - -from typing import Any - -from .....types import events as events_ -from .....types import messages as messages_ -from ....agent import MessageBundle -from .. import _approvals, protocol -from . import history - - -def _tool_error_text(part: messages_.ToolResultPart) -> str: - """Best-effort error text extraction from a failed tool result.""" - if isinstance(part.result, str) and part.result: - return part.result - if isinstance(part.result, dict): - for key in ("error", "message", "detail"): - value = part.result.get(key) - if isinstance(value, str) and value: - return value - return "Tool execution failed" - - -def _to_wire_output(snapshot: Any) -> Any: - """Convert an aggregator snapshot to its UI wire representation. - - For ``MessageBundle`` (sub-agent transcripts) this produces a single - ``UIMessage`` assistant bubble — the canonical AI SDK shape. Other - snapshot types pass through unchanged. - - Returns ``None`` if the bundle has no assistant anchor yet (e.g. a - streaming sub-agent that has produced no messages); callers should - skip emitting in that case. - """ - if isinstance(snapshot, MessageBundle): - ui_msgs = history.to_ui_messages(list(snapshot.messages)) - return ui_msgs[-1] if ui_msgs else None - return snapshot - - -class _StreamState: - """Single-pass state across one ``to_stream()`` call.""" - - def __init__(self) -> None: - self.ui_message_id: str | None = None - self.emitted_start: bool = False - self.in_step: bool = False - - self.started_tool_inputs: set[str] = set() - self.tool_names: dict[str, str] = {} - self.input_available_emitted: set[str] = set() - self.emitted_tool_results: set[str] = set() - self.emitted_approval_requests: set[str] = set() - - self.open_text_ids: set[str] = set() - self.open_reasoning_ids: set[str] = set() - self.completed_text_ids: set[str] = set() - self.completed_reasoning_ids: set[str] = set() - self.text_delta_ids: set[str] = set() - self.reasoning_delta_ids: set[str] = set() - - # Per-tool-call aggregators for streaming generator tools. Each - # PartialToolCallResult feeds its value into the aggregator and - # the snapshot goes out as a preliminary tool output. - self.partial_aggregators: dict[ - str, events_.Aggregator[Any, Any, Any] - ] = {} - - # -- boundary helpers ---------------------------------------------------- - - def _close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: - parts: list[protocol.UIMessageStreamPart] = [] - for rid in list(self.open_reasoning_ids): - parts.append(protocol.ReasoningEndPart(id=rid)) - self.completed_reasoning_ids.add(rid) - self.open_reasoning_ids.clear() - for tid in list(self.open_text_ids): - parts.append(protocol.TextEndPart(id=tid)) - self.completed_text_ids.add(tid) - self.open_text_ids.clear() - return parts - - def _finish_step(self) -> list[protocol.UIMessageStreamPart]: - parts = self._close_open_blocks() - if self.in_step: - parts.append(protocol.FinishStepPart()) - self.in_step = False - return parts - - def _reset_step_tracking(self) -> None: - self.started_tool_inputs.clear() - self.tool_names.clear() - self.input_available_emitted.clear() - self.emitted_tool_results.clear() - self.emitted_approval_requests.clear() - - def _ensure_started(self) -> list[protocol.UIMessageStreamPart]: - """Lazily emit StartPart / StartStepPart on the first event.""" - parts: list[protocol.UIMessageStreamPart] = [] - - if not self.emitted_start: - parts.append(protocol.StartPart(message_id=None)) - parts.append(protocol.StartStepPart()) - self.emitted_start = True - self.in_step = True - self._reset_step_tracking() - - return parts - - # -- phase: streaming events -------------------------------------------- - - def on_event( - self, event: events_.Event - ) -> list[protocol.UIMessageStreamPart]: - out: list[protocol.UIMessageStreamPart] = [] - - # Lazily open the UI message on the first streaming event. - if not self.emitted_start: - out.extend(self._ensure_started()) - - match event: - case events_.TextStart(block_id=pid): - self.open_text_ids.add(pid) - out.append(protocol.TextStartPart(id=pid)) - - case events_.TextDelta(block_id=pid, chunk=chunk): - if pid not in self.open_text_ids: - self.open_text_ids.add(pid) - out.append(protocol.TextStartPart(id=pid)) - self.text_delta_ids.add(pid) - out.append(protocol.TextDeltaPart(id=pid, delta=chunk)) - - case events_.TextEnd(block_id=pid): - if pid in self.open_text_ids: - self.open_text_ids.discard(pid) - self.completed_text_ids.add(pid) - out.append(protocol.TextEndPart(id=pid)) - - case events_.ReasoningStart(block_id=pid): - self.open_reasoning_ids.add(pid) - out.append(protocol.ReasoningStartPart(id=pid)) - - case events_.ReasoningDelta(block_id=pid, chunk=chunk): - if pid not in self.open_reasoning_ids: - self.open_reasoning_ids.add(pid) - out.append(protocol.ReasoningStartPart(id=pid)) - self.reasoning_delta_ids.add(pid) - out.append(protocol.ReasoningDeltaPart(id=pid, delta=chunk)) - - case events_.ReasoningEnd(block_id=pid): - if pid in self.open_reasoning_ids: - self.open_reasoning_ids.discard(pid) - self.completed_reasoning_ids.add(pid) - out.append(protocol.ReasoningEndPart(id=pid)) - - case events_.ToolStart(tool_call_id=tcid, tool_name=name): - self.tool_names[tcid] = name - if tcid in self.started_tool_inputs: - return out - self.started_tool_inputs.add(tcid) - out.append( - protocol.ToolInputStartPart( - tool_call_id=tcid, - tool_name=name, - ) - ) - - case events_.ToolDelta(tool_call_id=tcid, chunk=chunk): - if tcid not in self.started_tool_inputs: - self.started_tool_inputs.add(tcid) - out.append( - protocol.ToolInputStartPart( - tool_call_id=tcid, - tool_name=self.tool_names.get(tcid, ""), - ) - ) - out.append( - protocol.ToolInputDeltaPart( - tool_call_id=tcid, - input_text_delta=chunk, - ) - ) - - case events_.ToolEnd(): - pass - - return out - - # -- phase: tool results ------------------------------------------------ - - def on_tool_result( - self, event: events_.ToolCallResult - ) -> list[protocol.UIMessageStreamPart]: - """Handle a ``ToolCallResult`` — emit tool input/output parts.""" - msg = event.message - out: list[protocol.UIMessageStreamPart] = [] - - out.extend(self._ensure_started()) - - # Emit ToolInputAvailable for each tool call that triggered - # these results (from the assistant message's ToolCallParts). - for part in msg.parts: - if isinstance(part, messages_.ToolCallPart): - if part.tool_call_id in self.input_available_emitted: - continue - self.input_available_emitted.add(part.tool_call_id) - if part.tool_call_id not in self.started_tool_inputs: - self.started_tool_inputs.add(part.tool_call_id) - out.append( - protocol.ToolInputStartPart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - ) - out.append( - protocol.ToolInputAvailablePart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - input=part.tool_args, - ) - ) - - # Emit tool results. - for part in event.results: - if part.tool_call_id in self.emitted_tool_results: - continue - # Hook-abort placeholders are internal bookkeeping: the - # corresponding HookPart(pending) drives the UI state. - if part.is_hook_pending: - continue - self.emitted_tool_results.add(part.tool_call_id) - if part.is_error: - out.append( - protocol.ToolOutputErrorPart( - tool_call_id=part.tool_call_id, - error_text=_tool_error_text(part), - ) - ) - else: - wire_output = _to_wire_output(part.result) - if wire_output is None: - # Aggregator produced no anchor (e.g. sub-agent - # tool that yielded nothing). Skip the final - # output emit; preliminaries already covered the - # streaming view if any. - continue - out.append( - protocol.ToolOutputAvailablePart( - tool_call_id=part.tool_call_id, - output=wire_output, - ) - ) - - return out - - def on_partial_tool_result( - self, event: events_.PartialToolCallResult - ) -> list[protocol.UIMessageStreamPart]: - """Feed the value and emit a preliminary output. - - Each PartialToolCallResult carries one yielded value plus the - aggregator factory the tool was declared with. We instantiate - the aggregator once per ``tool_call_id`` and use its snapshot - as the ``output`` of a preliminary ``ToolOutputAvailablePart``. - The AI SDK supersedes preliminary outputs with the final - ``ToolCallResult`` when it arrives. - """ - out: list[protocol.UIMessageStreamPart] = [] - - tcid = event.tool_call_id - factory = event.aggregator_factory - if tcid is None or factory is None: - return out - - out.extend(self._ensure_started()) - - agg = self.partial_aggregators.get(tcid) - if agg is None: - agg = factory() - self.partial_aggregators[tcid] = agg - agg.feed(event.value) - - wire_output = _to_wire_output(agg.snapshot()) - if wire_output is None: - # Sub-agent bundle without an assistant anchor yet — wait - # for more events before emitting. - return out - - out.append( - protocol.ToolOutputAvailablePart( - tool_call_id=tcid, - output=wire_output, - preliminary=True, - ) - ) - return out - - # -- phase: hooks ------------------------------------------------------- - - def on_hook( - self, event: events_.HookEvent - ) -> list[protocol.UIMessageStreamPart]: - """Handle a ``HookEvent`` — emit approval parts.""" - hook_part = event.hook - out: list[protocol.UIMessageStreamPart] = [] - - # Ensure the UI message is started. - out.extend(self._ensure_started()) - - tc_id = _approvals.tool_call_id_for(hook_part) - if tc_id is None: - return out - - if hook_part.status == "pending": - if tc_id in self.emitted_approval_requests: - return out - self.emitted_approval_requests.add(tc_id) - out.append( - protocol.ToolApprovalRequestPart( - approval_id=hook_part.hook_id, - tool_call_id=tc_id, - ) - ) - elif hook_part.status == "resolved": - resolution: dict[str, Any] = hook_part.resolution or {} - if not resolution.get("granted"): - out.append(protocol.ToolOutputDeniedPart(tool_call_id=tc_id)) - elif hook_part.status == "cancelled": - out.append( - protocol.ToolOutputErrorPart( - tool_call_id=tc_id, - error_text="Hook cancelled", - ) - ) - - return out - - # -- phase: stream finish ------------------------------------------------ - - def finish(self) -> list[protocol.UIMessageStreamPart]: - parts = self._finish_step() - if self.emitted_start: - parts.append(protocol.FinishPart(finish_reason="stop")) - return parts diff --git a/src/ai/agents/ui/ai_sdk/outbound/history.py b/src/ai/agents/ui/ai_sdk/outbound/history.py deleted file mode 100644 index eb5a20a4..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/history.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Persisted-message → UIMessage list for history endpoints.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .. import _parts, ui_message - -if TYPE_CHECKING: - from .....types import messages as messages_ - - -def to_ui_messages( - messages: list[messages_.Message], -) -> list[ui_message.UIMessage]: - """Group persisted messages into UIMessage bubbles. - - ``user``/``system`` messages become standalone UIMessages. Runs of - ``assistant``/``tool``/``internal`` messages merge into a single - assistant UIMessage, with tool results and approval state folded into - the corresponding tool-call parts. - """ - result: list[ui_message.UIMessage] = [] - - i = 0 - while i < len(messages): - msg = messages[i] - - if msg.role in ("user", "system"): - result.append( - ui_message.UIMessage( - id=msg.id, - role=msg.role, - parts=_parts.to_ui_parts(msg.parts), - ) - ) - i += 1 - continue - - if msg.role == "assistant": - ui_parts: list[ui_message.UIMessagePart] = [] - bubble_id = msg.id - - while i < len(messages) and messages[i].role in ( - "assistant", - "tool", - "internal", - ): - current = messages[i] - if current.role == "assistant": - ui_parts.extend(_parts.to_ui_parts(current.parts)) - elif current.role == "tool": - _parts.merge_tool_results(ui_parts, current.parts) - elif current.role == "internal": - _parts.merge_approval_signals(ui_parts, current.parts) - i += 1 - - result.append( - ui_message.UIMessage( - id=bubble_id, - role="assistant", - parts=ui_parts, - ) - ) - continue - - # Orphan tool / internal messages — skip; they have no assistant anchor. - i += 1 - - return result diff --git a/src/ai/agents/ui/ai_sdk/outbound/sse.py b/src/ai/agents/ui/ai_sdk/outbound/sse.py deleted file mode 100644 index 88207c54..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/sse.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Serialize the UI message stream as Server-Sent Events.""" - -from __future__ import annotations - -import dataclasses -import json -from typing import TYPE_CHECKING, Any - -import pydantic - -from .. import protocol -from .stream import to_stream - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterable - - from .....types import events as events_ - - -def _to_camel_case(snake_str: str) -> str: - components = snake_str.split("_") - return components[0] + "".join(x.title() for x in components[1:]) - - -def _json_default(obj: Any) -> Any: - """Fallback encoder for json.dumps — handle pydantic models recursively. - - Aggregator snapshots and tool outputs may carry pydantic models - (e.g. ``MessageBundle``, ``UIMessage``). ``model_dump(mode="json")`` - converts them to plain JSON-native dicts/lists. - """ - if isinstance(obj, pydantic.BaseModel): - return obj.model_dump(mode="json", by_alias=True) - raise TypeError( - f"Object of type {type(obj).__name__} is not JSON serializable" - ) - - -def serialize_part(part: protocol.UIMessageStreamPart) -> str: - """Serialize a stream part to JSON with camelCase keys.""" - d = dataclasses.asdict(part) - if isinstance(part, protocol.DataPart): - d["type"] = part.type - del d["data_type"] - camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} - return json.dumps(camel_dict, default=_json_default) - - -def format_sse(part: protocol.UIMessageStreamPart) -> str: - """Format a stream part as an SSE data line.""" - return f"data: {serialize_part(part)}\n\n" - - -async def to_sse( - events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[str]: - """Convert an internal event stream into SSE strings.""" - async for part in to_stream(events): - yield format_sse(part) diff --git a/src/ai/agents/ui/ai_sdk/outbound/stream.py b/src/ai/agents/ui/ai_sdk/outbound/stream.py deleted file mode 100644 index 4b70f920..00000000 --- a/src/ai/agents/ui/ai_sdk/outbound/stream.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Convert internal event streams into AI SDK UI protocol parts.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .....types import events as events_ -from ._state import _StreamState - -if TYPE_CHECKING: - from collections.abc import AsyncGenerator, AsyncIterable - - from .. import protocol - - -async def to_stream( - events: AsyncIterable[events_.AgentEvent], -) -> AsyncGenerator[protocol.UIMessageStreamPart]: - """Walk ``events`` once, emitting AI SDK UI stream parts. - - Streaming text/reasoning/tool-input deltas come from model events. - Tool results come from ``ToolCallResult``. Hook signals come from - ``HookEvent``. - """ - state = _StreamState() - - async for event in events: - if isinstance(event, events_.ToolCallResult): - for part in state.on_tool_result(event): - yield part - elif isinstance(event, events_.PartialToolCallResult): - for part in state.on_partial_tool_result(event): - yield part - elif isinstance(event, events_.HookEvent): - for part in state.on_hook(event): - yield part - else: - for part in state.on_event(event): - yield part - - for part in state.finish(): - yield part diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py new file mode 100644 index 00000000..4520e20d --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -0,0 +1,382 @@ +"""Persisted-message conversion for AI SDK UI messages.""" + +from __future__ import annotations + +from typing import Any, cast + +from ....types import media +from ....types import messages as messages_ +from . import approvals, id_utils, ui_messages +from .tool_utils import normalize_tool_input + +UIToolLike = ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + +# Internal history can contain separate records for one tool call +# (call, approval, result). AI SDK UI expects one tool part per +# toolCallId, so later/higher-ranked states update the first part +# https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message#tooluipart +# https://ai-sdk.dev/docs/ai-sdk-ui/stream-protocol + +_TOOL_STATE_RANK: dict[ui_messages.UIToolInvocationState, int] = { + "input-streaming": 0, + "input-available": 1, + "approval-requested": 2, + "approval-responded": 3, + "output-denied": 4, + "output-error": 5, + "output-available": 6, +} + +_MERGEABLE_TOOL_PART_FIELDS = ( + "raw_input", + "output", + "error_text", + "approval", + "provider_executed", + "call_provider_metadata", + "result_provider_metadata", + "tool_metadata", + "preliminary", + "title", +) + + +def _merge_tool_part( + existing: UIToolLike, + candidate: UIToolLike, +) -> UIToolLike: + """Merge duplicate UI tool parts, keeping the first display position.""" + existing_rank = _TOOL_STATE_RANK.get(existing.state, 0) + candidate_rank = _TOOL_STATE_RANK.get(candidate.state, 0) + updates: dict[str, Any] = {} + + if candidate_rank >= existing_rank: + updates["state"] = candidate.state + if candidate.state == "output-denied": + updates["output"] = None + + if existing.input is None and candidate.input is not None: + updates["input"] = candidate.input + + for field in _MERGEABLE_TOOL_PART_FIELDS: + value = getattr(candidate, field) + if value is not None: + updates[field] = value + + return existing.model_copy(update=updates) if updates else existing + + +def _tool_part_index_by_call_id( + ui_parts: list[ui_messages.UIMessagePart], +) -> dict[str, int]: + return { + ui_part.tool_call_id: idx + for idx, ui_part in enumerate(ui_parts) + if isinstance( + ui_part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ) + } + + +def dedupe_tool_parts( + ui_parts: list[ui_messages.UIMessagePart], +) -> list[ui_messages.UIMessagePart]: + """Collapse duplicate UI tool parts by tool_call_id.""" + result: list[ui_messages.UIMessagePart] = [] + tool_index: dict[str, int] = {} + + for part in ui_parts: + if not isinstance( + part, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + result.append(part) + continue + + idx = tool_index.get(part.tool_call_id) + if idx is None: + tool_index[part.tool_call_id] = len(result) + result.append(part) + continue + + existing = result[idx] + if isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + result[idx] = _merge_tool_part(existing, part) + + return result + + +def merge_tool_results( + ui_parts: list[ui_messages.UIMessagePart], + tool_parts: list[messages_.Part], +) -> None: + """Merge tool result parts into existing UI tool parts.""" + tool_index = _tool_part_index_by_call_id(ui_parts) + + for part in tool_parts: + updates: dict[str, Any] + match part: + case messages_.ToolResultPart() if part.is_hook_pending: + continue + case messages_.ToolResultPart(): + tool_call_id = part.tool_call_id + state = "output-error" if part.is_error else "output-available" + updates = { + "state": state, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + case messages_.BuiltinToolReturnPart(): + tool_call_id = part.tool_call_id + updates = { + "state": ( + "output-error" if part.is_error else "output-available" + ), + "provider_executed": True, + "result_provider_metadata": part.provider_metadata, + } + if part.is_error: + updates["error_text"] = str(part.result) + else: + updates["output"] = part.result + case _: + continue + + idx_opt = tool_index.get(tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + existing = ui_parts[idx] + if not isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + continue + if existing.state == "output-denied": + continue + ui_parts[idx] = existing.model_copy(update=updates) + + +def merge_approval_signals( + ui_parts: list[ui_messages.UIMessagePart], + internal_parts: list[messages_.Part], +) -> None: + """Merge approval hook state into existing UI tool parts.""" + tool_index = _tool_part_index_by_call_id(ui_parts) + + for part in internal_parts: + if not isinstance(part, messages_.HookPart): + continue + + tool_call_id = approvals.tool_call_id_for(part) + if tool_call_id is None: + continue + + idx_opt = tool_index.get(tool_call_id) + if idx_opt is None: + continue + idx = idx_opt + + existing = ui_parts[idx] + if not isinstance( + existing, ui_messages.UIToolPart | ui_messages.UIDynamicToolPart + ): + continue + + updates: dict[str, Any] = {} + provider_executed = part.metadata.get("providerExecuted") + if isinstance(provider_executed, bool): + updates["provider_executed"] = provider_executed + is_automatic = part.metadata.get("isAutomatic") + is_automatic = is_automatic if isinstance(is_automatic, bool) else None + match part.status: + case "pending": + updates["state"] = "approval-requested" + updates["approval"] = ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "isAutomatic": is_automatic, + } + ) + case "resolved": + resolution = cast( + "dict[str, Any]", + part.resolution + if isinstance(part.resolution, dict) + else {}, + ) + updates["approval"] = ui_messages.UIToolApproval.model_validate( + { + "id": part.hook_id, + "approved": resolution.get("granted"), + "reason": resolution.get("reason"), + "isAutomatic": is_automatic, + } + ) + if resolution.get("granted", False): + updates["state"] = "approval-responded" + else: + updates["state"] = "output-denied" + updates["output"] = None + case "cancelled": + updates["state"] = "output-error" + updates["error_text"] = "Hook cancelled" + + if updates: + ui_parts[idx] = existing.model_copy(update=updates) + + +def to_ui_parts(parts: list[messages_.Part]) -> list[ui_messages.UIMessagePart]: + """Convert internal parts to UI message parts.""" + result: list[ui_messages.UIMessagePart] = [] + for part in parts: + match part: + case messages_.TextPart(text=text) if text: + result.append( + ui_messages.UITextPart.model_validate( + { + "type": "text", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) + ) + case messages_.ReasoningPart(text=text) if text: + result.append( + ui_messages.UIReasoningPart.model_validate( + { + "type": "reasoning", + "text": text, + "providerMetadata": part.provider_metadata, + } + ) + ) + case messages_.ToolCallPart(): + result.append( + ui_messages.UIToolPart.model_validate( + { + "type": f"tool-{part.tool_name}", + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "callProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.BuiltinToolCallPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": "input-available", + "input": normalize_tool_input(part.tool_args), + "providerExecuted": True, + "callProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.BuiltinToolReturnPart(): + result.append( + ui_messages.UIDynamicToolPart.model_validate( + { + "type": "dynamic-tool", + "toolName": part.tool_name, + "toolCallId": part.tool_call_id, + "state": ( + "output-error" + if part.is_error + else "output-available" + ), + "input": None, + "output": None if part.is_error else part.result, + "errorText": ( + str(part.result) if part.is_error else None + ), + "providerExecuted": True, + "resultProviderMetadata": part.provider_metadata, + } + ) + ) + case messages_.FilePart(): + result.append( + ui_messages.UIFilePart.model_validate( + { + "type": "file", + "mediaType": part.media_type, + "url": media.data_to_data_url( + part.data, part.media_type + ), + "filename": part.filename, + "providerMetadata": part.provider_metadata, + } + ) + ) + return result + + +def to_ui_messages( + messages: list[messages_.Message], +) -> list[ui_messages.UIMessage]: + """Group persisted messages into UI message bubbles.""" + result: list[ui_messages.UIMessage] = [] + + i = 0 + while i < len(messages): + msg = messages[i] + + match msg.role: + case "user" | "system": + result.append( + ui_messages.UIMessage( + id=msg.id, + role=msg.role, + metadata=id_utils.metadata_for([msg]), + parts=to_ui_parts(msg.parts), + ) + ) + i += 1 + case "assistant": + ui_parts: list[ui_messages.UIMessagePart] = [] + source_messages: list[messages_.Message] = [] + bubble_id = msg.turn_id or msg.id + + while i < len(messages) and messages[i].role in ( + "assistant", + "tool", + "internal", + ): + current = messages[i] + if ( + current.turn_id is not None + and current.turn_id != bubble_id + ): + break + source_messages.append(current) + match current.role: + case "assistant": + ui_parts.extend(to_ui_parts(current.parts)) + ui_parts = dedupe_tool_parts(ui_parts) + case "tool": + merge_tool_results(ui_parts, current.parts) + case "internal": + merge_approval_signals(ui_parts, current.parts) + i += 1 + ui_parts = dedupe_tool_parts(ui_parts) + + result.append( + ui_messages.UIMessage( + id=bubble_id, + role="assistant", + metadata=id_utils.metadata_for(source_messages), + parts=ui_parts, + ) + ) + case _: + i += 1 + + return result diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py new file mode 100644 index 00000000..288bfbe5 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -0,0 +1,616 @@ +"""Live event stream conversion for the AI SDK UI protocol.""" + +from __future__ import annotations + +import dataclasses +import json +from typing import TYPE_CHECKING, Any + +import pydantic + +from ....types import events as events_ +from ....types import media +from ....types import messages as messages_ +from ...agent import MessageBundle +from . import approvals, outbound_messages, ui_events +from .tool_utils import normalize_tool_input + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, AsyncIterable + + +def _tool_error_text(part: messages_.ToolResultPart) -> str: + """Best-effort error text extraction from a failed tool result.""" + if isinstance(part.result, str) and part.result: + return part.result + if isinstance(part.result, dict): + for key in ("error", "message", "detail"): + value = part.result.get(key) + if isinstance(value, str) and value: + return value + return "Tool execution failed" + + +def _to_wire_output(snapshot: Any) -> Any: + """Convert an aggregator snapshot to its UI wire representation. + + For ``MessageBundle`` (sub-agent transcripts) this produces a single + ``UIMessage`` assistant bubble — the canonical AI SDK shape. Other + snapshot types pass through unchanged. + + Returns ``None`` if the bundle has no assistant anchor yet (e.g. a + streaming sub-agent that has produced no messages); callers should + skip emitting in that case. + """ + if isinstance(snapshot, MessageBundle): + ui_msgs = outbound_messages.to_ui_messages(list(snapshot.messages)) + return ui_msgs[-1] if ui_msgs else None + return snapshot + + +class _StreamState: + """Single-pass state across one ``to_stream()`` call.""" + + def __init__(self) -> None: + self.ui_message_id: str | None = None + self.emitted_start: bool = False + self.in_step: bool = False + + self.started_tool_inputs: set[str] = set() + self.tool_names: dict[str, str] = {} + self.input_available_emitted: set[str] = set() + self.emitted_tool_results: set[str] = set() + self.emitted_approval_requests: set[str] = set() + + self.open_text_ids: set[str] = set() + self.open_reasoning_ids: set[str] = set() + self.completed_text_ids: set[str] = set() + self.completed_reasoning_ids: set[str] = set() + self.text_delta_ids: set[str] = set() + self.reasoning_delta_ids: set[str] = set() + self.source_messages: dict[str, messages_.Message] = {} + + # Per-tool-call aggregators for streaming generator tools. Each + # PartialToolCallResult feeds its value into the aggregator and + # the snapshot goes out as a preliminary tool output. + self.partial_aggregators: dict[ + str, events_.Aggregator[Any, Any, Any] + ] = {} + + # -- boundary helpers ---------------------------------------------------- + + def _track_source_message(self, message: messages_.Message | None) -> None: + if message is None or message.id == "": + return + self.source_messages[message.id] = message + + def _latest_assistant_metadata(self) -> Any | None: + messages = [ + msg + for msg in self.source_messages.values() + if msg.role != "system" and msg.parts + ] + ui_messages = outbound_messages.to_ui_messages(messages) + for message in reversed(ui_messages): + if message.role == "assistant": + return message.metadata + return None + + def _close_open_blocks(self) -> list[ui_events.UIMessageStreamEvent]: + events: list[ui_events.UIMessageStreamEvent] = [] + for rid in list(self.open_reasoning_ids): + events.append(ui_events.UIReasoningEndEvent(id=rid)) + self.completed_reasoning_ids.add(rid) + self.open_reasoning_ids.clear() + for tid in list(self.open_text_ids): + events.append(ui_events.UITextEndEvent(id=tid)) + self.completed_text_ids.add(tid) + self.open_text_ids.clear() + return events + + def _ensure_started( + self, + message_id: str | None = None, + ) -> list[ui_events.UIMessageStreamEvent]: + """Lazily emit UIStartEvent / UIStartStepEvent on the first event.""" + events: list[ui_events.UIMessageStreamEvent] = [] + + if not self.emitted_start: + self.ui_message_id = message_id + events.append(ui_events.UIStartEvent(message_id=self.ui_message_id)) + events.append(ui_events.UIStartStepEvent()) + self.emitted_start = True + self.in_step = True + self.started_tool_inputs.clear() + self.tool_names.clear() + self.input_available_emitted.clear() + self.emitted_tool_results.clear() + self.emitted_approval_requests.clear() + + return events + + # -- phase: streaming events -------------------------------------------- + + def on_event( + self, event: events_.Event + ) -> list[ui_events.UIMessageStreamEvent]: + out: list[ui_events.UIMessageStreamEvent] = [] + self._track_source_message(event.message) + + # Lazily open the UI message on the first streaming event. + if not self.emitted_start: + message = event.message + message_id = None + if message.role == "assistant": + if message.turn_id is not None: + message_id = message.turn_id + elif message.id != "": + message_id = message.id + out.extend(self._ensure_started(message_id)) + + match event: + case events_.TextStart(block_id=pid): + self.open_text_ids.add(pid) + out.append( + ui_events.UITextStartEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.TextDelta(block_id=pid, chunk=chunk): + if pid not in self.open_text_ids: + self.open_text_ids.add(pid) + out.append( + ui_events.UITextStartEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + self.text_delta_ids.add(pid) + out.append( + ui_events.UITextDeltaEvent( + id=pid, + delta=chunk, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.TextEnd(block_id=pid): + if pid in self.open_text_ids: + self.open_text_ids.discard(pid) + self.completed_text_ids.add(pid) + out.append( + ui_events.UITextEndEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.ReasoningStart(block_id=pid): + self.open_reasoning_ids.add(pid) + out.append( + ui_events.UIReasoningStartEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.ReasoningDelta(block_id=pid, chunk=chunk): + if pid not in self.open_reasoning_ids: + self.open_reasoning_ids.add(pid) + out.append( + ui_events.UIReasoningStartEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + self.reasoning_delta_ids.add(pid) + out.append( + ui_events.UIReasoningDeltaEvent( + id=pid, + delta=chunk, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.ReasoningEnd(block_id=pid): + if pid in self.open_reasoning_ids: + self.open_reasoning_ids.discard(pid) + self.completed_reasoning_ids.add(pid) + out.append( + ui_events.UIReasoningEndEvent( + id=pid, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.ToolStart(tool_call_id=tcid, tool_name=name): + self.tool_names[tcid] = name + if tcid in self.started_tool_inputs: + return out + self.started_tool_inputs.add(tcid) + out.append( + ui_events.UIToolInputStartEvent( + tool_call_id=tcid, + tool_name=name, + provider_metadata=event.provider_metadata, + ) + ) + + case events_.ToolDelta(tool_call_id=tcid, chunk=chunk): + if tcid not in self.started_tool_inputs: + self.started_tool_inputs.add(tcid) + out.append( + ui_events.UIToolInputStartEvent( + tool_call_id=tcid, + tool_name=self.tool_names.get(tcid, ""), + provider_metadata=event.provider_metadata, + ) + ) + out.append( + ui_events.UIToolInputDeltaEvent( + tool_call_id=tcid, + input_text_delta=chunk, + ) + ) + + case events_.ToolEnd(): + pass + + case events_.BuiltinToolStart(tool_call_id=tcid, tool_name=name): + self.tool_names[tcid] = name + if tcid in self.started_tool_inputs: + return out + self.started_tool_inputs.add(tcid) + out.append( + ui_events.UIToolInputStartEvent( + tool_call_id=tcid, + tool_name=name, + provider_executed=True, + provider_metadata=event.provider_metadata, + dynamic=True, + ) + ) + + case events_.BuiltinToolDelta(tool_call_id=tcid, chunk=chunk): + if tcid not in self.started_tool_inputs: + self.started_tool_inputs.add(tcid) + out.append( + ui_events.UIToolInputStartEvent( + tool_call_id=tcid, + tool_name=self.tool_names.get(tcid, ""), + provider_executed=True, + provider_metadata=event.provider_metadata, + dynamic=True, + ) + ) + out.append( + ui_events.UIToolInputDeltaEvent( + tool_call_id=tcid, + input_text_delta=chunk, + ) + ) + + case events_.BuiltinToolEnd(tool_call_id=tcid, tool_call=tc): + if tcid not in self.input_available_emitted: + self.input_available_emitted.add(tcid) + out.append( + ui_events.UIToolInputAvailableEvent( + tool_call_id=tcid, + tool_name=tc.tool_name, + input=normalize_tool_input(tc.tool_args), + provider_executed=True, + provider_metadata=tc.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + + case events_.BuiltinToolResult(tool_call_id=tcid, result=result): + if tcid in self.emitted_tool_results: + return out + self.emitted_tool_results.add(tcid) + if result.is_error: + out.append( + ui_events.UIToolOutputErrorEvent( + tool_call_id=tcid, + error_text=str(result.result), + provider_executed=True, + provider_metadata=result.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + else: + out.append( + ui_events.UIToolOutputAvailableEvent( + tool_call_id=tcid, + output=result.result, + provider_executed=True, + provider_metadata=result.provider_metadata + or event.provider_metadata, + dynamic=True, + ) + ) + + case events_.FileEvent( + media_type=media_type, + data=data, + ): + out.append( + ui_events.UIFileEvent( + url=media.data_to_data_url(data, media_type), + media_type=media_type, + provider_metadata=event.provider_metadata, + ) + ) + + return out + + # -- phase: tool results ------------------------------------------------ + + def on_tool_result( + self, event: events_.ToolCallResult + ) -> list[ui_events.UIMessageStreamEvent]: + """Handle a ``ToolCallResult`` — emit tool input/output events.""" + msg = event.message + out: list[ui_events.UIMessageStreamEvent] = [] + + self._track_source_message(msg) + out.extend(self._ensure_started(msg.turn_id)) + + # Emit ToolInputAvailable for each tool call that triggered + # these results (from the assistant message's ToolCallParts). + for part in msg.parts: + if isinstance(part, messages_.ToolCallPart): + if part.tool_call_id in self.input_available_emitted: + continue + self.input_available_emitted.add(part.tool_call_id) + if part.tool_call_id not in self.started_tool_inputs: + self.started_tool_inputs.add(part.tool_call_id) + out.append( + ui_events.UIToolInputStartEvent( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + provider_metadata=part.provider_metadata, + ) + ) + out.append( + ui_events.UIToolInputAvailableEvent( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + input=normalize_tool_input(part.tool_args), + provider_metadata=part.provider_metadata, + ) + ) + + # Emit tool results. + for part in event.results: + if part.tool_call_id in self.emitted_tool_results: + continue + # Hook-abort placeholders are internal bookkeeping: the + # corresponding HookPart(pending) drives the UI state. + if part.is_hook_pending: + continue + self.emitted_tool_results.add(part.tool_call_id) + if part.is_error: + out.append( + ui_events.UIToolOutputErrorEvent( + tool_call_id=part.tool_call_id, + error_text=_tool_error_text(part), + provider_metadata=part.provider_metadata, + ) + ) + else: + wire_output = _to_wire_output(part.result) + if wire_output is None: + # Aggregator produced no anchor (e.g. sub-agent + # tool that yielded nothing). Skip the final + # output emit; preliminaries already covered the + # streaming view if any. + continue + out.append( + ui_events.UIToolOutputAvailableEvent( + tool_call_id=part.tool_call_id, + output=wire_output, + provider_metadata=part.provider_metadata, + ) + ) + + return out + + def on_partial_tool_result( + self, event: events_.PartialToolCallResult + ) -> list[ui_events.UIMessageStreamEvent]: + """Feed the value and emit a preliminary output. + + Each PartialToolCallResult carries one yielded value plus the + aggregator factory the tool was declared with. We instantiate + the aggregator once per ``tool_call_id`` and use its snapshot + as the ``output`` of a preliminary ``UIToolOutputAvailableEvent``. + The AI SDK supersedes preliminary outputs with the final + ``ToolCallResult`` when it arrives. + """ + out: list[ui_events.UIMessageStreamEvent] = [] + + tcid = event.tool_call_id + factory = event.aggregator_factory + if tcid is None or factory is None: + return out + + out.extend(self._ensure_started()) + + agg = self.partial_aggregators.get(tcid) + if agg is None: + agg = factory() + self.partial_aggregators[tcid] = agg + agg.feed(event.value) + + wire_output = _to_wire_output(agg.snapshot()) + if wire_output is None: + # Sub-agent bundle without an assistant anchor yet — wait + # for more events before emitting. + return out + + out.append( + ui_events.UIToolOutputAvailableEvent( + tool_call_id=tcid, + output=wire_output, + preliminary=True, + ) + ) + return out + + # -- phase: hooks ------------------------------------------------------- + + def on_hook( + self, event: events_.HookEvent + ) -> list[ui_events.UIMessageStreamEvent]: + """Handle a ``HookEvent`` — emit approval events.""" + hook_part = event.hook + out: list[ui_events.UIMessageStreamEvent] = [] + + self._track_source_message(event.message) + # Ensure the UI message is started. + out.extend(self._ensure_started(event.message.turn_id)) + + tc_id = approvals.tool_call_id_for(hook_part) + if tc_id is None: + return out + + is_automatic = hook_part.metadata.get("isAutomatic") + is_automatic = is_automatic if isinstance(is_automatic, bool) else None + match hook_part.status: + case "pending": + if tc_id in self.emitted_approval_requests: + return out + self.emitted_approval_requests.add(tc_id) + out.append( + ui_events.UIToolApprovalRequestEvent( + approval_id=hook_part.hook_id, + tool_call_id=tc_id, + is_automatic=is_automatic, + ) + ) + case "resolved": + resolution: dict[str, Any] = hook_part.resolution or {} + provider_executed = hook_part.metadata.get("providerExecuted") + provider_executed = ( + provider_executed + if isinstance(provider_executed, bool) + else None + ) + provider_metadata = hook_part.metadata.get( + "callProviderMetadata" + ) + provider_metadata = ( + provider_metadata + if isinstance(provider_metadata, dict) + else None + ) + out.append( + ui_events.UIToolApprovalResponseEvent( + approval_id=hook_part.hook_id, + approved=bool(resolution.get("granted")), + reason=resolution.get("reason"), + provider_executed=provider_executed, + provider_metadata=provider_metadata, + ) + ) + if not resolution.get("granted"): + out.append( + ui_events.UIToolOutputDeniedEvent(tool_call_id=tc_id) + ) + case "cancelled": + out.append( + ui_events.UIToolOutputErrorEvent( + tool_call_id=tc_id, + error_text="Hook cancelled", + ) + ) + + return out + + # -- phase: stream finish ------------------------------------------------ + + def finish(self) -> list[ui_events.UIMessageStreamEvent]: + events = self._close_open_blocks() + if self.in_step: + events.append(ui_events.UIFinishStepEvent()) + self.in_step = False + if self.emitted_start: + events.append( + ui_events.UIFinishEvent( + finish_reason="stop", + message_metadata=self._latest_assistant_metadata(), + ) + ) + return events + + +def _to_camel_case(snake_str: str) -> str: + components = snake_str.split("_") + return components[0] + "".join(x.title() for x in components[1:]) + + +def _json_default(obj: Any) -> Any: + if isinstance(obj, pydantic.BaseModel): + return obj.model_dump(mode="json", by_alias=True) + raise TypeError( + f"Object of type {type(obj).__name__} is not JSON serializable" + ) + + +def serialize_event(event: ui_events.UIMessageStreamEvent) -> str: + """Serialize a stream event to JSON with camelCase keys.""" + d = dataclasses.asdict(event) + if isinstance(event, ui_events.UIDataEvent): + d["type"] = event.type + del d["data_type"] + camel_dict = {_to_camel_case(k): v for k, v in d.items() if v is not None} + return json.dumps(camel_dict, default=_json_default) + + +def format_sse(event: ui_events.UIMessageStreamEvent) -> str: + """Format a stream event as an SSE data line.""" + return f"data: {serialize_event(event)}\n\n" + + +def format_done_sse() -> str: + """Format the AI SDK UI stream termination marker.""" + return "data: [DONE]\n\n" + + +async def to_stream( + events: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[ui_events.UIMessageStreamEvent]: + """Walk internal events once, emitting AI SDK UI stream events.""" + state = _StreamState() + + async for event in events: + match event: + case events_.ToolCallResult(): + for ui_event in state.on_tool_result(event): + yield ui_event + case events_.PartialToolCallResult(): + for ui_event in state.on_partial_tool_result(event): + yield ui_event + case events_.HookEvent(): + for ui_event in state.on_hook(event): + yield ui_event + case _: + for ui_event in state.on_event(event): + yield ui_event + + for ui_event in state.finish(): + yield ui_event + + +async def to_sse( + events: AsyncIterable[events_.AgentEvent], +) -> AsyncGenerator[str]: + """Convert an internal event stream into SSE strings.""" + async for event in to_stream(events): + yield format_sse(event) + yield format_done_sse() diff --git a/src/ai/agents/ui/ai_sdk/tool_utils.py b/src/ai/agents/ui/ai_sdk/tool_utils.py new file mode 100644 index 00000000..b3f5d8a4 --- /dev/null +++ b/src/ai/agents/ui/ai_sdk/tool_utils.py @@ -0,0 +1,25 @@ +"""Tool value helpers shared by AI SDK UI adapters.""" + +from __future__ import annotations + +import json +from typing import Any + + +def normalize_tool_input(raw: str) -> Any: + """Parse serialized tool args into a JSON value when possible.""" + try: + return json.loads(raw) + except (json.JSONDecodeError, TypeError): + return raw + + +def normalize_tool_args(tool_input: Any) -> str: + """Normalize UI tool input to the internal serialized args form.""" + match tool_input: + case str(): + return tool_input + case None: + return "{}" + case _: + return json.dumps(tool_input) diff --git a/src/ai/agents/ui/ai_sdk/protocol.py b/src/ai/agents/ui/ai_sdk/ui_events.py similarity index 67% rename from src/ai/agents/ui/ai_sdk/protocol.py rename to src/ai/agents/ui/ai_sdk/ui_events.py index e0eb8902..a3a130b2 100644 --- a/src/ai/agents/ui/ai_sdk/protocol.py +++ b/src/ai/agents/ui/ai_sdk/ui_events.py @@ -13,7 +13,7 @@ } -# different kinds of messages expected by the frontend +# different kinds of stream events expected by the frontend FinishReason = Literal[ "stop", "length", "content-filter", "tool-calls", "error", "other" @@ -21,7 +21,7 @@ @dataclasses.dataclass -class StartPart: +class UIStartEvent: """Indicates the beginning of a new message with metadata.""" type: Literal["start"] = dataclasses.field(default="start", init=False) @@ -30,7 +30,7 @@ class StartPart: @dataclasses.dataclass -class TextStartPart: +class UITextStartEvent: """Indicates the beginning of a text block.""" id: str @@ -41,7 +41,7 @@ class TextStartPart: @dataclasses.dataclass -class TextDeltaPart: +class UITextDeltaEvent: """Contains incremental text content for the text block.""" id: str @@ -53,7 +53,7 @@ class TextDeltaPart: @dataclasses.dataclass -class TextEndPart: +class UITextEndEvent: """Indicates the completion of a text block.""" id: str @@ -64,7 +64,7 @@ class TextEndPart: @dataclasses.dataclass -class ReasoningStartPart: +class UIReasoningStartEvent: """Indicates the beginning of a reasoning block.""" id: str @@ -75,7 +75,7 @@ class ReasoningStartPart: @dataclasses.dataclass -class ReasoningDeltaPart: +class UIReasoningDeltaEvent: """Contains incremental reasoning content for the reasoning block.""" id: str @@ -87,7 +87,7 @@ class ReasoningDeltaPart: @dataclasses.dataclass -class ReasoningEndPart: +class UIReasoningEndEvent: """Indicates the completion of a reasoning block.""" id: str @@ -98,7 +98,16 @@ class ReasoningEndPart: @dataclasses.dataclass -class SourceUrlPart: +class UICustomEvent: + """Provider-specific content that does not fit standard UI events.""" + + kind: str + type: Literal["custom"] = dataclasses.field(default="custom", init=False) + provider_metadata: dict[str, Any] | None = None + + +@dataclasses.dataclass +class UISourceUrlEvent: """References to external URLs.""" source_id: str @@ -111,7 +120,7 @@ class SourceUrlPart: @dataclasses.dataclass -class SourceDocumentPart: +class UISourceDocumentEvent: """References to documents or files.""" source_id: str @@ -125,8 +134,8 @@ class SourceDocumentPart: @dataclasses.dataclass -class FilePart: - """The file parts contain references to files with their media type.""" +class UIFileEvent: + """References to files with their media type.""" url: str media_type: str @@ -135,14 +144,26 @@ class FilePart: @dataclasses.dataclass -class DataPart: - """Custom data part for arbitrary structured data. +class UIReasoningFileEvent: + """A file generated as part of model reasoning.""" + + url: str + media_type: str + type: Literal["reasoning-file"] = dataclasses.field( + default="reasoning-file", init=False + ) + provider_metadata: dict[str, Any] | None = None + + +@dataclasses.dataclass +class UIDataEvent: + """Custom data event for arbitrary structured data. - Data parts support type-specific handling. + Data events support type-specific handling. The wire type is ``data-{data_type}`` (e.g. ``data-custom``), exposed - via the ``type`` property so that ``DataPart`` is uniform with every - other ``UIMessageStreamPart`` variant. + via the ``type`` property so that ``UIDataEvent`` is uniform with every + other ``UIMessageStreamEvent`` variant. """ data_type: str @@ -157,7 +178,7 @@ def type(self) -> str: @dataclasses.dataclass -class ToolInputStartPart: +class UIToolInputStartEvent: """Indicates the beginning of tool input streaming.""" tool_call_id: str @@ -166,12 +187,14 @@ class ToolInputStartPart: default="tool-input-start", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @dataclasses.dataclass -class ToolInputDeltaPart: +class UIToolInputDeltaEvent: """Incremental chunks of tool input as it's being generated.""" tool_call_id: str @@ -182,7 +205,7 @@ class ToolInputDeltaPart: @dataclasses.dataclass -class ToolInputAvailablePart: +class UIToolInputAvailableEvent: """Indicates that tool input is complete and ready for execution.""" tool_call_id: str @@ -193,12 +216,13 @@ class ToolInputAvailablePart: ) provider_executed: bool | None = None provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @dataclasses.dataclass -class ToolInputErrorPart: +class UIToolInputErrorEvent: """Indicates an error occurred during tool input processing.""" tool_call_id: str @@ -210,12 +234,13 @@ class ToolInputErrorPart: ) provider_executed: bool | None = None provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None title: str | None = None @dataclasses.dataclass -class ToolOutputAvailablePart: +class UIToolOutputAvailableEvent: """Contains the result of tool execution.""" tool_call_id: str @@ -224,12 +249,14 @@ class ToolOutputAvailablePart: default="tool-output-available", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None preliminary: bool | None = None @dataclasses.dataclass -class ToolOutputErrorPart: +class UIToolOutputErrorEvent: """Indicates an error occurred during tool execution.""" tool_call_id: str @@ -238,11 +265,13 @@ class ToolOutputErrorPart: default="tool-output-error", init=False ) provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None + tool_metadata: dict[str, Any] | None = None dynamic: bool | None = None @dataclasses.dataclass -class ToolOutputDeniedPart: +class UIToolOutputDeniedEvent: """Indicates tool execution was denied.""" tool_call_id: str @@ -252,7 +281,7 @@ class ToolOutputDeniedPart: @dataclasses.dataclass -class ToolApprovalRequestPart: +class UIToolApprovalRequestEvent: """Requests approval for tool execution.""" approval_id: str @@ -260,11 +289,26 @@ class ToolApprovalRequestPart: type: Literal["tool-approval-request"] = dataclasses.field( default="tool-approval-request", init=False ) + is_automatic: bool | None = None + + +@dataclasses.dataclass +class UIToolApprovalResponseEvent: + """Records an approval decision for a tool call.""" + + approval_id: str + approved: bool + type: Literal["tool-approval-response"] = dataclasses.field( + default="tool-approval-response", init=False + ) + reason: str | None = None + provider_executed: bool | None = None + provider_metadata: dict[str, Any] | None = None @dataclasses.dataclass -class StartStepPart: - """A part indicating the start of a step.""" +class UIStartStepEvent: + """Indicates the start of a step.""" type: Literal["start-step"] = dataclasses.field( default="start-step", init=False @@ -272,8 +316,8 @@ class StartStepPart: @dataclasses.dataclass -class FinishStepPart: - """A part indicating that a step has been completed.""" +class UIFinishStepEvent: + """Indicates that a step has been completed.""" type: Literal["finish-step"] = dataclasses.field( default="finish-step", init=False @@ -281,8 +325,8 @@ class FinishStepPart: @dataclasses.dataclass -class FinishPart: - """A part indicating the completion of a message.""" +class UIFinishEvent: + """Indicates the completion of a message.""" type: Literal["finish"] = dataclasses.field(default="finish", init=False) finish_reason: FinishReason | None = None @@ -290,14 +334,15 @@ class FinishPart: @dataclasses.dataclass -class AbortPart: +class UIAbortEvent: """Indicates the message was aborted.""" type: Literal["abort"] = dataclasses.field(default="abort", init=False) + reason: str | None = None @dataclasses.dataclass -class MessageMetadataPart: +class UIMessageMetadataEvent: """Contains message metadata.""" message_metadata: Any @@ -307,37 +352,40 @@ class MessageMetadataPart: @dataclasses.dataclass -class ErrorPart: - """The error parts are appended to the message as they are received.""" +class UIErrorEvent: + """Errors appended to the message as they are received.""" error_text: str type: Literal["error"] = dataclasses.field(default="error", init=False) -UIMessageStreamPart = ( - StartPart - | TextStartPart - | TextDeltaPart - | TextEndPart - | ReasoningStartPart - | ReasoningDeltaPart - | ReasoningEndPart - | SourceUrlPart - | SourceDocumentPart - | FilePart - | DataPart - | ToolInputStartPart - | ToolInputDeltaPart - | ToolInputAvailablePart - | ToolInputErrorPart - | ToolOutputAvailablePart - | ToolOutputErrorPart - | ToolOutputDeniedPart - | ToolApprovalRequestPart - | StartStepPart - | FinishStepPart - | FinishPart - | AbortPart - | MessageMetadataPart - | ErrorPart +UIMessageStreamEvent = ( + UIStartEvent + | UITextStartEvent + | UITextDeltaEvent + | UITextEndEvent + | UIReasoningStartEvent + | UIReasoningDeltaEvent + | UIReasoningEndEvent + | UICustomEvent + | UISourceUrlEvent + | UISourceDocumentEvent + | UIFileEvent + | UIReasoningFileEvent + | UIDataEvent + | UIToolInputStartEvent + | UIToolInputDeltaEvent + | UIToolInputAvailableEvent + | UIToolInputErrorEvent + | UIToolOutputAvailableEvent + | UIToolOutputErrorEvent + | UIToolOutputDeniedEvent + | UIToolApprovalRequestEvent + | UIToolApprovalResponseEvent + | UIStartStepEvent + | UIFinishStepEvent + | UIFinishEvent + | UIAbortEvent + | UIMessageMetadataEvent + | UIErrorEvent ) diff --git a/src/ai/agents/ui/ai_sdk/ui_message.py b/src/ai/agents/ui/ai_sdk/ui_messages.py similarity index 61% rename from src/ai/agents/ui/ai_sdk/ui_message.py rename to src/ai/agents/ui/ai_sdk/ui_messages.py index ea3ccd0f..f48cef34 100644 --- a/src/ai/agents/ui/ai_sdk/ui_message.py +++ b/src/ai/agents/ui/ai_sdk/ui_messages.py @@ -14,12 +14,20 @@ from ....types import messages as messages_ +_UI_MODEL_CONFIG = pydantic.ConfigDict(populate_by_name=True, extra="allow") + class UITextPart(pydantic.BaseModel): """Text content part in AI SDK v6 format.""" + model_config = _UI_MODEL_CONFIG + type: Literal["text"] text: str + state: Literal["streaming", "done"] | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UIReasoningPart(pydantic.BaseModel): @@ -31,9 +39,26 @@ class UIReasoningPart(pydantic.BaseModel): we accept it but don't currently route on it. """ + model_config = _UI_MODEL_CONFIG + type: Literal["reasoning"] text: str state: Literal["streaming", "done"] | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UICustomPart(pydantic.BaseModel): + """Provider-specific content that does not fit standard UI parts.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["custom"] + kind: str + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) # Tool invocation states in AI SDK v6: @@ -65,7 +90,7 @@ class UIToolInvocationPart(pydantic.BaseModel): Reference: https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["tool-invocation"] tool_invocation_id: str = pydantic.Field(alias="toolInvocationId") @@ -73,11 +98,16 @@ class UIToolInvocationPart(pydantic.BaseModel): args: dict[str, Any] = pydantic.Field(default_factory=dict) state: UIToolInvocationState = "input-available" result: Any | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) class UIStepStartPart(pydantic.BaseModel): """Step boundary marker. Skipped during conversion to internal format.""" + model_config = _UI_MODEL_CONFIG + type: Literal["step-start"] @@ -89,11 +119,14 @@ class UIToolApproval(pydantic.BaseModel): ``approved`` is None while awaiting a response, True/False after. """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG id: str approved: bool | None = None reason: str | None = None + is_automatic: bool | None = pydantic.Field( + default=None, alias="isAutomatic" + ) class UIToolPart(pydantic.BaseModel): @@ -103,7 +136,7 @@ class UIToolPart(pydantic.BaseModel): where the tool name is embedded in the type string. """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG # The actual type string (e.g., "tool-talk_to_mothership") # We store this to extract the tool name @@ -112,8 +145,23 @@ class UIToolPart(pydantic.BaseModel): state: UIToolInvocationState input: str | dict[str, Any] | None = None # JSON string or parsed dict output: Any | None = None + raw_input: Any | None = pydantic.Field(default=None, alias="rawInput") error_text: str | None = pydantic.Field(default=None, alias="errorText") approval: UIToolApproval | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) + call_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="callProviderMetadata" + ) + result_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="resultProviderMetadata" + ) + tool_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="toolMetadata" + ) + preliminary: bool | None = None + title: str | None = None @property def tool_name(self) -> str: @@ -126,61 +174,134 @@ def tool_name(self) -> str: return self.type +class UIDynamicToolPart(pydantic.BaseModel): + """Dynamic tool part where the tool name is a field, not the type suffix.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["dynamic-tool"] + tool_name: str = pydantic.Field(alias="toolName") + tool_call_id: str = pydantic.Field(alias="toolCallId") + state: UIToolInvocationState + input: Any | None = None + output: Any | None = None + raw_input: Any | None = pydantic.Field(default=None, alias="rawInput") + error_text: str | None = pydantic.Field(default=None, alias="errorText") + approval: UIToolApproval | None = None + provider_executed: bool | None = pydantic.Field( + default=None, alias="providerExecuted" + ) + call_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="callProviderMetadata" + ) + result_provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="resultProviderMetadata" + ) + tool_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="toolMetadata" + ) + preliminary: bool | None = None + title: str | None = None + + class UIFilePart(pydantic.BaseModel): """File part. TODO: FilePart not yet supported in core messages.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["file"] media_type: str = pydantic.Field(alias="mediaType") url: str filename: str | None = None + provider_reference: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerReference" + ) + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UIReasoningFilePart(pydantic.BaseModel): + """Reasoning file part generated as part of model reasoning.""" + + model_config = _UI_MODEL_CONFIG + + type: Literal["reasoning-file"] + media_type: str = pydantic.Field(alias="mediaType") + url: str + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UISourceUrlPart(pydantic.BaseModel): """Source URL part. TODO: SourceUrlPart not yet supported.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["source-url"] source_id: str = pydantic.Field(alias="sourceId") url: str title: str | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) class UISourceDocumentPart(pydantic.BaseModel): """Source document part. TODO: SourceDocumentPart not yet supported.""" - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG type: Literal["source-document"] source_id: str = pydantic.Field(alias="sourceId") media_type: str = pydantic.Field(alias="mediaType") title: str filename: str | None = None + provider_metadata: dict[str, Any] | None = pydantic.Field( + default=None, alias="providerMetadata" + ) + + +class UIDataPart(pydantic.BaseModel): + """Custom data part with a dynamic ``data-*`` type.""" + + model_config = _UI_MODEL_CONFIG + + type: str + id: str | None = None + data: Any + transient: bool | None = None # Union of all supported part types (used for type hints) UIMessagePart = ( UITextPart | UIReasoningPart + | UICustomPart | UIToolInvocationPart | UIStepStartPart | UIToolPart + | UIDynamicToolPart | UIFilePart + | UIReasoningFilePart | UISourceUrlPart | UISourceDocumentPart + | UIDataPart ) _STATIC_UI_PART_TYPES: dict[str, type[pydantic.BaseModel]] = { "text": UITextPart, "reasoning": UIReasoningPart, + "custom": UICustomPart, "tool-invocation": UIToolInvocationPart, "step-start": UIStepStartPart, "file": UIFilePart, + "reasoning-file": UIReasoningFilePart, "source-url": UISourceUrlPart, "source-document": UISourceDocumentPart, + "dynamic-tool": UIDynamicToolPart, } @@ -198,9 +319,8 @@ def _parse_ui_part(part_data: dict[str, Any]) -> UIMessagePart | None: case str() as t if t.startswith("tool-"): # Dynamic tool type: tool-{toolName} (e.g., "tool-get_weather") return UIToolPart.model_validate(part_data) - case str() as t if t.startswith("data-") or t == "dynamic-tool": - # TODO: data-{name} and dynamic-tool not yet supported - return None + case str() as t if t.startswith("data-"): + return UIDataPart.model_validate(part_data) case _: # Unknown part type - skip gracefully return None @@ -212,12 +332,13 @@ class UIMessage(pydantic.BaseModel): Reference: https://ai-sdk.dev/docs/reference/ai-sdk-core/ui-message """ - model_config = pydantic.ConfigDict(populate_by_name=True) + model_config = _UI_MODEL_CONFIG id: str = pydantic.Field( default_factory=lambda: messages_.generate_id("msg") ) role: Literal["user", "assistant", "system"] + metadata: Any | None = None parts: list[UIMessagePart] = pydantic.Field(default_factory=list) @pydantic.field_validator("parts", mode="before") diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index bae4c3c0..217b4287 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -14,7 +14,7 @@ def generate_id(prefix: str | None = None) -> str: class TextPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) text: str provider_metadata: dict[str, Any] | None = None @@ -25,7 +25,7 @@ class TextPart(pydantic.BaseModel): class ToolResultPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str is_error: bool = False @@ -67,7 +67,7 @@ def has_model_input(self) -> bool: class ToolCallPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str tool_args: str @@ -98,7 +98,7 @@ class BuiltinToolCallPart(pydantic.BaseModel): host. Adapters emit them when a model uses a built-in tool. """ - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str tool_args: str = "" @@ -110,7 +110,7 @@ class BuiltinToolCallPart(pydantic.BaseModel): class BuiltinToolReturnPart(pydantic.BaseModel): """The provider's result for a :class:`BuiltinToolCallPart`.""" - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str result: Any = None @@ -122,7 +122,7 @@ class BuiltinToolReturnPart(pydantic.BaseModel): class ReasoningPart(pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) text: str provider_metadata: dict[str, Any] | None = None @@ -130,7 +130,7 @@ class ReasoningPart(pydantic.BaseModel): class HookPart[T](pydantic.BaseModel): - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) hook_id: str hook_type: str status: Literal["pending", "resolved", "cancelled"] @@ -158,7 +158,7 @@ class FilePart(pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) data: str | bytes media_type: str # IANA media type, e.g. "image/png", "audio/wav" filename: str | None = None @@ -220,7 +220,7 @@ def from_bytes( class Message(pydantic.BaseModel): role: Literal["user", "assistant", "system", "tool", "internal"] parts: list[Part] - id: str = pydantic.Field(default_factory=generate_id) + id: str = pydantic.Field(default_factory=lambda: generate_id("msg")) turn_id: str | None = None usage: usage_.Usage | None = None provider_metadata: dict[str, Any] | None = None diff --git a/tests/agents/ui/ai_sdk/outbound/__init__.py b/tests/agents/ui/ai_sdk/outbound/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/agents/ui/ai_sdk/outbound/test_history.py b/tests/agents/ui/ai_sdk/outbound/test_history.py deleted file mode 100644 index 7665bc13..00000000 --- a/tests/agents/ui/ai_sdk/outbound/test_history.py +++ /dev/null @@ -1,130 +0,0 @@ -from __future__ import annotations - -from ai.agents.ui.ai_sdk import to_ui_messages -from ai.agents.ui.ai_sdk.ui_message import ( - UITextPart, - UIToolPart, -) -from ai.types import messages as messages_ - - -def test_to_ui_messages_user_and_assistant() -> None: - msgs = [ - messages_.Message( - id="u1", role="user", parts=[messages_.TextPart(text="hi")] - ), - messages_.Message( - id="a1", - role="assistant", - parts=[messages_.TextPart(text="hello back")], - ), - ] - result = to_ui_messages(msgs) - assert len(result) == 2 - assert result[0].role == "user" - assert result[1].role == "assistant" - assert result[1].id == "a1" - - -def test_to_ui_messages_merges_assistant_tool_internal() -> None: - msgs = [ - messages_.Message( - id="a1", - role="assistant", - parts=[ - messages_.TextPart(text="calling"), - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args='{"q":"x"}', - ), - ], - ), - messages_.Message( - role="tool", - parts=[ - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 2}, - ) - ], - ), - messages_.Message( - role="assistant", - parts=[messages_.TextPart(text="done")], - ), - ] - result = to_ui_messages(msgs) - assert len(result) == 1 - ui_msg = result[0] - assert ui_msg.role == "assistant" - assert ui_msg.id == "a1" - assert isinstance(ui_msg.parts[0], UITextPart) - assert ui_msg.parts[0].text == "calling" - assert isinstance(ui_msg.parts[1], UIToolPart) - assert ui_msg.parts[1].state == "output-available" - assert ui_msg.parts[1].output == {"hits": 2} - assert isinstance(ui_msg.parts[2], UITextPart) - assert ui_msg.parts[2].text == "done" - - -def test_to_ui_messages_internal_role_merges_approval() -> None: - msgs = [ - messages_.Message( - id="a1", - role="assistant", - parts=[ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="delete", - tool_args="{}", - ) - ], - ), - messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ) - ], - ), - ] - result = to_ui_messages(msgs) - ui_msg = result[0] - tool_part = ui_msg.parts[0] - assert isinstance(tool_part, UIToolPart) - assert tool_part.state == "approval-requested" - assert tool_part.approval is not None - assert tool_part.approval.id == "approve_tc1" - - -def test_to_ui_messages_user_message_uses_own_id() -> None: - msgs = [ - messages_.Message( - id="u1", role="user", parts=[messages_.TextPart(text="a")] - ) - ] - result = to_ui_messages(msgs) - assert result[0].id == "u1" - - -def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: - msgs = [ - messages_.Message( - id="a1", - role="assistant", - parts=[messages_.TextPart(text="first")], - ), - messages_.Message( - id="a2", - role="assistant", - parts=[messages_.TextPart(text="second")], - ), - ] - result = to_ui_messages(msgs) - assert len(result) == 1 - assert result[0].id == "a1" diff --git a/tests/agents/ui/ai_sdk/outbound/test_sse.py b/tests/agents/ui/ai_sdk/outbound/test_sse.py deleted file mode 100644 index 95d520ef..00000000 --- a/tests/agents/ui/ai_sdk/outbound/test_sse.py +++ /dev/null @@ -1,55 +0,0 @@ -from __future__ import annotations - -import json -from collections.abc import AsyncGenerator - -from ai.agents.ui.ai_sdk import protocol, to_sse -from ai.agents.ui.ai_sdk.outbound.sse import format_sse, serialize_part -from ai.types import events as agent_events_ -from ai.types import events as events_ - - -def test_serialize_part_camelcases_keys() -> None: - part = protocol.StartPart(message_id="m1") - payload = json.loads(serialize_part(part)) - assert payload == {"type": "start", "messageId": "m1"} - - -def test_format_sse_wraps_data_line() -> None: - part = protocol.TextDeltaPart(id="t1", delta="hi") - line = format_sse(part) - assert line.startswith("data: ") - assert line.endswith("\n\n") - - -def test_serialize_data_part_uses_type_with_prefix() -> None: - part = protocol.DataPart(data_type="custom", data={"k": 1}) - payload = json.loads(serialize_part(part)) - assert payload["type"] == "data-custom" - assert "dataType" not in payload - - -async def _gen( - stream_events: list[agent_events_.AgentEvent], -) -> AsyncGenerator[agent_events_.AgentEvent]: - for event in stream_events: - yield event - - -async def test_to_sse_emits_data_prefixed_lines() -> None: - lines = [ - line - async for line in to_sse( - _gen( - [ - events_.TextStart(block_id="t1"), - events_.TextDelta(block_id="t1", chunk="hi"), - events_.TextEnd(block_id="t1"), - ] - ) - ) - ] - assert all(line.startswith("data: ") for line in lines) - # first line is the start part (lazy open) - first = json.loads(lines[0].removeprefix("data: ").rstrip()) - assert first["type"] == "start" diff --git a/tests/agents/ui/ai_sdk/outbound/test_stream.py b/tests/agents/ui/ai_sdk/outbound/test_stream.py deleted file mode 100644 index 1022de46..00000000 --- a/tests/agents/ui/ai_sdk/outbound/test_stream.py +++ /dev/null @@ -1,244 +0,0 @@ -from __future__ import annotations - -from collections.abc import AsyncGenerator - -import ai -from ai.agents.ui.ai_sdk import protocol, to_stream -from ai.types import events as agent_events_ -from ai.types import events as events_ -from ai.types import messages as messages_ - - -async def _gen( - stream_events: list[agent_events_.AgentEvent], -) -> AsyncGenerator[agent_events_.AgentEvent]: - for event in stream_events: - yield event - - -async def _collect( - stream_events: list[agent_events_.AgentEvent], -) -> list[protocol.UIMessageStreamPart]: - return [part async for part in to_stream(_gen(stream_events))] - - -async def test_event_driven_text_streaming() -> None: - """Streaming text events lazily open a UI message.""" - text_id = "txt1" - out = await _collect( - [ - events_.TextStart(block_id=text_id), - events_.TextDelta(block_id=text_id, chunk="hi"), - events_.TextEnd(block_id=text_id), - ] - ) - - assert isinstance(out[0], protocol.StartPart) - assert isinstance(out[1], protocol.StartStepPart) - assert isinstance(out[2], protocol.TextStartPart) and out[2].id == text_id - assert isinstance(out[3], protocol.TextDeltaPart) and out[3].delta == "hi" - assert isinstance(out[4], protocol.TextEndPart) and out[4].id == text_id - assert isinstance(out[5], protocol.FinishStepPart) - assert isinstance(out[6], protocol.FinishPart) - - -async def test_tool_call_and_result_emit_terminal_parts() -> None: - """ToolCallResult emits tool input and output parts.""" - tool_result_msg = messages_.Message( - role="tool", - parts=[ - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 1}, - ) - ], - ) - out = await _collect( - [ - # Streaming tool input events from the model - events_.ToolStart(tool_call_id="tc1", tool_name="search"), - events_.ToolDelta(tool_call_id="tc1", chunk='{"q":"x"}'), - events_.ToolEnd( - tool_call_id="tc1", - tool_call=messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args='{"q":"x"}', - ), - ), - # Tool execution result - agent_events_.ToolCallResult( - message=tool_result_msg, - results=tool_result_msg.tool_results, - ), - ] - ) - types = [type(part).__name__ for part in out] - assert "ToolInputStartPart" in types - assert "ToolOutputAvailablePart" in types - - -async def test_tool_result_without_streaming_emits_input_start() -> None: - """ToolCallResult for a non-streamed tool emits input + output parts.""" - tool_result_msg = messages_.Message( - role="tool", - parts=[ - messages_.ToolCallPart( - id="tc1", - tool_call_id="tc1", - tool_name="search", - tool_args='{"q":"x"}', - ), - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 1}, - ), - ], - ) - out = await _collect( - [ - agent_events_.ToolCallResult( - message=tool_result_msg, - results=tool_result_msg.tool_results, - ), - ] - ) - types = [type(part).__name__ for part in out] - assert "ToolInputStartPart" in types - assert "ToolInputAvailablePart" in types - assert "ToolOutputAvailablePart" in types - - -async def test_approval_request_hook_emits_approval_part() -> None: - """HookEvent with pending status emits a ToolApprovalRequestPart.""" - out = await _collect( - [ - # Streaming tool events first - events_.ToolStart(tool_call_id="tc1", tool_name="delete"), - events_.ToolDelta(tool_call_id="tc1", chunk="{}"), - events_.ToolEnd( - tool_call_id="tc1", - tool_call=messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="delete", - tool_args="{}", - ), - ), - # Hook requesting approval - agent_events_.HookEvent( - message=messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ) - ], - ), - hook=messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ), - ), - ] - ) - approval_parts = [ - p for p in out if isinstance(p, protocol.ToolApprovalRequestPart) - ] - assert len(approval_parts) == 1 - assert approval_parts[0].tool_call_id == "tc1" - assert approval_parts[0].approval_id == "approve_tc1" - - -async def test_partial_tool_results_emit_preliminary_outputs() -> None: - """Each partial result yields a preliminary part.""" - out = await _collect( - [ - agent_events_.PartialToolCallResult( - tool_call_id="tc1", - tool_name="search", - value="hit 1, ", - aggregator_factory=ai.agents.ConcatAggregator, - ), - agent_events_.PartialToolCallResult( - tool_call_id="tc1", - tool_name="search", - value="hit 2, ", - aggregator_factory=ai.agents.ConcatAggregator, - ), - agent_events_.PartialToolCallResult( - tool_call_id="tc1", - tool_name="search", - value="hit 3", - aggregator_factory=ai.agents.ConcatAggregator, - ), - ] - ) - - prelim = [ - p - for p in out - if isinstance(p, protocol.ToolOutputAvailablePart) and p.preliminary - ] - assert [p.output for p in prelim] == [ - "hit 1, ", - "hit 1, hit 2, ", - "hit 1, hit 2, hit 3", - ] - assert all(p.tool_call_id == "tc1" for p in prelim) - - -async def test_partial_message_bundle_becomes_ui_message() -> None: - """MessageAggregator's snapshot collapses to one UIMessage.""" - from ai.agents.ui.ai_sdk.ui_message import UIMessage - - inner_msg = messages_.Message( - role="assistant", - parts=[messages_.TextPart(text="hi from sub-agent")], - ) - - out = await _collect( - [ - agent_events_.PartialToolCallResult( - tool_call_id="tc1", - tool_name="research", - value=agent_events_.ToolCallResult( - message=inner_msg, results=[] - ), - aggregator_factory=ai.agents.MessageAggregator, - ), - ] - ) - - [prelim] = [ - p - for p in out - if isinstance(p, protocol.ToolOutputAvailablePart) and p.preliminary - ] - assert isinstance(prelim.output, UIMessage) - assert prelim.output.role == "assistant" - assert prelim.output.parts[0].type == "text" - - -async def test_partial_tool_result_without_factory_is_skipped() -> None: - """Without an aggregator_factory there's nothing to snapshot.""" - out = await _collect( - [ - agent_events_.PartialToolCallResult( - tool_call_id="tc1", - tool_name="search", - value="ignored", - ), - ] - ) - assert not any(isinstance(p, protocol.ToolOutputAvailablePart) for p in out) - - -# NOTE: agent-change boundary detection used to be driven by -# Message.source_label. That field has been removed; agent-change -# routing in the AI SDK adapter now needs to come from -# PartialToolCallResult, which is a separate piece of work. diff --git a/tests/agents/ui/ai_sdk/test_approvals.py b/tests/agents/ui/ai_sdk/test_approvals.py index 993a9c25..447ef9db 100644 --- a/tests/agents/ui/ai_sdk/test_approvals.py +++ b/tests/agents/ui/ai_sdk/test_approvals.py @@ -2,17 +2,38 @@ from typing import Any -from ai.agents.ui.ai_sdk import _approvals +from ai.agents.ui.ai_sdk import approvals +from ai.agents.ui.ai_sdk.ui_messages import UIMessage from ai.types import messages as messages_ +def _ui(role: str, *parts: dict[str, Any], id: str = "m1") -> UIMessage: + return UIMessage.model_validate( + {"id": id, "role": role, "parts": list(parts)} + ) + + +def _tool( + tool_name: str, + tool_call_id: str, + state: str, + **extra: Any, +) -> dict[str, Any]: + return { + "type": f"tool-{tool_name}", + "toolCallId": tool_call_id, + "state": state, + **extra, + } + + def test_tool_call_id_for_strips_prefix() -> None: hook: messages_.HookPart[Any] = messages_.HookPart( hook_id="approve_tc_42", hook_type="ToolApproval", status="pending", ) - assert _approvals.tool_call_id_for(hook) == "tc_42" + assert approvals.tool_call_id_for(hook) == "tc_42" def test_tool_call_id_for_rejects_non_approval_type() -> None: @@ -21,7 +42,61 @@ def test_tool_call_id_for_rejects_non_approval_type() -> None: hook_type="SomethingElse", status="pending", ) - assert _approvals.tool_call_id_for(hook) is None + assert approvals.tool_call_id_for(hook) is None + + +def test_extract_approvals_returns_approved_responses() -> None: + approval_responses = approvals.extract_approvals( + [ + _ui( + "assistant", + _tool( + "x", + "tc1", + "approval-responded", + approval={ + "id": "approve_tc1", + "approved": False, + "reason": "nope", + }, + ), + ) + ] + ) + assert len(approval_responses) == 1 + assert approval_responses[0].hook_id == "approve_tc1" + assert approval_responses[0].granted is False + assert approval_responses[0].reason == "nope" + + +def test_extract_approvals_handles_dynamic_tool_responses() -> None: + approval_responses = approvals.extract_approvals( + [ + _ui( + "assistant", + { + "type": "dynamic-tool", + "toolName": "web_search", + "toolCallId": "tc1", + "state": "approval-responded", + "input": {"query": "ai"}, + "approval": { + "id": "approve_tc1", + "approved": True, + "reason": "ok", + "isAutomatic": True, + }, + "providerExecuted": True, + }, + ) + ] + ) + + assert len(approval_responses) == 1 + assert approval_responses[0].hook_id == "approve_tc1" + assert approval_responses[0].granted is True + assert approval_responses[0].reason == "ok" + assert approval_responses[0].tool_call_id == "tc1" def test_tool_call_id_for_rejects_bad_prefix() -> None: @@ -30,37 +105,4 @@ def test_tool_call_id_for_rejects_bad_prefix() -> None: hook_type="ToolApproval", status="pending", ) - assert _approvals.tool_call_id_for(hook) is None - - -def test_is_tool_approval_message_detects_all_approval_hooks() -> None: - msg = messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="approve_tc_1", - hook_type="ToolApproval", - status="pending", - ), - ], - ) - assert _approvals.is_tool_approval_message(msg) - - -def test_is_tool_approval_message_false_for_non_approval() -> None: - msg = messages_.Message( - role="internal", - parts=[ - messages_.HookPart( - hook_id="other", - hook_type="Something", - status="pending", - ), - ], - ) - assert not _approvals.is_tool_approval_message(msg) - - -def test_is_tool_approval_message_false_for_empty() -> None: - msg = messages_.Message(role="internal", parts=[]) - assert not _approvals.is_tool_approval_message(msg) + assert approvals.tool_call_id_for(hook) is None diff --git a/tests/agents/ui/ai_sdk/test_inbound.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py similarity index 70% rename from tests/agents/ui/ai_sdk/test_inbound.py rename to tests/agents/ui/ai_sdk/test_inbound_messages.py index 64f8a556..f2d2abbc 100644 --- a/tests/agents/ui/ai_sdk/test_inbound.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -6,11 +6,8 @@ from ai.agents.agent import MessageBundle from ai.agents.ui.ai_sdk import to_messages -from ai.agents.ui.ai_sdk.inbound import ( - _normalize_ui_messages, - extract_approvals, -) -from ai.agents.ui.ai_sdk.ui_message import UIMessage, UIToolPart +from ai.agents.ui.ai_sdk.inbound_messages import _normalize_ui_messages +from ai.agents.ui.ai_sdk.ui_messages import UIMessage, UIToolPart from ai.types import messages as messages_ @@ -136,30 +133,6 @@ def test_to_messages_keeps_trailing_assistant_when_approved() -> None: assert [a.hook_id for a in approvals] == ["approve_tc1"] -def test_extract_approvals_returns_approved_responses() -> None: - approvals = extract_approvals( - [ - _ui( - "assistant", - _tool( - "x", - "tc1", - "approval-responded", - approval={ - "id": "approve_tc1", - "approved": False, - "reason": "nope", - }, - ), - ) - ] - ) - assert len(approvals) == 1 - assert approvals[0].hook_id == "approve_tc1" - assert approvals[0].granted is False - assert approvals[0].reason == "nope" - - def test_normalize_ui_messages_heals_stale_tool_state() -> None: ui = [ _ui( @@ -237,3 +210,78 @@ def test_to_messages_passthrough_keeps_wire_shape() -> None: part = tool_msgs[0].tool_results[0] assert part.result == {"pong": True} assert part.get_model_input() == {"pong": True} + + +def test_to_messages_accepts_metadata_and_ui_only_parts() -> None: + ui = [ + UIMessage.model_validate( + { + "id": "a1", + "role": "assistant", + "metadata": {"trace": "t1"}, + "parts": [ + {"type": "custom", "kind": "openai.compaction"}, + { + "type": "data-weather", + "id": "weather-1", + "data": {"status": "loading"}, + }, + { + "type": "source-url", + "sourceId": "src-1", + "url": "https://example.com", + }, + { + "type": "reasoning-file", + "mediaType": "image/png", + "url": "data:image/png;base64,AAAA", + }, + { + "type": "text", + "text": "visible", + "providerMetadata": {"provider": {"k": "v"}}, + }, + ], + } + ) + ] + + messages, approvals = to_messages(ui) + + assert approvals == [] + assert len(messages) == 1 + assert messages[0].text == "visible" + text = messages[0].parts[0] + assert isinstance(text, messages_.TextPart) + assert text.provider_metadata == {"provider": {"k": "v"}} + + +def test_to_messages_dynamic_provider_executed_tool_becomes_builtin() -> None: + messages, _ = to_messages( + [ + _ui( + "assistant", + { + "type": "dynamic-tool", + "toolName": "web_search", + "toolCallId": "tc1", + "state": "output-available", + "input": {"query": "ai"}, + "output": [{"title": "result"}], + "providerExecuted": True, + "callProviderMetadata": {"provider": {"call": 1}}, + "resultProviderMetadata": {"provider": {"result": 1}}, + }, + ) + ] + ) + + assert len(messages) == 1 + assert messages[0].role == "assistant" + [call] = messages[0].builtin_tool_calls + [result] = messages[0].builtin_tool_returns + assert call.tool_name == "web_search" + assert call.tool_args == '{"query": "ai"}' + assert call.provider_metadata == {"provider": {"call": 1}} + assert result.result == [{"title": "result"}] + assert result.provider_metadata == {"provider": {"result": 1}} diff --git a/tests/agents/ui/ai_sdk/test_outbound_messages.py b/tests/agents/ui/ai_sdk/test_outbound_messages.py new file mode 100644 index 00000000..ad503e58 --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_outbound_messages.py @@ -0,0 +1,602 @@ +from __future__ import annotations + +from collections import Counter + +from ai.agents.ui import ai_sdk +from ai.agents.ui.ai_sdk import outbound_messages, to_ui_messages +from ai.agents.ui.ai_sdk.ui_messages import ( + UIDynamicToolPart, + UIFilePart, + UIReasoningPart, + UITextPart, + UIToolApproval, + UIToolPart, +) +from ai.types import integrity +from ai.types import messages as messages_ + + +def _parallel_tool_turn( + *, + turn_id: str, + assistant_prefix: str | None = None, + tool_call_ids: tuple[str, str] = ("tc-bash", "tc-web"), +) -> list[messages_.Message]: + prefix = assistant_prefix or turn_id + tc_bash, tc_web = tool_call_ids + + return [ + messages_.Message( + id=f"{prefix}:assistant:0", + turn_id=turn_id, + role="assistant", + parts=[ + messages_.TextPart( + id=f"{prefix}:text:0", + text="I will run two tools.", + ), + messages_.ToolCallPart( + id=f"{prefix}:call:bash", + tool_call_id=tc_bash, + tool_name="bash", + tool_args='{"command":"date"}', + ), + messages_.ToolCallPart( + id=f"{prefix}:call:web", + tool_call_id=tc_web, + tool_name="web_fetch", + tool_args='{"url":"https://httpbin.org/get"}', + ), + ], + ), + messages_.Message( + id=f"{prefix}:tool:0", + turn_id=turn_id, + role="tool", + parts=[ + messages_.ToolResultPart( + id=f"{prefix}:result:bash", + tool_call_id=tc_bash, + tool_name="bash", + result="Tue May 19 2026", + ), + messages_.ToolResultPart( + id=f"{prefix}:result:web", + tool_call_id=tc_web, + tool_name="web_fetch", + result={"status": 200}, + ), + ], + ), + messages_.Message( + id=f"{prefix}:assistant:1", + turn_id=turn_id, + role="assistant", + parts=[ + messages_.TextPart( + id=f"{prefix}:text:1", + text="Both tools finished.", + ), + ], + ), + ] + + +def test_to_ui_parts_text_and_reasoning() -> None: + parts: list[messages_.Part] = [ + messages_.ReasoningPart(text="thinking"), + messages_.TextPart(text="hi"), + ] + ui_parts = outbound_messages.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIReasoningPart) + assert ui_parts[0].text == "thinking" + assert isinstance(ui_parts[1], UITextPart) + assert ui_parts[1].text == "hi" + + +def test_to_ui_parts_tool_call_parses_json_args() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q": "x"}', + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + assert isinstance(ui_parts[0], UIToolPart) + assert ui_parts[0].type == "tool-search" + assert ui_parts[0].input == {"q": "x"} + assert ui_parts[0].state == "input-available" + + +def test_merge_tool_results_updates_state_and_output() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + outbound_messages.merge_tool_results( + ui_parts, + [ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 3}, + ) + ], + ) + merged = ui_parts[0] + assert isinstance(merged, UIToolPart) + assert merged.state == "output-available" + assert merged.output == {"hits": 3} + + +def test_merge_approval_signals_pending_then_resolved() -> None: + parts: list[messages_.Part] = [ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ] + ui_parts = outbound_messages.to_ui_parts(parts) + + outbound_messages.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ) + requested = ui_parts[0] + assert isinstance(requested, UIToolPart) + assert requested.state == "approval-requested" + assert isinstance(requested.approval, UIToolApproval) + + outbound_messages.merge_approval_signals( + ui_parts, + [ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + resolution={"granted": True, "reason": None}, + ) + ], + ) + responded = ui_parts[0] + assert isinstance(responded, UIToolPart) + assert responded.state == "approval-responded" + assert responded.approval is not None + assert responded.approval.approved is True + + +def _tool_counts( + messages: list[messages_.Message], +) -> Counter[tuple[str, str]]: + counts: Counter[tuple[str, str]] = Counter() + for message in messages: + for part in message.parts: + if isinstance(part, messages_.ToolCallPart): + counts["tool_call", part.tool_call_id] += 1 + elif isinstance(part, messages_.ToolResultPart): + counts["tool_result", part.tool_call_id] += 1 + return counts + + +class IdUpsertStore: + """Small app-like store: persist full history by message id.""" + + def __init__(self) -> None: + self._rows: list[messages_.Message] = [] + + def save_full_history(self, messages: list[messages_.Message]) -> None: + for message in messages: + if message.role == "system": + continue + + for index, existing in enumerate(self._rows): + if existing.id == message.id: + self._rows[index] = message + break + else: + self._rows.append(message) + + def load(self) -> list[messages_.Message]: + return list(self._rows) + + +def test_to_ui_messages_user_and_assistant() -> None: + msgs = [ + messages_.Message( + id="u1", role="user", parts=[messages_.TextPart(text="hi")] + ), + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="hello back")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 2 + assert result[0].role == "user" + assert result[1].role == "assistant" + assert result[1].id == "a1" + + +def test_to_ui_messages_merges_assistant_tool_internal() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.TextPart(text="calling"), + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + ], + ), + messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 2}, + ) + ], + ), + messages_.Message( + role="assistant", + parts=[messages_.TextPart(text="done")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + ui_msg = result[0] + assert ui_msg.role == "assistant" + assert ui_msg.id == "a1" + assert isinstance(ui_msg.parts[0], UITextPart) + assert ui_msg.parts[0].text == "calling" + assert isinstance(ui_msg.parts[1], UIToolPart) + assert ui_msg.parts[1].state == "output-available" + assert ui_msg.parts[1].output == {"hits": 2} + assert isinstance(ui_msg.parts[2], UITextPart) + assert ui_msg.parts[2].text == "done" + + +def test_to_ui_messages_records_source_messages_in_metadata() -> None: + msgs = [ + messages_.Message( + id="turn-1:assistant:0", + turn_id="turn-1", + role="assistant", + parts=[ + messages_.TextPart(id="text-0", text="calling"), + messages_.ToolCallPart( + id="call-0", + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ), + ], + ), + messages_.Message( + id="turn-1:tool:0", + turn_id="turn-1", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-0", + tool_call_id="tc1", + tool_name="search", + result={"hits": 2}, + ) + ], + ), + messages_.Message( + id="turn-1:assistant:1", + turn_id="turn-1", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="done")], + ), + ] + + [ui_msg] = to_ui_messages(msgs) + + assert ui_msg.id == "turn-1" + assert ui_msg.metadata == { + "aiPython": { + "sourceMessages": [ + { + "id": "turn-1:assistant:0", + "role": "assistant", + "turnId": "turn-1", + "partIds": ["text-0", "call-0"], + }, + { + "id": "turn-1:tool:0", + "role": "tool", + "turnId": "turn-1", + "partIds": ["result-0"], + }, + { + "id": "turn-1:assistant:1", + "role": "assistant", + "turnId": "turn-1", + "partIds": ["text-1"], + }, + ] + } + } + + +def test_to_ui_messages_internal_role_merges_approval() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ) + ], + ), + messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + ] + result = to_ui_messages(msgs) + ui_msg = result[0] + tool_part = ui_msg.parts[0] + assert isinstance(tool_part, UIToolPart) + assert tool_part.state == "approval-requested" + assert tool_part.approval is not None + assert tool_part.approval.id == "approve_tc1" + + +def test_to_ui_messages_user_message_uses_own_id() -> None: + msgs = [ + messages_.Message( + id="u1", role="user", parts=[messages_.TextPart(text="a")] + ) + ] + result = to_ui_messages(msgs) + assert result[0].id == "u1" + + +def test_to_ui_messages_uses_first_assistant_id_as_bubble_id() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[messages_.TextPart(text="first")], + ), + messages_.Message( + id="a2", + role="assistant", + parts=[messages_.TextPart(text="second")], + ), + ] + result = to_ui_messages(msgs) + assert len(result) == 1 + assert result[0].id == "a1" + + +def test_to_ui_messages_preserves_provider_metadata_and_files() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.TextPart( + text="hello", + provider_metadata={"provider": {"text": True}}, + ), + messages_.FilePart( + data=b"abc", + media_type="image/png", + filename="image.png", + provider_metadata={"provider": {"file": True}}, + ), + ], + ) + ] + + result = to_ui_messages(msgs) + + text_part = result[0].parts[0] + assert isinstance(text_part, UITextPart) + assert text_part.provider_metadata == {"provider": {"text": True}} + + file_part = result[0].parts[1] + assert isinstance(file_part, UIFilePart) + assert file_part.url == "data:image/png;base64,YWJj" + assert file_part.filename == "image.png" + assert file_part.provider_metadata == {"provider": {"file": True}} + + +def test_to_ui_messages_maps_builtin_tools_to_dynamic_parts() -> None: + msgs = [ + messages_.Message( + id="a1", + role="assistant", + parts=[ + messages_.BuiltinToolCallPart( + tool_call_id="tc1", + tool_name="web_search", + tool_args='{"q":"ai"}', + provider_metadata={"provider": {"call": True}}, + ), + messages_.BuiltinToolReturnPart( + tool_call_id="tc1", + tool_name="web_search", + result={"hits": 1}, + provider_metadata={"provider": {"result": True}}, + ), + ], + ) + ] + + result = to_ui_messages(msgs) + + assert len(result[0].parts) == 1 + tool_part = result[0].parts[0] + assert isinstance(tool_part, UIDynamicToolPart) + assert tool_part.provider_executed is True + assert tool_part.state == "output-available" + assert tool_part.input == {"q": "ai"} + assert tool_part.output == {"hits": 1} + assert tool_part.call_provider_metadata == {"provider": {"call": True}} + assert tool_part.result_provider_metadata == {"provider": {"result": True}} + + +def test_collapsed_assistant_turn_roundtrips_internal_ids() -> None: + original = [ + messages_.Message( + id="assistant-alpha", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-alpha", text="calling first"), + messages_.ToolCallPart( + id="call-alpha", + tool_call_id="tc-first", + tool_name="search", + tool_args='{"q":"first"}', + ), + ], + ), + messages_.Message( + id="tool-beta", + turn_id="turn-arbitrary", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-beta", + tool_call_id="tc-first", + tool_name="search", + result={"hits": 1}, + ) + ], + ), + messages_.Message( + id="assistant-gamma", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-gamma", text="calling second"), + messages_.ToolCallPart( + id="call-gamma", + tool_call_id="tc-second", + tool_name="lookup", + tool_args='{"id":2}', + ), + ], + ), + messages_.Message( + id="tool-delta", + turn_id="turn-arbitrary", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-delta", + tool_call_id="tc-second", + tool_name="lookup", + result={"value": 2}, + ) + ], + ), + messages_.Message( + id="assistant-epsilon", + turn_id="turn-arbitrary", + role="assistant", + parts=[ + messages_.TextPart(id="text-epsilon", text="all done"), + ], + ), + ] + + [ui_msg] = ai_sdk.to_ui_messages(original) + roundtripped, approvals = ai_sdk.to_messages([ui_msg]) + + assert approvals == [] + assert ui_msg.role == "assistant" + assert ui_msg.id == "turn-arbitrary" + assert [m.role for m in roundtripped] == [m.role for m in original] + assert [m.id for m in roundtripped] == [m.id for m in original] + assert [m.turn_id for m in roundtripped] == [m.turn_id for m in original] + assert [[p.id for p in m.parts] for m in roundtripped] == [ + [p.id for p in m.parts] for m in original + ] + + +def test_common_id_upsert_persistence_is_idempotent_after_reload() -> None: + store = IdUpsertStore() + + first_run = [ + messages_.Message( + id="user-1", + role="user", + parts=[messages_.TextPart(id="user-1:text", text="run two tools")], + ), + *_parallel_tool_turn(turn_id="turn-1"), + ] + store.save_full_history(first_run) + + reloaded_ui = ai_sdk.to_ui_messages(store.load()) + request_history, _ = ai_sdk.to_messages(reloaded_ui) + + second_run_result = [ + *request_history, + messages_.Message( + id="user-2", + role="user", + parts=[messages_.TextPart(id="user-2:text", text="do nothing")], + ), + messages_.Message( + id="turn-2:assistant:0", + turn_id="turn-2", + role="assistant", + parts=[messages_.TextPart(id="turn-2:text:0", text="standing by")], + ), + ] + store.save_full_history(second_run_result) + + loaded = store.load() + integrity.prepare_messages(loaded) + + counts = _tool_counts(loaded) + assert counts["tool_call", "tc-bash"] == 1 + assert counts["tool_result", "tc-bash"] == 1 + assert counts["tool_call", "tc-web"] == 1 + assert counts["tool_result", "tc-web"] == 1 + + +def test_duplicate_tool_copies_do_not_reach_model_integrity() -> None: + history = [ + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="server"), + *_parallel_tool_turn(turn_id="turn-1", assistant_prefix="client"), + ] + + reloaded_ui = ai_sdk.to_ui_messages(history) + next_request_history, _ = ai_sdk.to_messages(reloaded_ui) + + integrity.prepare_messages(next_request_history) diff --git a/tests/agents/ui/ai_sdk/test_outbound_stream.py b/tests/agents/ui/ai_sdk/test_outbound_stream.py new file mode 100644 index 00000000..e496405c --- /dev/null +++ b/tests/agents/ui/ai_sdk/test_outbound_stream.py @@ -0,0 +1,575 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator +from typing import Any + +import ai +from ai.agents.ui.ai_sdk import to_sse, to_stream, ui_events +from ai.agents.ui.ai_sdk.outbound_stream import ( + format_done_sse, + format_sse, + serialize_event, +) +from ai.types import events as agent_events_ +from ai.types import events as events_ +from ai.types import messages as messages_ + + +async def _gen( + stream_events: list[agent_events_.AgentEvent], +) -> AsyncGenerator[agent_events_.AgentEvent]: + for event in stream_events: + yield event + + +async def _collect( + stream_events: list[agent_events_.AgentEvent], +) -> list[ui_events.UIMessageStreamEvent]: + return [event async for event in to_stream(_gen(stream_events))] + + +def _source_messages(metadata: Any | None) -> list[dict[str, Any]]: + assert isinstance(metadata, dict) + adapter_metadata = metadata.get("aiPython") + assert isinstance(adapter_metadata, dict) + source_messages = adapter_metadata.get("sourceMessages") + assert isinstance(source_messages, list) + return source_messages + + +def test_serialize_event_camelcases_keys() -> None: + event = ui_events.UIStartEvent(message_id="m1") + payload = json.loads(serialize_event(event)) + assert payload == {"type": "start", "messageId": "m1"} + + +def test_format_sse_wraps_data_line() -> None: + event = ui_events.UITextDeltaEvent(id="t1", delta="hi") + line = format_sse(event) + assert line.startswith("data: ") + assert line.endswith("\n\n") + + +def test_serialize_data_event_uses_type_with_prefix() -> None: + event = ui_events.UIDataEvent(data_type="custom", data={"k": 1}) + payload = json.loads(serialize_event(event)) + assert payload["type"] == "data-custom" + assert "dataType" not in payload + + +def test_serialize_protocol_fields_use_ai_sdk_wire_names() -> None: + event = ui_events.UIToolApprovalResponseEvent( + approval_id="approval-1", + approved=False, + reason="no", + provider_executed=True, + provider_metadata={"provider": {"k": "v"}}, + ) + + payload = json.loads(serialize_event(event)) + + assert payload == { + "type": "tool-approval-response", + "approvalId": "approval-1", + "approved": False, + "reason": "no", + "providerExecuted": True, + "providerMetadata": {"provider": {"k": "v"}}, + } + + +def test_format_done_sse_returns_done_sentinel() -> None: + assert format_done_sse() == "data: [DONE]\n\n" + + +async def test_to_sse_emits_data_prefixed_lines() -> None: + lines = [ + line + async for line in to_sse( + _gen( + [ + events_.TextStart(block_id="t1"), + events_.TextDelta(block_id="t1", chunk="hi"), + events_.TextEnd(block_id="t1"), + ] + ) + ) + ] + assert all(line.startswith("data: ") for line in lines) + first = json.loads(lines[0].removeprefix("data: ").rstrip()) + assert first["type"] == "start" + assert lines[-1] == "data: [DONE]\n\n" + + +async def test_stream_start_uses_runtime_message_id() -> None: + assistant = messages_.Message( + id="assistant-runtime-id", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="hello")], + ) + + out = await _collect( + [ + events_.TextStart(block_id="text-1", message=assistant), + events_.TextDelta( + block_id="text-1", chunk="hello", message=assistant + ), + events_.TextEnd(block_id="text-1", message=assistant), + ] + ) + + start = next( + event for event in out if isinstance(event, ui_events.UIStartEvent) + ) + assert start.message_id == "assistant-runtime-id" + + +async def test_finish_metadata_tracks_streamed_assistant_message() -> None: + assistant = messages_.Message( + id="assistant-1", + role="assistant", + parts=[messages_.TextPart(id="text-1", text="hello")], + ) + + out = await _collect( + [ + events_.TextStart(block_id="text-1", message=assistant), + events_.TextDelta( + block_id="text-1", chunk="hello", message=assistant + ), + events_.TextEnd(block_id="text-1", message=assistant), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert _source_messages(finish.message_metadata) == [ + { + "id": "assistant-1", + "role": "assistant", + "turnId": None, + "partIds": ["text-1"], + } + ] + + +async def test_finish_metadata_tracks_tool_and_internal_messages() -> None: + tool_call = messages_.ToolCallPart( + id="call-1", + tool_call_id="tc1", + tool_name="search", + tool_args="{}", + ) + assistant = messages_.Message( + id="assistant-1", + turn_id="turn-1", + role="assistant", + parts=[tool_call], + ) + tool = messages_.Message( + id="tool-1", + turn_id="turn-1", + role="tool", + parts=[ + messages_.ToolResultPart( + id="result-1", + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ) + hook = messages_.HookPart[Any]( + id="hook-1", + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + internal = messages_.Message( + id="internal-1", + turn_id="turn-1", + role="internal", + parts=[hook], + ) + + out = await _collect( + [ + events_.ToolStart( + tool_call_id="tc1", tool_name="search", message=assistant + ), + events_.ToolEnd( + tool_call_id="tc1", tool_call=tool_call, message=assistant + ), + agent_events_.ToolCallResult( + message=tool, + results=tool.tool_results, + ), + agent_events_.HookEvent(message=internal, hook=hook), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert [ + source["id"] for source in _source_messages(finish.message_metadata) + ] == ["assistant-1", "tool-1", "internal-1"] + + +async def test_event_driven_text_streaming() -> None: + """Streaming text events lazily open a UI message.""" + text_id = "txt1" + out = await _collect( + [ + events_.TextStart(block_id=text_id), + events_.TextDelta(block_id=text_id, chunk="hi"), + events_.TextEnd(block_id=text_id), + ] + ) + + assert isinstance(out[0], ui_events.UIStartEvent) + assert isinstance(out[1], ui_events.UIStartStepEvent) + assert ( + isinstance(out[2], ui_events.UITextStartEvent) and out[2].id == text_id + ) + assert ( + isinstance(out[3], ui_events.UITextDeltaEvent) and out[3].delta == "hi" + ) + assert isinstance(out[4], ui_events.UITextEndEvent) and out[4].id == text_id + assert isinstance(out[5], ui_events.UIFinishStepEvent) + assert isinstance(out[6], ui_events.UIFinishEvent) + assert out[6].message_metadata is None + + +async def test_finish_metadata_ignores_empty_messages() -> None: + assistant = messages_.Message( + id="assistant-empty", + role="assistant", + parts=[], + ) + + out = await _collect( + [ + events_.StreamStart(message=assistant), + events_.StreamEnd(message=assistant), + ] + ) + + finish = next( + event for event in out if isinstance(event, ui_events.UIFinishEvent) + ) + assert finish.message_metadata is None + + +async def test_tool_call_and_result_emit_terminal_events() -> None: + """ToolCallResult emits tool input and output events.""" + tool_result_msg = messages_.Message( + role="tool", + parts=[ + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ) + ], + ) + out = await _collect( + [ + # Streaming tool input events from the model + events_.ToolStart(tool_call_id="tc1", tool_name="search"), + events_.ToolDelta(tool_call_id="tc1", chunk='{"q":"x"}'), + events_.ToolEnd( + tool_call_id="tc1", + tool_call=messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + ), + # Tool execution result + agent_events_.ToolCallResult( + message=tool_result_msg, + results=tool_result_msg.tool_results, + ), + ] + ) + types = [type(event).__name__ for event in out] + assert "UIToolInputStartEvent" in types + assert "UIToolOutputAvailableEvent" in types + + +async def test_tool_result_without_streaming_emits_input_start() -> None: + """ToolCallResult for a non-streamed tool emits input + output events.""" + tool_result_msg = messages_.Message( + role="tool", + parts=[ + messages_.ToolCallPart( + id="tc1", + tool_call_id="tc1", + tool_name="search", + tool_args='{"q":"x"}', + ), + messages_.ToolResultPart( + tool_call_id="tc1", + tool_name="search", + result={"hits": 1}, + ), + ], + ) + out = await _collect( + [ + agent_events_.ToolCallResult( + message=tool_result_msg, + results=tool_result_msg.tool_results, + ), + ] + ) + types = [type(event).__name__ for event in out] + assert "UIToolInputStartEvent" in types + assert "UIToolInputAvailableEvent" in types + assert "UIToolOutputAvailableEvent" in types + + +async def test_approval_request_hook_emits_approval_event() -> None: + """HookEvent with pending status emits a UIToolApprovalRequestEvent.""" + out = await _collect( + [ + # Streaming tool events first + events_.ToolStart(tool_call_id="tc1", tool_name="delete"), + events_.ToolDelta(tool_call_id="tc1", chunk="{}"), + events_.ToolEnd( + tool_call_id="tc1", + tool_call=messages_.ToolCallPart( + tool_call_id="tc1", + tool_name="delete", + tool_args="{}", + ), + ), + # Hook requesting approval + agent_events_.HookEvent( + message=messages_.Message( + role="internal", + parts=[ + messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ) + ], + ), + hook=messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="pending", + ), + ), + ] + ) + approval_events = [ + p for p in out if isinstance(p, ui_events.UIToolApprovalRequestEvent) + ] + assert len(approval_events) == 1 + assert approval_events[0].tool_call_id == "tc1" + assert approval_events[0].approval_id == "approve_tc1" + + +async def test_partial_tool_results_emit_preliminary_outputs() -> None: + """Each partial result yields a preliminary event.""" + out = await _collect( + [ + agent_events_.PartialToolCallResult( + tool_call_id="tc1", + tool_name="search", + value="hit 1, ", + aggregator_factory=ai.agents.ConcatAggregator, + ), + agent_events_.PartialToolCallResult( + tool_call_id="tc1", + tool_name="search", + value="hit 2, ", + aggregator_factory=ai.agents.ConcatAggregator, + ), + agent_events_.PartialToolCallResult( + tool_call_id="tc1", + tool_name="search", + value="hit 3", + aggregator_factory=ai.agents.ConcatAggregator, + ), + ] + ) + + prelim = [ + p + for p in out + if isinstance(p, ui_events.UIToolOutputAvailableEvent) and p.preliminary + ] + assert [p.output for p in prelim] == [ + "hit 1, ", + "hit 1, hit 2, ", + "hit 1, hit 2, hit 3", + ] + assert all(p.tool_call_id == "tc1" for p in prelim) + + +async def test_partial_message_bundle_becomes_ui_message() -> None: + """MessageAggregator's snapshot collapses to one UIMessage.""" + from ai.agents.ui.ai_sdk.ui_messages import UIMessage + + inner_msg = messages_.Message( + role="assistant", + parts=[messages_.TextPart(text="hi from sub-agent")], + ) + + out = await _collect( + [ + agent_events_.PartialToolCallResult( + tool_call_id="tc1", + tool_name="research", + value=agent_events_.ToolCallResult( + message=inner_msg, results=[] + ), + aggregator_factory=ai.agents.MessageAggregator, + ), + ] + ) + + [prelim] = [ + p + for p in out + if isinstance(p, ui_events.UIToolOutputAvailableEvent) and p.preliminary + ] + assert isinstance(prelim.output, UIMessage) + assert prelim.output.role == "assistant" + assert prelim.output.parts[0].type == "text" + + +async def test_partial_tool_result_without_factory_is_skipped() -> None: + """Without an aggregator_factory there's nothing to snapshot.""" + out = await _collect( + [ + agent_events_.PartialToolCallResult( + tool_call_id="tc1", + tool_name="search", + value="ignored", + ), + ] + ) + assert not any( + isinstance(p, ui_events.UIToolOutputAvailableEvent) for p in out + ) + + +async def test_builtin_tool_stream_marks_provider_executed_dynamic() -> None: + out = await _collect( + [ + events_.BuiltinToolStart( + tool_call_id="tc1", + tool_name="web_search", + provider_metadata={"provider": {"start": True}}, + ), + events_.BuiltinToolDelta(tool_call_id="tc1", chunk='{"q":"ai"}'), + events_.BuiltinToolEnd( + tool_call_id="tc1", + tool_call=messages_.BuiltinToolCallPart( + tool_call_id="tc1", + tool_name="web_search", + tool_args='{"q":"ai"}', + provider_metadata={"provider": {"call": True}}, + ), + ), + events_.BuiltinToolResult( + tool_call_id="tc1", + result=messages_.BuiltinToolReturnPart( + tool_call_id="tc1", + tool_name="web_search", + result={"hits": 1}, + provider_metadata={"provider": {"result": True}}, + ), + ), + ] + ) + + start = next( + p for p in out if isinstance(p, ui_events.UIToolInputStartEvent) + ) + assert start.provider_executed is True + assert start.dynamic is True + assert start.provider_metadata == {"provider": {"start": True}} + + available = next( + p for p in out if isinstance(p, ui_events.UIToolInputAvailableEvent) + ) + assert available.provider_executed is True + assert available.dynamic is True + assert available.input == {"q": "ai"} + assert available.provider_metadata == {"provider": {"call": True}} + + result = next( + p for p in out if isinstance(p, ui_events.UIToolOutputAvailableEvent) + ) + assert result.provider_executed is True + assert result.dynamic is True + assert result.output == {"hits": 1} + assert result.provider_metadata == {"provider": {"result": True}} + + +async def test_file_event_emits_ui_file_event() -> None: + out = await _collect( + [ + events_.FileEvent( + media_type="image/png", + data=b"abc", + provider_metadata={"provider": {"file": True}}, + ) + ] + ) + + file_event = next(p for p in out if isinstance(p, ui_events.UIFileEvent)) + assert file_event.url == "data:image/png;base64,YWJj" + assert file_event.media_type == "image/png" + assert file_event.provider_metadata == {"provider": {"file": True}} + + +async def test_resolved_approval_hook_emits_response_event() -> None: + hook: messages_.HookPart[Any] = messages_.HookPart( + hook_id="approve_tc1", + hook_type="ToolApproval", + status="resolved", + metadata={ + "providerExecuted": True, + "callProviderMetadata": {"provider": {"approval": True}}, + }, + resolution={"granted": False, "reason": "not allowed"}, + ) + + out = await _collect( + [ + agent_events_.HookEvent( + message=messages_.Message( + id="turn-1:internal:0", + turn_id="turn-1", + role="internal", + parts=[hook], + ), + hook=hook, + ) + ] + ) + + response = next( + p for p in out if isinstance(p, ui_events.UIToolApprovalResponseEvent) + ) + assert response.approval_id == "approve_tc1" + assert response.approved is False + assert response.reason == "not allowed" + assert response.provider_executed is True + assert response.provider_metadata == {"provider": {"approval": True}} + assert any(isinstance(p, ui_events.UIToolOutputDeniedEvent) for p in out) + + +# NOTE: agent-change boundary detection used to be driven by +# Message.source_label. That field has been removed; agent-change +# routing in the AI SDK adapter now needs to come from +# PartialToolCallResult, which is a separate piece of work. diff --git a/tests/agents/ui/ai_sdk/test_parts.py b/tests/agents/ui/ai_sdk/test_parts.py deleted file mode 100644 index 9876e9c8..00000000 --- a/tests/agents/ui/ai_sdk/test_parts.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -from ai.agents.ui.ai_sdk import _parts -from ai.agents.ui.ai_sdk.ui_message import ( - UIReasoningPart, - UITextPart, - UIToolApproval, - UIToolPart, -) -from ai.types import messages as messages_ - - -def test_to_ui_parts_text_and_reasoning() -> None: - parts: list[messages_.Part] = [ - messages_.ReasoningPart(text="thinking"), - messages_.TextPart(text="hi"), - ] - ui_parts = _parts.to_ui_parts(parts) - assert isinstance(ui_parts[0], UIReasoningPart) - assert ui_parts[0].text == "thinking" - assert isinstance(ui_parts[1], UITextPart) - assert ui_parts[1].text == "hi" - - -def test_to_ui_parts_tool_call_parses_json_args() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args='{"q": "x"}', - ) - ] - ui_parts = _parts.to_ui_parts(parts) - assert isinstance(ui_parts[0], UIToolPart) - assert ui_parts[0].type == "tool-search" - assert ui_parts[0].input == {"q": "x"} - assert ui_parts[0].state == "input-available" - - -def test_merge_tool_results_updates_state_and_output() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="search", - tool_args="{}", - ) - ] - ui_parts = _parts.to_ui_parts(parts) - _parts.merge_tool_results( - ui_parts, - [ - messages_.ToolResultPart( - tool_call_id="tc1", - tool_name="search", - result={"hits": 3}, - ) - ], - ) - merged = ui_parts[0] - assert isinstance(merged, UIToolPart) - assert merged.state == "output-available" - assert merged.output == {"hits": 3} - - -def test_merge_approval_signals_pending_then_resolved() -> None: - parts: list[messages_.Part] = [ - messages_.ToolCallPart( - tool_call_id="tc1", - tool_name="delete", - tool_args="{}", - ) - ] - ui_parts = _parts.to_ui_parts(parts) - - _parts.merge_approval_signals( - ui_parts, - [ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="pending", - ) - ], - ) - requested = ui_parts[0] - assert isinstance(requested, UIToolPart) - assert requested.state == "approval-requested" - assert isinstance(requested.approval, UIToolApproval) - - _parts.merge_approval_signals( - ui_parts, - [ - messages_.HookPart( - hook_id="approve_tc1", - hook_type="ToolApproval", - status="resolved", - resolution={"granted": True, "reason": None}, - ) - ], - ) - responded = ui_parts[0] - assert isinstance(responded, UIToolPart) - assert responded.state == "approval-responded" - assert responded.approval is not None - assert responded.approval.approved is True diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index 8e41101d..d65eecc9 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -57,6 +57,56 @@ def _tool_result( ) +def _parallel_tool_turn( + *, + turn_id: str, + assistant_prefix: str, + tool_call_ids: tuple[str, str] = ("tc-bash", "tc-web"), +) -> list[messages.Message]: + tc_bash, tc_web = tool_call_ids + + return [ + messages.Message( + id=f"{assistant_prefix}:assistant:0", + turn_id=turn_id, + role="assistant", + parts=[ + messages.ToolCallPart( + id=f"{assistant_prefix}:call:bash", + tool_call_id=tc_bash, + tool_name="bash", + tool_args='{"command":"date"}', + ), + messages.ToolCallPart( + id=f"{assistant_prefix}:call:web", + tool_call_id=tc_web, + tool_name="web_fetch", + tool_args='{"url":"https://httpbin.org/get"}', + ), + ], + ), + messages.Message( + id=f"{assistant_prefix}:tool:0", + turn_id=turn_id, + role="tool", + parts=[ + messages.ToolResultPart( + id=f"{assistant_prefix}:result:bash", + tool_call_id=tc_bash, + tool_name="bash", + result="Tue May 19 2026", + ), + messages.ToolResultPart( + id=f"{assistant_prefix}:result:web", + tool_call_id=tc_web, + tool_name="web_fetch", + result={"status": 200}, + ), + ], + ), + ] + + def _assert_raises_issue( msgs: list[messages.Message], issue: str,