Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions examples/read_file_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tool that returns a ContentOutput so the model can see image files directly.

The ``read_file`` tool reads a path from disk and inspects the bytes:

* If the file is an image, it returns a :class:`ContentOutput` carrying
a summary line and an image :class:`FilePart`. All three providers
turn that into a real image content block on the next model turn, so
the model actually *sees* the picture.
* Otherwise it returns the decoded text -- the framework wraps that in
a :class:`TextOutput` automatically.

A single tool covers both code-reading and image-reading duties in an
agentic loop.
"""

import asyncio
import json
import pathlib

import ai
from ai.types import media

# Restrict the tool to a directory we trust the model to roam in.
# `.resolve()` collapses symlinks so a path inside ALLOWED_ROOT cannot
# escape via a symlink that points elsewhere.
ALLOWED_ROOT = pathlib.Path(__file__).parent.resolve()


def _resolve_within_allowed(path: str) -> pathlib.Path:
resolved = pathlib.Path(path).resolve()
if not resolved.is_relative_to(ALLOWED_ROOT):
raise ValueError(
f"Refusing to read {path!r}: outside allowed root {ALLOWED_ROOT}"
)
return resolved


@ai.tool
async def read_file(path: str) -> str | ai.messages.ContentOutput:
"""Read a file from disk.

