Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,19 +603,22 @@ async def _real(
else:
result = await tool.fn(**kwargs)
model_input = result
# Built inside the try so a non-serializable result (which
# ToolResultPart rejects) surfaces as a tool error rather than
# crashing the run.
part = types.messages.ToolResultPart(
tool_call_id=call.tool_call_id,
tool_name=call.tool_name,
result=result,
result_kind=types.messages.ToolResultPart.kind_for(result),
)
part.set_model_input(model_input)
except Exception as exc:
return _error_tool_result(
exc,
tool_call_id=call.tool_call_id,
tool_name=call.tool_name,
)
part = types.messages.ToolResultPart(
tool_call_id=call.tool_call_id,
tool_name=call.tool_name,
result=result,
result_kind=types.messages.ToolResultPart.kind_for(result),
)
part.set_model_input(model_input)
return tool_result(part)

chain = middleware_._build_tool_chain(_real)
Expand Down
57 changes: 38 additions & 19 deletions src/ai/types/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Annotated, Any, Literal, Self, overload

import pydantic
from pydantic_core import to_jsonable_python

from . import media
from . import usage as usage_
Expand Down Expand Up @@ -139,6 +140,22 @@ class MessageBundle(pydantic.BaseModel):
)


def _jsonify_result(value: Any) -> Any:
"""Reduce a tool-result value to JSON-y data.

:class:`ContentOutput` and :class:`MessageBundle` are kept as typed
models -- providers and the UI adapter dispatch on ``isinstance``.
Everything else, including any other pydantic model, is dumped to plain
JSON-y Python so a tool result never carries an arbitrary model and its
in-memory shape matches what survives a serialization round-trip.
"""
if isinstance(value, SpecialToolResult):
return value
# Raise (rather than stringify) on a value that isn't JSON-serializable:
# a tool returning such a result is a bug worth surfacing, not hiding.
return to_jsonable_python(value)


_MODEL_INPUT_UNSET: Any = object()

# Coarse tag for the shape of ``ToolResultPart.result``.
Expand Down Expand Up @@ -179,26 +196,24 @@ class ToolResultPart(pydantic.BaseModel):

@pydantic.model_validator(mode="before")
@classmethod
def _restore_content(cls, data: Any) -> Any:
"""Rebuild a typed :class:`SpecialToolResult` after a JSON round-trip.

``result`` is ``Any``, so pydantic restores a serialized
``ContentOutput`` / ``MessageBundle`` as a plain dict. When
``result_kind`` is ``"special"``, coerce it back to the typed result
so providers (and the UI adapter) can rely on ``isinstance`` checks.
def _normalize_result(cls, data: Any) -> Any:
"""Normalize ``result`` to its stored invariant.

A serialized special result (a dict tagged ``result_kind="special"``)
is rebuilt into its :class:`ContentOutput` / :class:`MessageBundle`
model so providers and the UI adapter can rely on ``isinstance``.
Any other value is reduced to JSON-y data -- a tool result never
stores an arbitrary pydantic model (see :func:`_jsonify_result`).
"""
if (
isinstance(data, dict)
and data.get("result_kind") == "special"
and isinstance(data.get("result"), dict)
):
data = {
if not isinstance(data, dict) or "result" not in data:
return data
result = data["result"]
if data.get("result_kind") == "special" and isinstance(result, dict):
return {
**data,
"result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python(
data["result"]
),
"result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python(result),
}
return data
return {**data, "result": _jsonify_result(result)}

@staticmethod
def kind_for(result: Any) -> ResultKind:
Expand All @@ -222,8 +237,12 @@ def get_model_input(self) -> Any:
return self._model_input

def set_model_input(self, value: Any) -> None:
"""Set the model-facing value (overrides the ``result`` fallback)."""
self._model_input = value
"""Set the model-facing value (overrides the ``result`` fallback).

Reduced to JSON-y data like ``result`` so the model never sees an
arbitrary pydantic model (see :func:`_jsonify_result`).
"""
self._model_input = _jsonify_result(value)

@property
def has_model_input(self) -> bool:
Expand Down
19 changes: 19 additions & 0 deletions tests/agents/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,25 @@ async def fail(x: int) -> int:
assert str(result.exception) == "boom"


async def test_tool_call_unserializable_result_becomes_error() -> None:
"""A tool returning a non-JSON-able value yields a tool error, not crash."""

@ai.tool
async def make_widget() -> object:
"""Returns something that can't be serialized."""
return object()

part = ai.messages.ToolCallPart(
tool_call_id="tc-widget",
tool_name="make_widget",
tool_args="{}",
)
tc = ai.agents.BoundToolCall(part=part, tool=make_widget)
result = await tc()

assert result.results[0].is_error


async def test_tool_call_unwraps_singleton_exceptiongroup() -> None:
"""When a tool's body raises an ExceptionGroup wrapping a single
exception (typical when it runs an asyncio TaskGroup internally),
Expand Down
57 changes: 57 additions & 0 deletions tests/types/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import Any

import pydantic
import pytest

from ai.types import messages, usage
Expand Down Expand Up @@ -143,6 +144,62 @@ def test_tool_result_content_output_with_file_part_round_trip() -> None:
assert file_part.media_type == "image/png"


class _Widget(pydantic.BaseModel):
a: int
b: str


def test_tool_result_dumps_non_special_pydantic_models() -> None:
"""A returned pydantic model is stored as JSON-y data, never the model.

Only the special results (:class:`ContentOutput`, :class:`MessageBundle`)
stay typed; everything else is dumped so the in-memory shape matches what
survives a round-trip.
"""
trp = messages.ToolResultPart(
tool_call_id="tc", tool_name="t", result=_Widget(a=1, b="hi")
)
assert trp.result == {"a": 1, "b": "hi"}
assert not isinstance(trp.result, pydantic.BaseModel)
assert trp.result_kind == "json"

# Models nested inside a container are dumped recursively.
nested = messages.ToolResultPart(
tool_call_id="tc",
tool_name="t",
result={"w": _Widget(a=2, b="x"), "n": [_Widget(a=3, b="y")]},
)
assert nested.result == {"w": {"a": 2, "b": "x"}, "n": [{"a": 3, "b": "y"}]}


def test_tool_result_model_input_is_dumped() -> None:
"""The model-facing value is dumped too, so the LLM never sees a model."""
trp = messages.ToolResultPart(tool_call_id="tc", tool_name="t")
trp.set_model_input(_Widget(a=7, b="z"))
assert trp.get_model_input() == {"a": 7, "b": "z"}


def test_tool_result_special_values_kept_typed() -> None:
"""ContentOutput / MessageBundle results are not dumped."""
co = messages.ContentOutput(value=[messages.TextPart(text="hi")])
trp = messages.ToolResultPart(
tool_call_id="tc", tool_name="t", result=co, result_kind="special"
)
assert isinstance(trp.result, messages.ContentOutput)


def test_tool_result_rejects_unserializable_result() -> None:
"""A non-JSON-able, non-special result is rejected, not stringified."""

class Weird:
pass

with pytest.raises(pydantic.ValidationError):
messages.ToolResultPart(
tool_call_id="tc", tool_name="t", result=Weird()
)


def test_tool_result_plain_values_stored_raw() -> None:
"""Plain str / dict / list / None results are stored as-is and round-trip.

Expand Down
Loading