diff --git a/examples/read_file_tool.py b/examples/read_file_tool.py new file mode 100644 index 00000000..6ee3aab9 --- /dev/null +++ b/examples/read_file_tool.py @@ -0,0 +1,86 @@ +"""Tool that returns a ContentOutput so the model can see image files directly. + +The ``read_file`` tool reads a path from disk and inspects the bytes: + +* If the file is an image, it returns a :class:`ContentOutput` carrying + a summary line and an image :class:`FilePart`. All three providers + turn that into a real image content block on the next model turn, so + the model actually *sees* the picture. +* Otherwise it returns the decoded text -- the framework wraps that in + a :class:`TextOutput` automatically. + +A single tool covers both code-reading and image-reading duties in an +agentic loop. +""" + +import asyncio +import json +import pathlib + +import ai +from ai.types import media + +# Restrict the tool to a directory we trust the model to roam in. +# `.resolve()` collapses symlinks so a path inside ALLOWED_ROOT cannot +# escape via a symlink that points elsewhere. +ALLOWED_ROOT = pathlib.Path(__file__).parent.resolve() + + +def _resolve_within_allowed(path: str) -> pathlib.Path: + resolved = pathlib.Path(path).resolve() + if not resolved.is_relative_to(ALLOWED_ROOT): + raise ValueError( + f"Refusing to read {path!r}: outside allowed root {ALLOWED_ROOT}" + ) + return resolved + + +@ai.tool +async def read_file(path: str) -> str | ai.messages.ContentOutput: + """Read a file from disk. + + Image files come back as a ContentOutput so the model can view them. + """ + data = _resolve_within_allowed(path).read_bytes() + image_type = media.detect_image_media_type(data) + if image_type is not None: + return ai.content_output( + f"Loaded {path} ({image_type}, {len(data)} bytes).", + ai.file_part(data, media_type=image_type), + ) + return data.decode("utf-8", errors="replace") + + +async def main() -> None: + model = ai.get_model("gateway:anthropic/claude-sonnet-4.6") + my_agent = ai.agent(tools=[read_file]) + + here = pathlib.Path(__file__).parent + image_path = here / "sample_image.jpg" + text_path = here / "agent_simple.py" + + messages = [ + ai.system_message( + "Use the read_file tool to inspect any files the user mentions." + ), + ai.user_message( + f"First read {image_path} and describe what you see in the " + f"picture. Then read {text_path} and summarize what the " + f"script does in one sentence." + ), + ] + + async with my_agent.run(model, messages) as stream: + async for event in stream: + if isinstance(event, ai.events.TextDelta): + print(event.chunk, end="", flush=True) + elif isinstance(event, ai.events.ToolEnd): + args = json.loads(event.tool_call.tool_args or "{}") + print(f"\n[read_file({args.get('path')!r})]") + elif isinstance(event, ai.events.StreamEnd): + print() + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 066d9aaa..6d6ab8f2 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -63,8 +63,10 @@ from .types import events, messages, tools from .types.builders import ( assistant_message, + content_output, file_part, system_message, + text_part, thinking, tool_message, tool_result_part, @@ -119,6 +121,7 @@ "agent", "assistant_message", "cancel_hook", + "content_output", "errors", # Submodules "events", @@ -137,6 +140,7 @@ "resolve_hook", "stream", "system_message", + "text_part", "thinking", "tool", "tool_message", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 90bfd449..6ef03ce6 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -74,8 +74,9 @@ def _error_tool_result( types.messages.ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, - result=f"{type(unwrapped).__name__}: {unwrapped}", - is_error=True, + result=types.messages.ErrorTextOutput( + value=f"{type(unwrapped).__name__}: {unwrapped}" + ), ), exception=unwrapped, ) @@ -174,6 +175,8 @@ def _populate_model_inputs( Tool execution sets ``model_input`` directly; this fills in the value for tool results that were reconstructed from a wire round- trip (e.g. the AI SDK UI inbound path) and never had it computed. + The aggregator's ``model_input_from_result`` does any snapshot + unwrapping internally. """ for msg in messages: if msg.role != "tool": @@ -187,7 +190,7 @@ def _populate_model_inputs( agg_cls = _aggregator_cls(tool.aggregator) if agg_cls is None: continue - part.set_model_input(agg_cls.to_model_input(part.result)) + part.set_model_input(agg_cls.model_input_from_result(part.result)) class SimpleAggregator[Item, Result](events_.Aggregator[Item, Result, Result]): @@ -1038,8 +1041,9 @@ def pending_tool_result( part = types.messages.ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, - result=f"Pending on hook {hook.hook_id!r}", - is_error=True, + result=types.messages.ErrorTextOutput( + value=f"Pending on hook {hook.hook_id!r}" + ), is_hook_pending=True, ) msg = types.messages.Message(role="tool", parts=[part]) diff --git a/src/ai/agents/ui/ai_sdk/inbound_messages.py b/src/ai/agents/ui/ai_sdk/inbound_messages.py index c0c36631..d27cf3d9 100644 --- a/src/ai/agents/ui/ai_sdk/inbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/inbound_messages.py @@ -85,20 +85,22 @@ def _build_result_part( output: Any, is_error: bool, ) -> messages_.ToolResultPart: + result: messages_.ToolResultOutput if is_error: - result: Any = output + text = str(output) if output is not None else "" + result = messages_.ErrorTextOutput(value=text) else: decoded = _decode_wire_output(output) - result = ( + raw = ( decoded if isinstance(decoded, MessageBundle) else _normalize_tool_result(decoded) ) + result = messages_.coerce_to_output(raw) return messages_.ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, result=result, - is_error=is_error, ) @@ -189,8 +191,9 @@ def _patch_pending_hook_aborts( messages_.ToolResultPart( tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, - result=f"Pending on hook '{hook.hook_id}'", - is_error=True, + result=messages_.ErrorTextOutput( + value=f"Pending on hook '{hook.hook_id}'" + ), is_hook_pending=True, ) ) diff --git a/src/ai/agents/ui/ai_sdk/outbound_messages.py b/src/ai/agents/ui/ai_sdk/outbound_messages.py index 4520e20d..019b183e 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_messages.py +++ b/src/ai/agents/ui/ai_sdk/outbound_messages.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from typing import Any, cast from ....types import media @@ -107,6 +108,29 @@ def dedupe_tool_parts( return result +def _output_view( + output: messages_.ToolResultOutput, +) -> tuple[str, dict[str, Any]]: + """Map a :class:`ToolResultOutput` to ``(state, field_updates)``.""" + match output: + case messages_.TextOutput(value=value): + return "output-available", {"output": value} + case messages_.JsonOutput(value=value): + return "output-available", {"output": value} + case messages_.ContentOutput(value=items): + return "output-available", { + "output": [item.model_dump(mode="json") for item in items] + } + case messages_.ErrorTextOutput(value=value): + return "output-error", {"error_text": value} + case messages_.ErrorJsonOutput(value=value): + return "output-error", {"error_text": json.dumps(value)} + case messages_.ExecutionDeniedOutput(reason=reason): + return "output-denied", { + "error_text": reason or "Tool execution denied." + } + + def merge_tool_results( ui_parts: list[ui_messages.UIMessagePart], tool_parts: list[messages_.Part], @@ -121,15 +145,12 @@ def merge_tool_results( continue case messages_.ToolResultPart(): tool_call_id = part.tool_call_id - state = "output-error" if part.is_error else "output-available" + state, field_updates = _output_view(part.result) updates = { "state": state, "result_provider_metadata": part.provider_metadata, + **field_updates, } - if part.is_error: - updates["error_text"] = str(part.result) - else: - updates["output"] = part.result case messages_.BuiltinToolReturnPart(): tool_call_id = part.tool_call_id updates = { diff --git a/src/ai/agents/ui/ai_sdk/outbound_stream.py b/src/ai/agents/ui/ai_sdk/outbound_stream.py index 288bfbe5..2399d313 100644 --- a/src/ai/agents/ui/ai_sdk/outbound_stream.py +++ b/src/ai/agents/ui/ai_sdk/outbound_stream.py @@ -21,13 +21,20 @@ def _tool_error_text(part: messages_.ToolResultPart) -> str: """Best-effort error text extraction from a failed tool result.""" - if isinstance(part.result, str) and part.result: - return part.result - if isinstance(part.result, dict): - for key in ("error", "message", "detail"): - value = part.result.get(key) - if isinstance(value, str) and value: - return value + output = part.result + if isinstance(output, messages_.ErrorTextOutput): + return output.value or "Tool execution failed" + if isinstance(output, messages_.ErrorJsonOutput): + value = output.value + if isinstance(value, str) and value: + return value + if isinstance(value, dict): + for key in ("error", "message", "detail"): + inner = value.get(key) + if isinstance(inner, str) and inner: + return inner + if isinstance(output, messages_.ExecutionDeniedOutput): + return output.reason or "Tool execution denied" return "Tool execution failed" @@ -403,7 +410,16 @@ def on_tool_result( ) ) else: - wire_output = _to_wire_output(part.result) + output = part.result + raw = ( + output.value + if isinstance( + output, + messages_.TextOutput | messages_.JsonOutput, + ) + else output + ) + wire_output = _to_wire_output(raw) if wire_output is None: # Aggregator produced no anchor (e.g. sub-agent # tool that yielded nothing). Skip the final diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index a19e8a8d..74393f8c 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -65,6 +65,57 @@ def _file_part_to_wire(part: types.messages.FilePart) -> dict[str, Any]: return {"type": "file", "data": b64, "mediaType": part.media_type} +# --------------------------------------------------------------------------- +# Tool result output -> v3 wire +# --------------------------------------------------------------------------- + + +def _file_part_to_v3_inline(part: types.messages.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to an inline v3 content element. + + Images become ``image-data``; everything else becomes ``file-data``. + """ + b64 = types.media.data_to_base64(part.data) + if part.media_type.startswith("image/"): + return {"type": "image-data", "data": b64, "mediaType": part.media_type} + entry: dict[str, Any] = { + "type": "file-data", + "data": b64, + "mediaType": part.media_type, + } + if part.filename is not None: + entry["filename"] = part.filename + return entry + + +def _tool_result_output( + output: types.messages.ToolResultOutput, +) -> dict[str, Any]: + """Convert a :class:`ToolResultOutput` to its v3 ``output`` wire form.""" + match output: + case types.messages.TextOutput(value=value): + return {"type": "text", "value": value} + case types.messages.JsonOutput(value=value): + return {"type": "json", "value": value} + case types.messages.ErrorTextOutput(value=value): + return {"type": "error-text", "value": value} + case types.messages.ErrorJsonOutput(value=value): + return {"type": "error-json", "value": value} + case types.messages.ExecutionDeniedOutput(reason=reason): + entry: dict[str, Any] = {"type": "execution-denied"} + if reason is not None: + entry["reason"] = reason + return entry + case types.messages.ContentOutput(value=items): + parts: list[dict[str, Any]] = [] + for item in items: + if isinstance(item, types.messages.FilePart): + parts.append(_file_part_to_v3_inline(item)) + else: + parts.append({"type": "text", "text": item.text}) + return {"type": "content", "value": parts} + + # --------------------------------------------------------------------------- # Streaming request building — Message list → v3 prompt # --------------------------------------------------------------------------- @@ -172,22 +223,7 @@ async def _messages_to_prompt( tool_results: list[dict[str, Any]] = [] for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() - output = ( - { - "type": "error-text", - "value": ( - str(model_input) - if model_input is not None - else "" - ), - } - if part.is_error - else { - "type": "json", - "value": model_input, - } - ) + output = _tool_result_output(part.get_model_input()) tool_results.append( { "type": "tool-result", diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 05351b21..88fa599c 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -179,6 +179,46 @@ def _file_part_to_anthropic( raise ValueError(f"Unsupported media type for Anthropic: {mt}") +def _tool_result_to_anthropic( + output: types.messages.ToolResultOutput, +) -> str | list[dict[str, Any]]: + """Convert a :class:`ToolResultOutput` to Anthropic tool_result content. + + :class:`ContentOutput` expands into Anthropic content blocks + (image/document) so the model sees actual media; all other variants + are stringified (the Anthropic API accepts a string as tool_result + content). + """ + match output: + case types.messages.TextOutput(value=value): + return value + case types.messages.ErrorTextOutput(value=value): + return value + case ( + types.messages.JsonOutput(value=value) + | types.messages.ErrorJsonOutput(value=value) + ): + return ( + json.dumps(value, separators=(",", ":"), default=str) + if value is not None + else "" + ) + case types.messages.ExecutionDeniedOutput(reason=reason): + return ( + f"Tool execution denied: {reason}" + if reason + else "Tool execution denied" + ) + case types.messages.ContentOutput(value=items): + blocks: list[dict[str, Any]] = [] + for item in items: + if isinstance(item, types.messages.FilePart): + blocks.append(_file_part_to_anthropic(item)) + else: + blocks.append({"type": "text", "text": item.text}) + return blocks + + async def _messages_to_anthropic( messages: list[types.messages.Message], ) -> tuple[str | None, list[dict[str, Any]]]: @@ -271,13 +311,13 @@ async def _messages_to_anthropic( tool_results: list[dict[str, Any]] = [] for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() + tool_content = _tool_result_to_anthropic( + part.get_model_input() + ) entry: dict[str, Any] = { "type": "tool_result", "tool_use_id": part.tool_call_id, - "content": str(model_input) - if model_input is not None - else "", + "content": tool_content, } if part.is_error: entry["is_error"] = True diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index 41be970e..e62add3b 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -108,6 +108,51 @@ async def _file_part_to_openai( raise ValueError(f"Unsupported media type for OpenAI: {mt}") +def _tool_result_to_openai( + output: types.messages.ToolResultOutput, +) -> str | list[dict[str, Any]]: + """Convert a :class:`ToolResultOutput` to OpenAI tool-message content. + + :class:`ContentOutput` expands into a content array with ``text`` + and ``image_url`` parts (chat-completions API). All other variants + are stringified. + """ + match output: + case types.messages.TextOutput(value=value): + return value + case types.messages.ErrorTextOutput(value=value): + return value + case ( + types.messages.JsonOutput(value=value) + | types.messages.ErrorJsonOutput(value=value) + ): + return _json_dumps(value) if value is not None else "" + case types.messages.ExecutionDeniedOutput(reason=reason): + return ( + f"Tool execution denied: {reason}" + if reason + else "Tool execution denied" + ) + case types.messages.ContentOutput(value=items): + parts: list[dict[str, Any]] = [] + for item in items: + if isinstance(item, types.messages.FilePart): + mt = item.media_type + if mt.startswith("image/"): + data_url = types.media.data_to_data_url(item.data, mt) + parts.append( + { + "type": "image_url", + "image_url": {"url": data_url}, + } + ) + else: + parts.append({"type": "text", "text": f"[file: {mt}]"}) + else: + parts.append({"type": "text", "text": item.text}) + return parts + + async def _messages_to_openai( messages: list[types.messages.Message], ) -> list[dict[str, Any]]: @@ -165,14 +210,14 @@ async def _messages_to_openai( case "tool": for part in msg.parts: if isinstance(part, types.messages.ToolResultPart): - model_input = part.get_model_input() + tool_content = _tool_result_to_openai( + part.get_model_input() + ) result.append( { "role": "tool", "tool_call_id": part.tool_call_id, - "content": str(model_input) - if model_input is not None - else "", + "content": tool_content, } ) @@ -511,12 +556,55 @@ def _raw_item_from_metadata(part: Any) -> dict[str, Any] | None: return None -def _stringify_tool_result(result: Any) -> str: - if result is None: - return "" - if isinstance(result, str): - return result - return _json_dumps(result) +def _tool_result_to_responses( + output: types.messages.ToolResultOutput, +) -> str | list[dict[str, Any]]: + """Convert a :class:`ToolResultOutput` to a Responses ``output`` value. + + Returns a plain string for text/json/error/denied variants, or an + array of ``input_text`` / ``input_image`` / ``input_file`` parts + for :class:`ContentOutput` (the Responses API accepts both shapes + on ``function_call_output.output``). + """ + match output: + case ( + types.messages.TextOutput(value=value) + | types.messages.ErrorTextOutput(value=value) + ): + return value + case ( + types.messages.JsonOutput(value=value) + | types.messages.ErrorJsonOutput(value=value) + ): + return _json_dumps(value) if value is not None else "" + case types.messages.ExecutionDeniedOutput(reason=reason): + return ( + f"Tool execution denied: {reason}" + if reason + else "Tool execution denied" + ) + case types.messages.ContentOutput(value=items): + parts: list[dict[str, Any]] = [] + for item in items: + if isinstance(item, types.messages.FilePart): + data_url = types.media.data_to_data_url( + item.data, item.media_type + ) + if item.media_type.startswith("image/"): + parts.append( + {"type": "input_image", "image_url": data_url} + ) + else: + entry: dict[str, Any] = { + "type": "input_file", + "file_data": data_url, + } + if item.filename is not None: + entry["filename"] = item.filename + parts.append(entry) + else: + parts.append({"type": "input_text", "text": item.text}) + return parts async def _file_part_to_responses( @@ -650,7 +738,7 @@ async def _messages_to_responses( { "type": "function_call_output", "call_id": part.tool_call_id, - "output": _stringify_tool_result( + "output": _tool_result_to_responses( part.get_model_input() ), } diff --git a/src/ai/types/builders.py b/src/ai/types/builders.py index eeb83c10..dfce753c 100644 --- a/src/ai/types/builders.py +++ b/src/ai/types/builders.py @@ -13,6 +13,11 @@ from . import events as events_ from .messages import ( + ContentOutput, + ContentPart, + ErrorJsonOutput, + ErrorTextOutput, + ExecutionDeniedOutput, FilePart, HookPart, Message, @@ -89,6 +94,35 @@ def file_part( return FilePart.from_bytes(data, media_type=media_type, filename=filename) +def text_part( + text: str, + *, + provider_metadata: dict[str, Any] | None = None, +) -> TextPart: + """Create a :class:`TextPart`. + + Bare strings passed to the ``*_message`` builders are coerced into + text parts automatically; reach for this when you need to attach + ``provider_metadata`` or build a part list directly. + """ + return TextPart(text=text, provider_metadata=provider_metadata) + + +def content_output(*content: str | TextPart | FilePart) -> ContentOutput: + """Create a multipart :class:`ContentOutput` tool result. + + Bare strings become :class:`TextPart` objects, mirroring the + ``*_message`` builders, so a tool can return mixed text and files + without constructing the part list by hand. + + >>> ai.content_output("Here is the chart:", ai.file_part(png_bytes)) + """ + parts: list[ContentPart] = [] + for item in content: + parts.append(TextPart(text=item) if isinstance(item, str) else item) + return ContentOutput(value=parts) + + def thinking( text: str, *, @@ -208,11 +242,28 @@ def tool_result_part( ) -> ToolResultPart: """Create a :class:`ToolResultPart`. + ``result`` is coerced into a :class:`ToolResultOutput` variant. + With ``is_error=True``, a plain value becomes :class:`ErrorTextOutput` + (stringifying non-string values); pass an :class:`ErrorJsonOutput` + or :class:`ExecutionDeniedOutput` directly for richer error shapes. + >>> ai.tool_result_part("tc-1", result={"temp": 72}, tool_name="weather") """ + if is_error: + # Promote plain values to the error variant; pass-through existing + # error / denial outputs. + if isinstance( + result, + ErrorTextOutput | ErrorJsonOutput | ExecutionDeniedOutput, + ): + output: Any = result + else: + text = str(result) if result is not None else "" + output = ErrorTextOutput(value=text) + else: + output = result return ToolResultPart( tool_call_id=tool_call_id, tool_name=tool_name, - result=result, - is_error=is_error, + result=output, ) diff --git a/src/ai/types/events.py b/src/ai/types/events.py index 14aa220c..2cbfd3ed 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -1,6 +1,6 @@ import abc from collections.abc import AsyncGenerator, Callable, Sequence -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, cast import pydantic @@ -290,6 +290,28 @@ def to_model_input(cls, snapshot: Result) -> ModelInput: """ ... + @classmethod + def model_input_from_result( + cls, result: messages.ToolResultOutput + ) -> messages.ToolResultOutput: + """Re-derive model input from a persisted ``ToolResultPart.result``. + + Default: unwrap text/json variants, run :meth:`to_model_input`, + then coerce the result back into a :class:`ToolResultOutput`. + Subclasses where ``result`` already carries the desired shape + (e.g. text-only aggregators) can override to return ``result`` + unchanged. + """ + if isinstance(result, messages.TextOutput | messages.JsonOutput): + # The persisted ``result.value`` is whatever the aggregator's + # snapshot serialized to (often a plain dict for pydantic + # snapshots); the aggregator is responsible for revalidating + # if it cares about a richer in-memory shape. + return messages.coerce_to_output( + cls.to_model_input(cast("Any", result.value)) + ) + return result + class PartialToolCallResult(pydantic.BaseModel): """Emitted when tool calls or other yield_from callers yield values.""" diff --git a/src/ai/types/integrity.py b/src/ai/types/integrity.py index a5328cc4..22dd85c1 100644 --- a/src/ai/types/integrity.py +++ b/src/ai/types/integrity.py @@ -179,8 +179,9 @@ def _flush_pending() -> None: messages_.ToolResultPart( tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, - result="Tool result not available", - is_error=True, + result=messages_.ErrorTextOutput( + value="Tool result not available" + ), ) for tc in pending.values() ) diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 217b4287..86be8f16 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -1,3 +1,4 @@ +import base64 import uuid from typing import Annotated, Any, Literal, Self, overload @@ -21,28 +22,234 @@ class TextPart(pydantic.BaseModel): kind: Literal["text"] = "text" +class FilePart(pydantic.BaseModel): + """File, image, or audio content part. + + Covers images (``image/*``), documents (``application/pdf``, ``text/*``), + and audio (``audio/*``). The ``media_type`` field tells provider + converters how to format this part for each API. + + ``data`` accepts: + + * **str** -- a URL (``http(s)://...`` or ``data:...``) *or* raw + base-64 text. + * **bytes** -- raw binary data (will be base-64 encoded when serialized + to JSON for providers that need it). + """ + + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=lambda: generate_id("part")) + data: str | bytes + media_type: str # IANA media type, e.g. "image/png", "audio/wav" + filename: str | None = None + kind: Literal["file"] = "file" + provider_metadata: dict[str, Any] | None = None + + @pydantic.field_serializer("data", when_used="json") + @classmethod + def _serialize_data(cls, v: str | bytes, _info: Any) -> str: + """Encode ``bytes`` as standard base-64 for JSON serialization. + + Pydantic's built-in ``ser_json_bytes`` uses URL-safe base-64 + (``-`` and ``_``) which LLM provider APIs reject. This + serializer uses standard base-64 (``+`` and ``/``) instead. + ``str`` values (URLs, existing base-64) pass through unchanged. + """ + if isinstance(v, bytes): + return base64.b64encode(v).decode("ascii") + return v + + @classmethod + def from_url(cls, url: str, *, media_type: str | None = None) -> Self: + """Create from a URL, inferring ``media_type`` from the URL if omitted. + + Inference handles ``data:`` URLs (the media type is embedded in the + prefix) and ``http(s)://`` URLs (via :func:`mimetypes.guess_type`). + Raises :class:`ValueError` if inference fails and no explicit + ``media_type`` is provided. + """ + if media_type is None: + media_type = media.infer_media_type(url) + return cls(data=url, media_type=media_type) + + @classmethod + def from_bytes( + cls, + data: bytes, + *, + media_type: str | None = None, + filename: str | None = None, + ) -> Self: + """Create from raw bytes, detecting ``media_type`` via magic bytes. + + Attempts image detection first, then audio. Raises + :class:`ValueError` if no ``media_type`` is provided and + detection fails. + """ + if media_type is None: + media_type = media.detect_image_media_type( + data + ) or media.detect_audio_media_type(data) + if media_type is None: + raise ValueError( + "Cannot detect media_type from bytes. " + "Provide media_type explicitly." + ) + return cls(data=data, media_type=media_type, filename=filename) + + +# --------------------------------------------------------------------------- +# Tool result output -- discriminated union mirroring the AI SDK v3 spec. +# A tool's return value is coerced into one of these variants before it +# lands on a ``ToolResultPart``. Providers switch on ``type`` to build +# their wire format. +# --------------------------------------------------------------------------- + + +class TextOutput(pydantic.BaseModel): + type: Literal["text"] = "text" + value: str + + model_config = pydantic.ConfigDict(frozen=True) + + +class JsonOutput(pydantic.BaseModel): + type: Literal["json"] = "json" + value: Any = None + + model_config = pydantic.ConfigDict(frozen=True) + + +class ErrorTextOutput(pydantic.BaseModel): + type: Literal["error-text"] = "error-text" + value: str + + model_config = pydantic.ConfigDict(frozen=True) + + +class ErrorJsonOutput(pydantic.BaseModel): + type: Literal["error-json"] = "error-json" + value: Any = None + + model_config = pydantic.ConfigDict(frozen=True) + + +class ExecutionDeniedOutput(pydantic.BaseModel): + type: Literal["execution-denied"] = "execution-denied" + reason: str | None = None + + model_config = pydantic.ConfigDict(frozen=True) + + +ContentPart = Annotated[ + TextPart | FilePart, + pydantic.Field(discriminator="kind"), +] + + +class ContentOutput(pydantic.BaseModel): + """Multipart tool result -- mix of text and file/image parts.""" + + type: Literal["content"] = "content" + value: list[ContentPart] + + model_config = pydantic.ConfigDict(frozen=True) + + +ToolResultOutput = Annotated[ + TextOutput + | JsonOutput + | ContentOutput + | ErrorTextOutput + | ErrorJsonOutput + | ExecutionDeniedOutput, + pydantic.Field(discriminator="type"), +] + + +def coerce_to_output(value: Any) -> ToolResultOutput: + """Map a tool return value onto a :class:`ToolResultOutput` variant. + + * ``ToolResultOutput`` instance -- passed through unchanged. + * ``str`` -- wrapped in :class:`TextOutput`. + * Anything else -- :class:`JsonOutput` with the value as-is. + + The value stored on :class:`JsonOutput` is not eagerly serialized: + pydantic models, ``MessageBundle`` snapshots, etc. survive in memory + so UI converters can introspect them. On JSON round-trip the value + is dumped/loaded normally and comes back as a plain dict/list/... + """ + if isinstance( + value, + TextOutput + | JsonOutput + | ContentOutput + | ErrorTextOutput + | ErrorJsonOutput + | ExecutionDeniedOutput, + ): + return value + if isinstance(value, str): + return TextOutput(value=value) + return JsonOutput(value=value) + + _MODEL_INPUT_UNSET: Any = object() +def _coerce_result_field(value: Any) -> Any: + """``BeforeValidator`` for ``ToolResultPart.result``. + + Pass-through for ``ToolResultOutput`` instances and wire-shape dicts + (with a known ``type`` discriminator). Anything else is routed + through :func:`coerce_to_output` so plain tool returns and stored + raw values still construct cleanly. + """ + if isinstance( + value, + TextOutput + | JsonOutput + | ContentOutput + | ErrorTextOutput + | ErrorJsonOutput + | ExecutionDeniedOutput, + ): + return value + if isinstance(value, dict) and value.get("type") in { + "text", + "json", + "content", + "error-text", + "error-json", + "execution-denied", + }: + return value + return coerce_to_output(value) + + class ToolResultPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=lambda: generate_id("part")) tool_call_id: str tool_name: str - is_error: bool = False is_hook_pending: bool = False provider_metadata: dict[str, Any] | None = None - # The "real" result of the tool call - result: Any = None - - # Value the LLM sees on its next turn. For most tools this is - # identical to ``result``; for aggregator-backed tools (sub-agents, - # streaming-text) it's derived from the aggregator's - # ``get_model_input``. Not part of the wire model: it's populated - # by tool execution and by ``Agent.run`` (which has the tool - # registry) rather than carried across serialization. ``default_factory`` - # preserves singleton identity so the unset sentinel survives pydantic's - # default-copying. + # The model-facing tool result, always a :class:`ToolResultOutput` + # variant. Plain values (str, dict, BaseModel, ...) are coerced on + # construction via :func:`coerce_to_output`, so existing call sites + # can still pass raw values and stored messages from prior versions + # round-trip. + result: Annotated[ + ToolResultOutput, pydantic.BeforeValidator(_coerce_result_field) + ] + + # Override for the model-facing value. Set explicitly by tool + # execution for streaming/aggregator tools (where ``result`` holds + # the rich snapshot) and reconstructed from ``result`` via the + # tool's aggregator in :func:`_populate_model_inputs` after a JSON + # round-trip. When unset, providers use ``result`` directly. + # ``PrivateAttr`` so it doesn't appear in serialized messages. _model_input: Any = pydantic.PrivateAttr( default_factory=lambda: _MODEL_INPUT_UNSET ) @@ -50,15 +257,25 @@ class ToolResultPart(pydantic.BaseModel): kind: Literal["tool_result"] = "tool_result" model_config = pydantic.ConfigDict(frozen=True) - def get_model_input(self) -> Any: - """Return the value the LLM should see, falling back to ``result``.""" + @property + def is_error(self) -> bool: + """Whether this result represents an error to the model.""" + output = self.get_model_input() + return output.type in ("error-text", "error-json", "execution-denied") + + def get_model_input(self) -> ToolResultOutput: + """Return the converted value the LLM should see. + + Returns the explicit ``_model_input`` override when set; + otherwise falls back to the typed ``result`` field. + """ if self._model_input is _MODEL_INPUT_UNSET: return self.result - return self._model_input + return self._model_input # type: ignore[no-any-return] def set_model_input(self, value: Any) -> None: - """Set the model-facing value (overrides the ``result`` fallback).""" - self._model_input = value + """Set the model-facing value, coercing to :class:`ToolResultOutput`.""" + self._model_input = coerce_to_output(value) @property def has_model_input(self) -> bool: @@ -141,69 +358,6 @@ class HookPart[T](pydantic.BaseModel): model_config = pydantic.ConfigDict(frozen=True) -class FilePart(pydantic.BaseModel): - """File, image, or audio content part. - - Covers images (``image/*``), documents (``application/pdf``, ``text/*``), - and audio (``audio/*``). The ``media_type`` field tells provider - converters how to format this part for each API. - - ``data`` accepts: - - * **str** -- a URL (``http(s)://...`` or ``data:...``) *or* raw - base-64 text. - * **bytes** -- raw binary data (will be base-64 encoded when serialized - to JSON for providers that need it). - """ - - model_config = pydantic.ConfigDict(frozen=True) - - id: str = pydantic.Field(default_factory=lambda: generate_id("part")) - data: str | bytes - media_type: str # IANA media type, e.g. "image/png", "audio/wav" - filename: str | None = None - kind: Literal["file"] = "file" - provider_metadata: dict[str, Any] | None = None - - @classmethod - def from_url(cls, url: str, *, media_type: str | None = None) -> Self: - """Create from a URL, inferring ``media_type`` from the URL if omitted. - - Inference handles ``data:`` URLs (the media type is embedded in the - prefix) and ``http(s)://`` URLs (via :func:`mimetypes.guess_type`). - Raises :class:`ValueError` if inference fails and no explicit - ``media_type`` is provided. - """ - if media_type is None: - media_type = media.infer_media_type(url) - return cls(data=url, media_type=media_type) - - @classmethod - def from_bytes( - cls, - data: bytes, - *, - media_type: str | None = None, - filename: str | None = None, - ) -> Self: - """Create from raw bytes, detecting ``media_type`` via magic bytes. - - Attempts image detection first, then audio. Raises - :class:`ValueError` if no ``media_type`` is provided and - detection fails. - """ - if media_type is None: - media_type = media.detect_image_media_type( - data - ) or media.detect_audio_media_type(data) - if media_type is None: - raise ValueError( - "Cannot detect media_type from bytes. " - "Provide media_type explicitly." - ) - return cls(data=data, media_type=media_type, filename=filename) - - Part = Annotated[ TextPart | ToolCallPart diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 13a73cc7..e21e6229 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -137,7 +137,9 @@ async def fake_fn(**kwargs: str) -> str: # Tool result is visible in messages. tool_results = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_results) >= 1 - assert tool_results[0].tool_results[0].result == "echoed: hello" + tr = tool_results[0].tool_results[0].result + assert isinstance(tr, ai.messages.TextOutput | ai.messages.JsonOutput) + assert tr.value == "echoed: hello" # LLM was called twice (tool call + final text). assert llm.call_count == 2 diff --git a/tests/agents/test_aggregate_marker.py b/tests/agents/test_aggregate_marker.py index e0fe90a2..87deb4c4 100644 --- a/tests/agents/test_aggregate_marker.py +++ b/tests/agents/test_aggregate_marker.py @@ -153,4 +153,6 @@ async def test_alias_declared_tool_runs_end_to_end() -> None: tool_results = [ e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] - assert tool_results[0].results[0].result == "Answer for test" + tr = tool_results[0].results[0].result + assert isinstance(tr, ai.messages.TextOutput | ai.messages.JsonOutput) + assert tr.value == "Answer for test" diff --git a/tests/agents/test_generator_tools.py b/tests/agents/test_generator_tools.py index 0f2bc9d6..cce1ac45 100644 --- a/tests/agents/test_generator_tools.py +++ b/tests/agents/test_generator_tools.py @@ -72,7 +72,9 @@ async def test_generator_tool_streams_and_returns_result() -> None: e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] assert len(tool_results) >= 1 - assert tool_results[0].results[0].result == "Answer for test" + tr = tool_results[0].results[0].result + assert isinstance(tr, ai.messages.TextOutput | ai.messages.JsonOutput) + assert tr.value == "Answer for test" # --------------------------------------------------------------------------- @@ -179,13 +181,16 @@ async def test_yield_from_nested_agent() -> None: tool_results = [ e for e in all_events if isinstance(e, agent_events_.ToolCallResult) ] - # MessageAggregator stores the rich MessageBundle as `result` and the - # extracted assistant text as the model input (the value the parent - # LLM sees on its next turn). + # MessageAggregator stores the rich MessageBundle inside the typed + # ``result`` (a JsonOutput wrapping the bundle) and the extracted + # assistant text as the model input the parent LLM sees on its next turn. sub_part = tool_results[0].results[0] - assert isinstance(sub_part.result, MessageBundle) - assert sub_part.result.messages[0].text == "Mars has two moons." - assert sub_part.get_model_input() == "Mars has two moons." + assert isinstance(sub_part.result, ai.messages.JsonOutput) + assert isinstance(sub_part.result.value, MessageBundle) + assert sub_part.result.value.messages[0].text == "Mars has two moons." + model_input = sub_part.get_model_input() + assert isinstance(model_input, ai.messages.TextOutput) + assert model_input.value == "Mars has two moons." # The outer LLM's second call (index 2) must NOT contain any inner # agent messages. It should only see: the original user message, diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 9a4096b9..e720c9b4 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -60,7 +60,9 @@ async def test_agent_tool_then_text() -> None: assert llm.call_count == 2 tool_results = [m for m in msgs if m.role == "tool" and m.tool_results] assert len(tool_results) >= 1 - assert tool_results[0].tool_results[0].result == 10 + tr = tool_results[0].tool_results[0].result + assert isinstance(tr, ai.messages.TextOutput | ai.messages.JsonOutput) + assert tr.value == 10 # -- Agent default loop: multiple tool calls in one message ---------------- diff --git a/tests/agents/test_tools.py b/tests/agents/test_tools.py index eee0b332..6d821b4a 100644 --- a/tests/agents/test_tools.py +++ b/tests/agents/test_tools.py @@ -75,7 +75,9 @@ async def add(a: int, b: int) -> int: tool_args='{"a": 1, "b": 2}', ) result = await ai.agents.BoundToolCall(part=part, tool=add)() - assert result.results[0].result == 3 + out = result.results[0].result + assert isinstance(out, ai.messages.JsonOutput) + assert out.value == 3 # -- ToolCall binds a ToolCallPart to a Tool and returns tool messages ---- @@ -101,7 +103,9 @@ async def double(x: int) -> int: assert len(result.results) == 1 assert result.results[0].tool_call_id == "tc-1" assert result.results[0].tool_name == "double" - assert result.results[0].result == 10 + out = result.results[0].result + assert isinstance(out, ai.messages.JsonOutput) + assert out.value == 10 assert not result.results[0].is_error @@ -175,7 +179,9 @@ async def double(x: int) -> int: result = await tc(x=7) - assert result.results[0].result == 14 + out = result.results[0].result + assert isinstance(out, ai.messages.JsonOutput) + assert out.value == 14 async def test_tool_call_override_validation_failure() -> None: diff --git a/tests/agents/ui/ai_sdk/test_inbound_messages.py b/tests/agents/ui/ai_sdk/test_inbound_messages.py index f2d2abbc..43109fcb 100644 --- a/tests/agents/ui/ai_sdk/test_inbound_messages.py +++ b/tests/agents/ui/ai_sdk/test_inbound_messages.py @@ -185,7 +185,8 @@ def test_to_messages_decodes_subagent_tool_output() -> None: tool_msgs = [m for m in messages if m.role == "tool"] assert len(tool_msgs) == 1 result_part = tool_msgs[0].tool_results[0] - assert isinstance(result_part.result, MessageBundle) + assert isinstance(result_part.result, messages_.JsonOutput) + assert isinstance(result_part.result.value, MessageBundle) assert not result_part.has_model_input @@ -208,8 +209,8 @@ def test_to_messages_passthrough_keeps_wire_shape() -> None: messages, _ = to_messages(ui) tool_msgs = [m for m in messages if m.role == "tool"] part = tool_msgs[0].tool_results[0] - assert part.result == {"pong": True} - assert part.get_model_input() == {"pong": True} + assert isinstance(part.result, messages_.JsonOutput) + assert part.result.value == {"pong": True} def test_to_messages_accepts_metadata_and_ui_only_parts() -> None: diff --git a/tests/agents/ui/ai_sdk/test_outbound_messages.py b/tests/agents/ui/ai_sdk/test_outbound_messages.py index ad503e58..b384384d 100644 --- a/tests/agents/ui/ai_sdk/test_outbound_messages.py +++ b/tests/agents/ui/ai_sdk/test_outbound_messages.py @@ -58,13 +58,13 @@ def _parallel_tool_turn( id=f"{prefix}:result:bash", tool_call_id=tc_bash, tool_name="bash", - result="Tue May 19 2026", + result=messages_.TextOutput(value="Tue May 19 2026"), ), messages_.ToolResultPart( id=f"{prefix}:result:web", tool_call_id=tc_web, tool_name="web_fetch", - result={"status": 200}, + result=messages_.JsonOutput(value={"status": 200}), ), ], ), @@ -124,7 +124,7 @@ def test_merge_tool_results_updates_state_and_output() -> None: messages_.ToolResultPart( tool_call_id="tc1", tool_name="search", - result={"hits": 3}, + result=messages_.JsonOutput(value={"hits": 3}), ) ], ) @@ -250,7 +250,7 @@ def test_to_ui_messages_merges_assistant_tool_internal() -> None: messages_.ToolResultPart( tool_call_id="tc1", tool_name="search", - result={"hits": 2}, + result=messages_.JsonOutput(value={"hits": 2}), ) ], ), @@ -298,7 +298,7 @@ def test_to_ui_messages_records_source_messages_in_metadata() -> None: id="result-0", tool_call_id="tc1", tool_name="search", - result={"hits": 2}, + result=messages_.JsonOutput(value={"hits": 2}), ) ], ), @@ -493,7 +493,7 @@ def test_collapsed_assistant_turn_roundtrips_internal_ids() -> None: id="result-beta", tool_call_id="tc-first", tool_name="search", - result={"hits": 1}, + result=messages_.JsonOutput(value={"hits": 1}), ) ], ), @@ -520,7 +520,7 @@ def test_collapsed_assistant_turn_roundtrips_internal_ids() -> None: id="result-delta", tool_call_id="tc-second", tool_name="lookup", - result={"value": 2}, + result=messages_.JsonOutput(value={"value": 2}), ) ], ), diff --git a/tests/agents/ui/ai_sdk/test_outbound_stream.py b/tests/agents/ui/ai_sdk/test_outbound_stream.py index e496405c..58448102 100644 --- a/tests/agents/ui/ai_sdk/test_outbound_stream.py +++ b/tests/agents/ui/ai_sdk/test_outbound_stream.py @@ -177,7 +177,7 @@ async def test_finish_metadata_tracks_tool_and_internal_messages() -> None: id="result-1", tool_call_id="tc1", tool_name="search", - result={"hits": 1}, + result=messages_.JsonOutput(value={"hits": 1}), ) ], ) @@ -271,7 +271,7 @@ async def test_tool_call_and_result_emit_terminal_events() -> None: messages_.ToolResultPart( tool_call_id="tc1", tool_name="search", - result={"hits": 1}, + result=messages_.JsonOutput(value={"hits": 1}), ) ], ) @@ -314,7 +314,7 @@ async def test_tool_result_without_streaming_emits_input_start() -> None: messages_.ToolResultPart( tool_call_id="tc1", tool_name="search", - result={"hits": 1}, + result=messages_.JsonOutput(value={"hits": 1}), ), ], ) diff --git a/tests/providers/ai_gateway/test_protocol.py b/tests/providers/ai_gateway/test_protocol.py index 5c881bc2..5d95d10e 100644 --- a/tests/providers/ai_gateway/test_protocol.py +++ b/tests/providers/ai_gateway/test_protocol.py @@ -87,7 +87,7 @@ async def test_tool_call_with_result_produces_two_messages(self) -> None: messages.ToolResultPart( tool_call_id="tc-1", tool_name="get_weather", - result={"temp": 72}, + result=messages.JsonOutput(value={"temp": 72}), ) ], ), @@ -124,8 +124,9 @@ async def test_tool_error_result(self) -> None: messages.ToolResultPart( tool_call_id="tc-1", tool_name="get_weather", - result="Connection timeout", - is_error=True, + result=messages.ErrorTextOutput( + value="Connection timeout" + ), ) ], ), @@ -389,3 +390,124 @@ def test_non_dict_returns_empty(self) -> None: usage = protocol._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 + + +# --------------------------------------------------------------------------- +# Multi-part tool result helpers +# --------------------------------------------------------------------------- + + +class TestFilePartToV3Inline: + def test_image_data(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + entry = protocol._file_part_to_v3_inline(fp) + assert entry == { + "type": "image-data", + "data": "b64data", + "mediaType": "image/png", + } + + def test_file_data_with_filename(self) -> None: + fp = messages.FilePart( + data="pdfdata", + media_type="application/pdf", + filename="doc.pdf", + ) + entry = protocol._file_part_to_v3_inline(fp) + assert entry["type"] == "file-data" + assert entry["mediaType"] == "application/pdf" + assert entry["filename"] == "doc.pdf" + + def test_bytes_become_base64(self) -> None: + fp = messages.FilePart(data=b"\x89PNG", media_type="image/png") + entry = protocol._file_part_to_v3_inline(fp) + assert entry["type"] == "image-data" + assert entry["data"] != "" + + +class TestToolResultOutput: + def test_text(self) -> None: + result = protocol._tool_result_output(messages.TextOutput(value="hi")) + assert result == {"type": "text", "value": "hi"} + + def test_json(self) -> None: + result = protocol._tool_result_output( + messages.JsonOutput(value={"key": "value"}) + ) + assert result == {"type": "json", "value": {"key": "value"}} + + def test_error_text(self) -> None: + result = protocol._tool_result_output( + messages.ErrorTextOutput(value="oops") + ) + assert result == {"type": "error-text", "value": "oops"} + + def test_error_json(self) -> None: + result = protocol._tool_result_output( + messages.ErrorJsonOutput(value={"code": 500}) + ) + assert result == {"type": "error-json", "value": {"code": 500}} + + def test_execution_denied_with_reason(self) -> None: + result = protocol._tool_result_output( + messages.ExecutionDeniedOutput(reason="user said no") + ) + assert result == {"type": "execution-denied", "reason": "user said no"} + + def test_execution_denied_no_reason(self) -> None: + result = protocol._tool_result_output(messages.ExecutionDeniedOutput()) + assert result == {"type": "execution-denied"} + + def test_content_multipart(self) -> None: + fp = messages.FilePart(data="b64", media_type="image/jpeg") + result = protocol._tool_result_output( + messages.ContentOutput(value=[messages.TextPart(text="desc"), fp]) + ) + assert result["type"] == "content" + assert result["value"][0] == {"type": "text", "text": "desc"} + assert result["value"][1]["type"] == "image-data" + + +class TestMessagesToPromptMultipart: + async def test_tool_result_with_file_part(self) -> None: + """ContentOutput with a FilePart uses the 'content' wire output.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + ) + ], + ), + ] + result = await protocol._messages_to_prompt(msgs) + tr = result[1]["content"][0] + assert tr["output"]["type"] == "content" + assert tr["output"]["value"][0] == { + "type": "text", + "text": "Image loaded", + } + assert tr["output"]["value"][1] == { + "type": "image-data", + "data": "iVBOR", + "mediaType": "image/png", + } diff --git a/tests/providers/ai_gateway/test_stream.py b/tests/providers/ai_gateway/test_stream.py index 10962a6a..c9c69d81 100644 --- a/tests/providers/ai_gateway/test_stream.py +++ b/tests/providers/ai_gateway/test_stream.py @@ -457,7 +457,7 @@ def handler(req: httpx.Request) -> httpx.Response: tool_result = messages.ToolResultPart( tool_call_id="tc-1", tool_name="search", - result={"temp": 72}, + result=messages.JsonOutput(value={"temp": 72}), ) conversation = [ user_msg("What's the weather?"), diff --git a/tests/providers/anthropic/test_multipart_tool_result.py b/tests/providers/anthropic/test_multipart_tool_result.py new file mode 100644 index 00000000..e28e1a71 --- /dev/null +++ b/tests/providers/anthropic/test_multipart_tool_result.py @@ -0,0 +1,123 @@ +"""Tests for multi-part tool results in the Anthropic protocol.""" + +from __future__ import annotations + +from ai.providers.anthropic import protocol +from ai.types import messages + + +class TestToolResultToAnthropic: + def test_text_output(self) -> None: + result = protocol._tool_result_to_anthropic( + messages.TextOutput(value="hello") + ) + assert result == "hello" + + def test_json_output_none(self) -> None: + result = protocol._tool_result_to_anthropic(messages.JsonOutput()) + assert result == "" + + def test_json_output_dict(self) -> None: + result = protocol._tool_result_to_anthropic( + messages.JsonOutput(value={"key": "value"}) + ) + assert result == '{"key":"value"}' + + def test_json_output_list(self) -> None: + result = protocol._tool_result_to_anthropic( + messages.JsonOutput(value=[1, 2, 3]) + ) + assert result == "[1,2,3]" + + def test_error_text_output(self) -> None: + result = protocol._tool_result_to_anthropic( + messages.ErrorTextOutput(value="boom") + ) + assert result == "boom" + + def test_execution_denied(self) -> None: + result = protocol._tool_result_to_anthropic( + messages.ExecutionDeniedOutput(reason="user said no") + ) + assert result == "Tool execution denied: user said no" + + def test_content_text_and_file(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput( + value=[messages.TextPart(text="Image loaded"), fp] + ) + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"type": "text", "text": "Image loaded"} + assert result[1]["type"] == "image" + assert result[1]["source"]["type"] == "base64" + assert result[1]["source"]["media_type"] == "image/png" + assert result[1]["source"]["data"] == "b64data" + + def test_content_file_only(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/jpeg") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput(value=[fp]) + ) + assert isinstance(result, list) + assert result[0]["type"] == "image" + assert result[0]["source"]["media_type"] == "image/jpeg" + + def test_content_bytes_file(self) -> None: + fp = messages.FilePart(data=b"\x89PNG", media_type="image/png") + result = protocol._tool_result_to_anthropic( + messages.ContentOutput(value=[messages.TextPart(text="desc"), fp]) + ) + assert isinstance(result, list) + assert result[1]["type"] == "image" + assert result[1]["source"]["data"] != "" + + +class TestMessagesToAnthropicMultipart: + async def test_tool_result_with_file_part(self) -> None: + """FilePart in tool results produces image content blocks.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Read image")], + ), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + ) + ], + ), + ] + _, result = await protocol._messages_to_anthropic(msgs) + tool_msg = result[-1] + assert tool_msg["role"] == "user" + tr = tool_msg["content"][0] + assert tr["type"] == "tool_result" + content = tr["content"] + assert isinstance(content, list) + assert content[0] == {"type": "text", "text": "Image loaded"} + assert content[1]["type"] == "image" + assert content[1]["source"]["type"] == "base64" + assert content[1]["source"]["media_type"] == "image/png" diff --git a/tests/providers/openai/test_multipart_tool_result.py b/tests/providers/openai/test_multipart_tool_result.py new file mode 100644 index 00000000..0507db17 --- /dev/null +++ b/tests/providers/openai/test_multipart_tool_result.py @@ -0,0 +1,128 @@ +"""Tests for multi-part tool results in the OpenAI protocol.""" + +from __future__ import annotations + +from ai.providers.openai import protocol +from ai.types import messages + + +class TestToolResultToOpenai: + def test_text_output(self) -> None: + result = protocol._tool_result_to_openai( + messages.TextOutput(value="hello") + ) + assert result == "hello" + + def test_json_output_none(self) -> None: + result = protocol._tool_result_to_openai(messages.JsonOutput()) + assert result == "" + + def test_json_output_dict(self) -> None: + result = protocol._tool_result_to_openai( + messages.JsonOutput(value={"key": "value"}) + ) + assert result == '{"key":"value"}' + + def test_json_output_list(self) -> None: + result = protocol._tool_result_to_openai( + messages.JsonOutput(value=[1, 2, 3]) + ) + assert result == "[1,2,3]" + + def test_error_text_output(self) -> None: + result = protocol._tool_result_to_openai( + messages.ErrorTextOutput(value="boom") + ) + assert result == "boom" + + def test_execution_denied(self) -> None: + result = protocol._tool_result_to_openai( + messages.ExecutionDeniedOutput(reason="user said no") + ) + assert result == "Tool execution denied: user said no" + + def test_content_text_and_image(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/png") + result = protocol._tool_result_to_openai( + messages.ContentOutput( + value=[messages.TextPart(text="Image loaded"), fp] + ) + ) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0] == {"type": "text", "text": "Image loaded"} + assert result[1]["type"] == "image_url" + assert result[1]["image_url"]["url"].startswith( + "data:image/png;base64," + ) + assert "b64data" in result[1]["image_url"]["url"] + + def test_content_image_only(self) -> None: + fp = messages.FilePart(data="b64data", media_type="image/jpeg") + result = protocol._tool_result_to_openai( + messages.ContentOutput(value=[fp]) + ) + assert isinstance(result, list) + assert result[0]["type"] == "image_url" + assert result[0]["image_url"]["url"].startswith( + "data:image/jpeg;base64," + ) + + def test_content_non_image_file(self) -> None: + fp = messages.FilePart(data="pdfdata", media_type="application/pdf") + result = protocol._tool_result_to_openai( + messages.ContentOutput(value=[messages.TextPart(text="desc"), fp]) + ) + assert isinstance(result, list) + assert result[1] == {"type": "text", "text": "[file: application/pdf]"} + + +class TestMessagesToOpenaiMultipart: + async def test_tool_result_with_file_part(self) -> None: + """ContentOutput with a FilePart produces image_url parts.""" + fp = messages.FilePart(data="iVBOR", media_type="image/png") + msgs = [ + messages.Message( + role="system", + parts=[messages.TextPart(text="System")], + ), + messages.Message( + role="user", + parts=[messages.TextPart(text="Read image")], + ), + messages.Message( + role="assistant", + parts=[ + messages.ToolCallPart( + tool_call_id="tc-1", + tool_name="read", + tool_args='{"path": "test.png"}', + ) + ], + ), + messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc-1", + tool_name="read", + result=messages.ContentOutput( + value=[ + messages.TextPart(text="Image loaded"), + fp, + ] + ), + ) + ], + ), + ] + result = await protocol._messages_to_openai(msgs) + tool_msg = result[-1] + assert tool_msg["role"] == "tool" + content = tool_msg["content"] + assert isinstance(content, list) + assert content[0] == {"type": "text", "text": "Image loaded"} + assert content[1]["type"] == "image_url" + assert content[1]["image_url"]["url"].startswith( + "data:image/png;base64," + ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index ba82e78a..ac6a6ce3 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -275,5 +275,7 @@ async def double(x: int) -> int: assert len(tool_result_msgs) >= 1 # The fixer middleware supplied x=99, so double should return 198. - assert tool_result_msgs[0].tool_results[0].result == 198 + tr = tool_result_msgs[0].tool_results[0].result + assert isinstance(tr, ai.messages.TextOutput | ai.messages.JsonOutput) + assert tr.value == 198 assert tool_result_msgs[0].tool_results[0].is_error is False diff --git a/tests/types/test_builders.py b/tests/types/test_builders.py index 2aae1275..30b76d1b 100644 --- a/tests/types/test_builders.py +++ b/tests/types/test_builders.py @@ -20,6 +20,25 @@ def test_user_message_mixed_content() -> None: assert isinstance(msg.parts[2], messages.TextPart) +def test_text_part() -> None: + tp = builders.text_part("hello", provider_metadata={"k": "v"}) + assert isinstance(tp, messages.TextPart) + assert tp.text == "hello" + assert tp.provider_metadata == {"k": "v"} + + +def test_content_output_coerces_strings() -> None: + fp = messages.FilePart( + data="https://example.com/img.png", media_type="image/png" + ) + out = builders.content_output("Here:", fp) + assert isinstance(out, messages.ContentOutput) + assert len(out.value) == 2 + assert isinstance(out.value[0], messages.TextPart) + assert out.value[0].text == "Here:" + assert isinstance(out.value[1], messages.FilePart) + + def test_file_part_from_url() -> None: fp = builders.file_part("https://example.com/image.png") assert isinstance(fp, messages.FilePart) diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index d65eecc9..96e5caae 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -94,13 +94,13 @@ def _parallel_tool_turn( id=f"{assistant_prefix}:result:bash", tool_call_id=tc_bash, tool_name="bash", - result="Tue May 19 2026", + result=messages.TextOutput(value="Tue May 19 2026"), ), messages.ToolResultPart( id=f"{assistant_prefix}:result:web", tool_call_id=tc_web, tool_name="web_fetch", - result={"status": 200}, + result=messages.JsonOutput(value={"status": 200}), ), ], ), diff --git a/tests/types/test_media.py b/tests/types/test_media.py index 39ef8f4f..56afab9d 100644 --- a/tests/types/test_media.py +++ b/tests/types/test_media.py @@ -73,3 +73,20 @@ def test_empty_or_short_media_returns_none() -> None: assert media.detect_audio_media_type(b"") is None assert media.detect_image_media_type(bytes([0x89])) is None assert media.detect_audio_media_type(bytes([0xFF])) is None + + +def test_data_to_base64_bytes_produces_standard() -> None: + """bytes input produces standard base-64 (+ and /).""" + data = b"\xff\xd8\xff\xe0" # JPEG header + result = media.data_to_base64(data) + decoded = base64.b64decode(result) + assert decoded == data + # Must be standard base-64, not URL-safe + assert "-" not in result + assert "_" not in result + + +def test_data_to_base64_str_passthrough() -> None: + """Standard base-64 string passes through unchanged.""" + standard = "/9j/4AAQSkZJRg==" + assert media.data_to_base64(standard) == standard diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index f684f884..d9c31cd5 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Any + import pytest from ai.types import messages, usage @@ -112,3 +114,113 @@ def test_from_bytes_explicit_overrides() -> None: def test_from_bytes_unknown_raises() -> None: with pytest.raises(ValueError, match="Cannot detect media_type"): messages.FilePart.from_bytes(b"\x00\x01\x02\x03") + + +# --------------------------------------------------------------------------- +# ToolResultPart -- typed result coercion and round-trip +# --------------------------------------------------------------------------- + + +def test_tool_result_content_output_with_file_part_round_trip() -> None: + """FilePart inside ContentOutput survives JSON round-trip.""" + fp = messages.FilePart(data=b"fake-image-data", media_type="image/png") + trp = messages.ToolResultPart( + tool_call_id="tc1", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="label"), fp] + ), + ) + j = trp.model_dump_json() + restored = messages.ToolResultPart.model_validate_json(j) + assert isinstance(restored.result, messages.ContentOutput) + assert len(restored.result.value) == 2 + text_part, file_part = restored.result.value + assert isinstance(text_part, messages.TextPart) + assert text_part.text == "label" + assert isinstance(file_part, messages.FilePart) + assert file_part.media_type == "image/png" + + +def test_tool_result_plain_values_coerced() -> None: + """Plain str / dict / list / None results coerce to typed output. + + Exercises the ``BeforeValidator`` back-compat path that lets callers + pass raw values into ``ToolResultPart(result=...)``. The raw inputs + are typed ``Any`` so the static type checker doesn't reject the call. + """ + cases: list[tuple[Any, type[messages.ToolResultOutput], Any]] = [ + ("hello", messages.TextOutput, "hello"), + (None, messages.JsonOutput, None), + ([1, 2, 3], messages.JsonOutput, [1, 2, 3]), + ({"key": "val"}, messages.JsonOutput, {"key": "val"}), + ] + for raw, expected_cls, expected_value in cases: + trp = messages.ToolResultPart( + tool_call_id="tc", tool_name="t", result=raw + ) + assert isinstance(trp.result, messages.TextOutput | messages.JsonOutput) + assert isinstance(trp.result, expected_cls) + assert trp.result.value == expected_value + restored = messages.ToolResultPart.model_validate_json( + trp.model_dump_json() + ) + assert isinstance( + restored.result, messages.TextOutput | messages.JsonOutput + ) + assert isinstance(restored.result, expected_cls) + assert restored.result.value == expected_value + + +def test_tool_result_content_in_message_round_trip() -> None: + """ContentOutput with a FilePart survives Message round-trip.""" + fp = messages.FilePart(data=b"img-data", media_type="image/webp") + msg = messages.Message( + role="tool", + parts=[ + messages.ToolResultPart( + tool_call_id="tc", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="Read image"), fp] + ), + ) + ], + ) + j = msg.model_dump_json() + restored = messages.Message.model_validate_json(j) + part = restored.parts[0] + assert isinstance(part, messages.ToolResultPart) + assert isinstance(part.result, messages.ContentOutput) + fp2 = part.result.value[1] + assert isinstance(fp2, messages.FilePart) + assert fp2.media_type == "image/webp" + + +def test_tool_result_file_part_base64_valid_after_round_trip() -> None: + """After round-trip, data_to_base64 produces standard base-64.""" + import base64 + + from ai.types import media as media_ + + raw = b"\xff\xd8\xff\xe0\x00\x10JFIF" * 10 + fp = messages.FilePart(data=raw, media_type="image/jpeg") + trp = messages.ToolResultPart( + tool_call_id="tc", + tool_name="read", + result=messages.ContentOutput( + value=[messages.TextPart(text="label"), fp] + ), + ) + restored = messages.ToolResultPart.model_validate_json( + trp.model_dump_json() + ) + assert isinstance(restored.result, messages.ContentOutput) + fp2 = restored.result.value[1] + assert isinstance(fp2, messages.FilePart) + + b64 = media_.data_to_base64(fp2.data) + assert "_" not in b64 + assert "-" not in b64 + decoded = base64.b64decode(b64) + assert decoded == raw