Image files come back as a ContentOutput so the model can view them.
"""
data = _resolve_within_allowed(path).read_bytes()
image_type = media.detect_image_media_type(data)
if image_type is not None:
return ai.content_output(
f"Loaded {path} ({image_type}, {len(data)} bytes).",
ai.file_part(data, media_type=image_type),
)
return data.decode("utf-8", errors="replace")


async def main() -> None:
model = ai.get_model("gateway:anthropic/claude-sonnet-4.6")
my_agent = ai.agent(tools=[read_file])

here = pathlib.Path(__file__).parent
image_path = here / "sample_image.jpg"
text_path = here / "agent_simple.py"

messages = [
ai.system_message(
"Use the read_file tool to inspect any files the user mentions."
),
ai.user_message(
f"First read {image_path} and describe what you see in the "
f"picture. Then read {text_path} and summarize what the "
f"script does in one sentence."
),
]

async with my_agent.run(model, messages) as stream:
async for event in stream:
if isinstance(event, ai.events.TextDelta):
print(event.chunk, end="", flush=True)
elif isinstance(event, ai.events.ToolEnd):
args = json.loads(event.tool_call.tool_args or "{}")
print(f"\n[read_file({args.get('path')!r})]")
elif isinstance(event, ai.events.StreamEnd):
print()
print()


if __name__ == "__main__":
asyncio.run(main())
4 changes: 4 additions & 0 deletions src/ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@
from .types import events, messages, tools
from .types.builders import (
assistant_message,
content_output,
file_part,
system_message,
text_part,
thinking,
tool_message,
tool_result_part,
Expand Down Expand Up @@ -119,6 +121,7 @@
"agent",
"assistant_message",
"cancel_hook",
"content_output",
"errors",
# Submodules
"events",
Expand All @@ -137,6 +140,7 @@
"resolve_hook",
"stream",
"system_message",
"text_part",
"thinking",
"tool",
"tool_message",
Expand Down
14 changes: 9 additions & 5 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def _error_tool_result(
types.messages.ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=f"{type(unwrapped).__name__}: {unwrapped}",
is_error=True,
result=types.messages.ErrorTextOutput(
value=f"{type(unwrapped).__name__}: {unwrapped}"
),
),
exception=unwrapped,
)
Expand Down Expand Up @@ -174,6 +175,8 @@ def _populate_model_inputs(
Tool execution sets ``model_input`` directly; this fills in the
value for tool results that were reconstructed from a wire round-
trip (e.g. the AI SDK UI inbound path) and never had it computed.
The aggregator's ``model_input_from_result`` does any snapshot
unwrapping internally.
"""
for msg in messages:
if msg.role != "tool":
Expand All @@ -187,7 +190,7 @@ def _populate_model_inputs(
agg_cls = _aggregator_cls(tool.aggregator)
if agg_cls is None:
continue
part.set_model_input(agg_cls.to_model_input(part.result))
part.set_model_input(agg_cls.model_input_from_result(part.result))


class SimpleAggregator[Item, Result](events_.Aggregator[Item, Result, Result]):
Expand Down Expand Up @@ -1038,8 +1041,9 @@ def pending_tool_result(
part = types.messages.ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=f"Pending on hook {hook.hook_id!r}",
is_error=True,
result=types.messages.ErrorTextOutput(
value=f"Pending on hook {hook.hook_id!r}"
),
is_hook_pending=True,
)
msg = types.messages.Message(role="tool", parts=[part])
Expand Down
13 changes: 8 additions & 5 deletions src/ai/agents/ui/ai_sdk/inbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,20 +85,22 @@ def _build_result_part(
output: Any,
is_error: bool,
) -> messages_.ToolResultPart:
result: messages_.ToolResultOutput
if is_error:
result: Any = output
text = str(output) if output is not None else ""
result = messages_.ErrorTextOutput(value=text)
else:
decoded = _decode_wire_output(output)
result = (
raw = (
decoded
if isinstance(decoded, MessageBundle)
else _normalize_tool_result(decoded)
)
result = messages_.coerce_to_output(raw)
return messages_.ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
result=result,
is_error=is_error,
)


Expand Down Expand Up @@ -189,8 +191,9 @@ def _patch_pending_hook_aborts(
messages_.ToolResultPart(
tool_call_id=tc.tool_call_id,
tool_name=tc.tool_name,
result=f"Pending on hook '{hook.hook_id}'",
is_error=True,
result=messages_.ErrorTextOutput(
value=f"Pending on hook '{hook.hook_id}'"
),
is_hook_pending=True,
)
)
Expand Down
31 changes: 26 additions & 5 deletions src/ai/agents/ui/ai_sdk/outbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
from typing import Any, cast

from ....types import media
Expand Down Expand Up @@ -107,6 +108,29 @@ def dedupe_tool_parts(
return result


def _output_view(
output: messages_.ToolResultOutput,
) -> tuple[str, dict[str, Any]]:
"""Map a :class:`ToolResultOutput` to ``(state, field_updates)``."""
match output:
case messages_.TextOutput(value=value):
return "output-available", {"output": value}
case messages_.JsonOutput(value=value):
return "output-available", {"output": value}
case messages_.ContentOutput(value=items):
return "output-available", {
"output": [item.model_dump(mode="json") for item in items]
}
case messages_.ErrorTextOutput(value=value):
return "output-error", {"error_text": value}
case messages_.ErrorJsonOutput(value=value):
return "output-error", {"error_text": json.dumps(value)}
case messages_.ExecutionDeniedOutput(reason=reason):
return "output-denied", {
"error_text": reason or "Tool execution denied."
}


def merge_tool_results(
ui_parts: list[ui_messages.UIMessagePart],
tool_parts: list[messages_.Part],
Expand All @@ -121,15 +145,12 @@ def merge_tool_results(
continue
case messages_.ToolResultPart():
tool_call_id = part.tool_call_id
state = "output-error" if part.is_error else "output-available"
state, field_updates = _output_view(part.result)
updates = {
"state": state,
"result_provider_metadata": part.provider_metadata,
**field_updates,
}
if part.is_error:
updates["error_text"] = str(part.result)
else:
updates["output"] = part.result
case messages_.BuiltinToolReturnPart():
tool_call_id = part.tool_call_id
updates = {
Expand Down
32 changes: 24 additions & 8 deletions src/ai/agents/ui/ai_sdk/outbound_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@

def _tool_error_text(part: messages_.ToolResultPart) -> str:
"""Best-effort error text extraction from a failed tool result."""
if isinstance(part.result, str) and part.result:
return part.result
if isinstance(part.result, dict):
for key in ("error", "message", "detail"):
value = part.result.get(key)
if isinstance(value, str) and value:
return value
output = part.result
if isinstance(output, messages_.ErrorTextOutput):
return output.value or "Tool execution failed"
if isinstance(output, messages_.ErrorJsonOutput):
value = output.value
if isinstance(value, str) and value:
return value
if isinstance(value, dict):
for key in ("error", "message", "detail"):
inner = value.get(key)
if isinstance(inner, str) and inner:
return inner
if isinstance(output, messages_.ExecutionDeniedOutput):
return output.reason or "Tool execution denied"
return "Tool execution failed"


Expand Down Expand Up @@ -403,7 +410,16 @@ def on_tool_result(
)
)
else:
wire_output = _to_wire_output(part.result)
output = part.result
raw = (
output.value
if isinstance(
output,
messages_.TextOutput | messages_.JsonOutput,
)
else output
)
wire_output = _to_wire_output(raw)
if wire_output is None:
# Aggregator produced no anchor (e.g. sub-agent
# tool that yielded nothing). Skip the final
Expand Down
68 changes: 52 additions & 16 deletions src/ai/providers/ai_gateway/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,57 @@ def _file_part_to_wire(part: types.messages.FilePart) -> dict[str, Any]:
return {"type": "file", "data": b64, "mediaType": part.media_type}


# ---------------------------------------------------------------------------
# Tool result output -> v3 wire
# ---------------------------------------------------------------------------


def _file_part_to_v3_inline(part: types.messages.FilePart) -> dict[str, Any]:
"""Convert a :class:`FilePart` to an inline v3 content element.

Images become ``image-data``; everything else becomes ``file-data``.
"""
b64 = types.media.data_to_base64(part.data)
if part.media_type.startswith("image/"):
return {"type": "image-data", "data": b64, "mediaType": part.media_type}
entry: dict[str, Any] = {
"type": "file-data",
"data": b64,
"mediaType": part.media_type,
}
if part.filename is not None:
entry["filename"] = part.filename
return entry


def _tool_result_output(
output: types.messages.ToolResultOutput,
) -> dict[str, Any]:
"""Convert a :class:`ToolResultOutput` to its v3 ``output`` wire form."""
match output:
case types.messages.TextOutput(value=value):
return {"type": "text", "value": value}
case types.messages.JsonOutput(value=value):
return {"type": "json", "value": value}
case types.messages.ErrorTextOutput(value=value):
return {"type": "error-text", "value": value}
case types.messages.ErrorJsonOutput(value=value):
return {"type": "error-json", "value": value}
case types.messages.ExecutionDeniedOutput(reason=reason):
entry: dict[str, Any] = {"type": "execution-denied"}
if reason is not None:
entry["reason"] = reason
return entry
case types.messages.ContentOutput(value=items):
parts: list[dict[str, Any]] = []
for item in items:
if isinstance(item, types.messages.FilePart):
parts.append(_file_part_to_v3_inline(item))
else:
parts.append({"type": "text", "text": item.text})
return {"type": "content", "value": parts}


# ---------------------------------------------------------------------------
# Streaming request building — Message list → v3 prompt
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -172,22 +223,7 @@ async def _messages_to_prompt(
tool_results: list[dict[str, Any]] = []
for part in msg.parts:
if isinstance(part, types.messages.ToolResultPart):
model_input = part.get_model_input()
output = (
{
"type": "error-text",
"value": (
str(model_input)
if model_input is not None
else ""
),
}
if part.is_error
else {
"type": "json",
"value": model_input,
}
)
output = _tool_result_output(part.get_model_input())
tool_results.append(
{
"type": "tool-result",
Expand Down
Loading
Loading