diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index f6ecc06..6dd84c9 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -603,19 +603,22 @@ async def _real( else: result = await tool.fn(**kwargs) model_input = result + # Built inside the try so a non-serializable result (which + # ToolResultPart rejects) surfaces as a tool error rather than + # crashing the run. + part = types.messages.ToolResultPart( + tool_call_id=call.tool_call_id, + tool_name=call.tool_name, + result=result, + result_kind=types.messages.ToolResultPart.kind_for(result), + ) + part.set_model_input(model_input) except Exception as exc: return _error_tool_result( exc, tool_call_id=call.tool_call_id, tool_name=call.tool_name, ) - part = types.messages.ToolResultPart( - tool_call_id=call.tool_call_id, - tool_name=call.tool_name, - result=result, - result_kind=types.messages.ToolResultPart.kind_for(result), - ) - part.set_model_input(model_input) return tool_result(part) chain = middleware_._build_tool_chain(_real) diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 644e938..bc26e4e 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -3,6 +3,7 @@ from typing import Annotated, Any, Literal, Self, overload import pydantic +from pydantic_core import to_jsonable_python from . import media from . import usage as usage_ @@ -139,6 +140,22 @@ class MessageBundle(pydantic.BaseModel): ) +def _jsonify_result(value: Any) -> Any: + """Reduce a tool-result value to JSON-y data. + + :class:`ContentOutput` and :class:`MessageBundle` are kept as typed + models -- providers and the UI adapter dispatch on ``isinstance``. + Everything else, including any other pydantic model, is dumped to plain + JSON-y Python so a tool result never carries an arbitrary model and its + in-memory shape matches what survives a serialization round-trip. + """ + if isinstance(value, SpecialToolResult): + return value + # Raise (rather than stringify) on a value that isn't JSON-serializable: + # a tool returning such a result is a bug worth surfacing, not hiding. + return to_jsonable_python(value) + + _MODEL_INPUT_UNSET: Any = object() # Coarse tag for the shape of ``ToolResultPart.result``. @@ -179,26 +196,24 @@ class ToolResultPart(pydantic.BaseModel): @pydantic.model_validator(mode="before") @classmethod - def _restore_content(cls, data: Any) -> Any: - """Rebuild a typed :class:`SpecialToolResult` after a JSON round-trip. - - ``result`` is ``Any``, so pydantic restores a serialized - ``ContentOutput`` / ``MessageBundle`` as a plain dict. When - ``result_kind`` is ``"special"``, coerce it back to the typed result - so providers (and the UI adapter) can rely on ``isinstance`` checks. + def _normalize_result(cls, data: Any) -> Any: + """Normalize ``result`` to its stored invariant. + + A serialized special result (a dict tagged ``result_kind="special"``) + is rebuilt into its :class:`ContentOutput` / :class:`MessageBundle` + model so providers and the UI adapter can rely on ``isinstance``. + Any other value is reduced to JSON-y data -- a tool result never + stores an arbitrary pydantic model (see :func:`_jsonify_result`). """ - if ( - isinstance(data, dict) - and data.get("result_kind") == "special" - and isinstance(data.get("result"), dict) - ): - data = { + if not isinstance(data, dict) or "result" not in data: + return data + result = data["result"] + if data.get("result_kind") == "special" and isinstance(result, dict): + return { **data, - "result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python( - data["result"] - ), + "result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python(result), } - return data + return {**data, "result": _jsonify_result(result)} @staticmethod def kind_for(result: Any) -> ResultKind: @@ -222,8 +237,12 @@ def get_model_input(self) -> Any: return self._model_input def set_model_input(self, value: Any) -> None: - """Set the model-facing value (overrides the ``result`` fallback).""" - self._model_input = value + """Set the model-facing value (overrides the ``result`` fallback). + + Reduced to JSON-y data like ``result`` so the model never sees an + arbitrary pydantic model (see :func:`_jsonify_result`). + """ + self._model_input = _jsonify_result(value) @property def has_model_input(self) -> bool: diff --git a/tests/agents/test_tools.py b/tests/agents/test_tools.py index f28650f..08771e2 100644 --- a/tests/agents/test_tools.py +++ b/tests/agents/test_tools.py @@ -128,6 +128,25 @@ async def fail(x: int) -> int: assert str(result.exception) == "boom" +async def test_tool_call_unserializable_result_becomes_error() -> None: + """A tool returning a non-JSON-able value yields a tool error, not crash.""" + + @ai.tool + async def make_widget() -> object: + """Returns something that can't be serialized.""" + return object() + + part = ai.messages.ToolCallPart( + tool_call_id="tc-widget", + tool_name="make_widget", + tool_args="{}", + ) + tc = ai.agents.BoundToolCall(part=part, tool=make_widget) + result = await tc() + + assert result.results[0].is_error + + async def test_tool_call_unwraps_singleton_exceptiongroup() -> None: """When a tool's body raises an ExceptionGroup wrapping a single exception (typical when it runs an asyncio TaskGroup internally), diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index 08dee49..d357896 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -4,6 +4,7 @@ from typing import Any +import pydantic import pytest from ai.types import messages, usage @@ -143,6 +144,62 @@ def test_tool_result_content_output_with_file_part_round_trip() -> None: assert file_part.media_type == "image/png" +class _Widget(pydantic.BaseModel): + a: int + b: str + + +def test_tool_result_dumps_non_special_pydantic_models() -> None: + """A returned pydantic model is stored as JSON-y data, never the model. + + Only the special results (:class:`ContentOutput`, :class:`MessageBundle`) + stay typed; everything else is dumped so the in-memory shape matches what + survives a round-trip. + """ + trp = messages.ToolResultPart( + tool_call_id="tc", tool_name="t", result=_Widget(a=1, b="hi") + ) + assert trp.result == {"a": 1, "b": "hi"} + assert not isinstance(trp.result, pydantic.BaseModel) + assert trp.result_kind == "json" + + # Models nested inside a container are dumped recursively. + nested = messages.ToolResultPart( + tool_call_id="tc", + tool_name="t", + result={"w": _Widget(a=2, b="x"), "n": [_Widget(a=3, b="y")]}, + ) + assert nested.result == {"w": {"a": 2, "b": "x"}, "n": [{"a": 3, "b": "y"}]} + + +def test_tool_result_model_input_is_dumped() -> None: + """The model-facing value is dumped too, so the LLM never sees a model.""" + trp = messages.ToolResultPart(tool_call_id="tc", tool_name="t") + trp.set_model_input(_Widget(a=7, b="z")) + assert trp.get_model_input() == {"a": 7, "b": "z"} + + +def test_tool_result_special_values_kept_typed() -> None: + """ContentOutput / MessageBundle results are not dumped.""" + co = messages.ContentOutput(value=[messages.TextPart(text="hi")]) + trp = messages.ToolResultPart( + tool_call_id="tc", tool_name="t", result=co, result_kind="special" + ) + assert isinstance(trp.result, messages.ContentOutput) + + +def test_tool_result_rejects_unserializable_result() -> None: + """A non-JSON-able, non-special result is rejected, not stringified.""" + + class Weird: + pass + + with pytest.raises(pydantic.ValidationError): + messages.ToolResultPart( + tool_call_id="tc", tool_name="t", result=Weird() + ) + + def test_tool_result_plain_values_stored_raw() -> None: """Plain str / dict / list / None results are stored as-is and round-trip.