Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,13 @@ messages while the model input is the final assistant text.
## Rich snapshots

Aggregators can preserve more than text. `MessageAggregator` stores nested
messages from a subagent:
messages from a subagent. The rich snapshot type is
`ai.types.messages.MessageBundle`:

```python
if isinstance(event, ai.events.ToolCallResult):
result = event.results[0].result
if isinstance(result, ai.agents.MessageBundle):
if isinstance(result, ai.types.messages.MessageBundle):
print(result.messages[-1].text)
```

Expand Down
2 changes: 1 addition & 1 deletion src/ai/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ..types.messages import MessageBundle
from . import mcp, ui
from .agent import (
Agent,
Expand All @@ -9,7 +10,6 @@
GatedToolCall,
LastAggregator,
MessageAggregator,
MessageBundle,
SimpleAggregator,
StreamingStatusTool,
StreamingTextTool,
Expand Down
20 changes: 2 additions & 18 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
Any,
ClassVar,
Generic,
Literal,
Protocol,
Self,
cast,
Expand All @@ -39,6 +38,7 @@
from .. import models, types, util
from ..types import builders
from ..types import events as events_
from ..types.messages import MessageBundle
from . import _middleware as middleware_
from . import hooks as hooks_
from . import runtime
Expand All @@ -59,18 +59,6 @@ def _unwrap_singleton_group(exc: BaseException) -> BaseException:
return exc


def _result_kind(value: Any) -> Literal["json", "content"]:
"""Tag a successful tool return value for ``ToolResultPart.result_kind``.

A :class:`ContentOutput` becomes ``"content"`` (expanded into provider
multimodal blocks); everything else is ``"json"`` (the encoder sends a
``str`` raw and JSON-encodes anything else).
"""
if isinstance(value, types.messages.ContentOutput):
return "content"
return "json"


def _error_tool_result(
exc: BaseException,
*,
Expand Down Expand Up @@ -232,10 +220,6 @@ def snapshot(self) -> T | None:
return self._val


class MessageBundle(pydantic.BaseModel):
messages: tuple[types.messages.Message, ...]


class MessageAggregator(
events_.Aggregator[events_.AgentEvent, MessageBundle, str]
):
Expand Down Expand Up @@ -629,7 +613,7 @@ async def _real(
tool_call_id=call.tool_call_id,
tool_name=call.tool_name,
result=result,
result_kind=_result_kind(result),
result_kind=types.messages.ToolResultPart.kind_for(result),
)
part.set_model_input(model_input)
return tool_result(part)
Expand Down
13 changes: 8 additions & 5 deletions src/ai/agents/ui/ai_sdk/id_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,18 @@ def _restore_message_ids(
def _tool_result_kinds(
source_messages: list[messages_.Message],
) -> dict[str, str]:
"""Collect ``{tool_call_id: result_kind}`` for content tool results."""
"""Collect ``{tool_call_id: subtype}`` for special tool results.

The recorded value is the :class:`SpecialToolResult` discriminator so the
inbound side can rehydrate the typed result without shape-sniffing it.
"""
kinds: dict[str, str] = {}
for message in source_messages:
for part in message.parts:
if (
isinstance(part, messages_.ToolResultPart)
and part.result_kind == "content"
if isinstance(part, messages_.ToolResultPart) and isinstance(
part.result, messages_.SpecialToolResult
):
kinds[part.tool_call_id] = part.result_kind
kinds[part.tool_call_id] = part.result.type
return kinds


Expand Down
54 changes: 21 additions & 33 deletions src/ai/agents/ui/ai_sdk/inbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing import Any

from ....types import messages as messages_
from ...agent import MessageBundle
from ....types.messages import MessageBundle
from . import approvals, id_utils
from . import ui_messages as ui_messages_
from .approvals import ApprovalResponse, extract_approvals
Expand Down Expand Up @@ -57,27 +57,6 @@ def _error_result(error_text: str | None, output: Any) -> dict[str, Any] | None:
return normalized


def _decode_wire_output(output: Any) -> Any:
"""Reconstruct the internal snapshot type from a wire tool output.

Hacky special case: when the wire output looks like a ``UIMessage``
(the wire shape we emit for sub-agent / ``MessageAggregator`` tools),
decode it back to a ``MessageBundle``. Other shapes pass through
unchanged. This avoids requiring callers to thread the tool
registry into inbound parsing.
"""
if not isinstance(output, dict):
return output
if output.get("role") != "assistant" or "parts" not in output:
return output
try:
ui_msg = ui_messages_.UIMessage.model_validate(output)
except Exception:
return output
inner = list(_parse([ui_msg]))
return MessageBundle(messages=tuple(inner))


def _build_result_part(
*,
tool_call_id: str,
Expand All @@ -89,10 +68,14 @@ def _build_result_part(
"""Reconstruct a tool result from its wire form.

``kind_hint`` comes from the adapter's ``toolResultKinds`` metadata
(see :mod:`id_utils`). When it marks the result as ``content``, the
``output`` -- a list of dumped content parts -- is rehydrated into a
typed :class:`ContentOutput` so providers re-expand it into multimodal
blocks; otherwise behaviour matches a plain value round-trip.
(see :mod:`id_utils`) and names the :class:`SpecialToolResult` subtype:

* ``"content"`` rehydrates a :class:`ContentOutput` from the dumped
content parts so providers re-expand it into multimodal blocks;
* ``"messages"`` rebuilds a :class:`MessageBundle` by parsing the
carried sub-agent UIMessage(s).

Without a hint the output is treated as a plain value round-trip.
"""
result: Any
result_kind: messages_.ResultKind
Expand All @@ -101,14 +84,19 @@ def _build_result_part(
result_kind = "error"
elif kind_hint == "content":
result = messages_.ContentOutput.model_validate({"value": output})
result_kind = "content"
result_kind = "special"
elif kind_hint == "messages":
raw = output if isinstance(output, list) else [output]
ui_msgs = [
m
if isinstance(m, ui_messages_.UIMessage)
else ui_messages_.UIMessage.model_validate(m)
for m in raw
]
result = MessageBundle(messages=tuple(_parse(ui_msgs)))
result_kind = "special"
else:
decoded = _decode_wire_output(output)
result = (
decoded
if isinstance(decoded, MessageBundle)
else _normalize_tool_result(decoded)
)
result = _normalize_tool_result(output)
result_kind = "json"
return messages_.ToolResultPart(
tool_call_id=tool_call_id,
Expand Down
23 changes: 23 additions & 0 deletions src/ai/agents/ui/ai_sdk/outbound_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,23 @@ def dedupe_tool_parts(
return result


def bundle_to_wire_output(bundle: messages_.MessageBundle) -> Any:
"""Serialize a sub-agent transcript to its UI tool ``output``.

Follows the AI SDK sub-agent convention of a single ``UIMessage`` for the
common case (one bubble), and only falls back to a JSON list when the
transcript spans multiple bubbles. Returns ``None`` for an empty bundle
so streaming callers can skip emitting until there's something to show.
The inbound side accepts either shape (see ``_build_result_part``).
"""
dumped = [
m.model_dump(mode="json") for m in to_ui_messages(list(bundle.messages))
]
if not dumped:
return None
return dumped[0] if len(dumped) == 1 else dumped


def _output_view(
part: messages_.ToolResultPart,
) -> tuple[str, dict[str, Any]]:
Expand All @@ -117,6 +134,12 @@ def _output_view(
return "output-available", {
"output": [item.model_dump(mode="json") for item in result.value]
}
if isinstance(result, messages_.MessageBundle):
# `None` (empty bundle) becomes `[]` so a completed result still
# round-trips to an (empty) MessageBundle rather than a null output.
return "output-available", {
"output": bundle_to_wire_output(result) or []
}
if part.is_error:
text = result if isinstance(result, str) else json.dumps(result)
return "output-error", {"error_text": text}
Expand Down
19 changes: 10 additions & 9 deletions src/ai/agents/ui/ai_sdk/outbound_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ....types import events as events_
from ....types import media
from ....types import messages as messages_
from ...agent import MessageBundle
from ....types.messages import MessageBundle
from . import approvals, outbound_messages, ui_events
from .tool_utils import normalize_tool_input

Expand All @@ -35,17 +35,18 @@ def _tool_error_text(part: messages_.ToolResultPart) -> str:
def _to_wire_output(snapshot: Any) -> Any:
"""Convert an aggregator snapshot to its UI wire representation.

For ``MessageBundle`` (sub-agent transcripts) this produces a single
``UIMessage`` assistant bubble — the canonical AI SDK shape. Other
snapshot types pass through unchanged.
For ``MessageBundle`` (sub-agent transcripts) this follows the AI SDK
sub-agent convention -- a single ``UIMessage`` for the common one-bubble
case, a JSON list only when the transcript spans multiple bubbles -- and
is paired with a ``toolResultKinds`` ``"messages"`` hint so the inbound
side can rebuild the bundle. Other snapshot types pass through unchanged.

Returns ``None`` if the bundle has no assistant anchor yet (e.g. a
streaming sub-agent that has produced no messages); callers should
skip emitting in that case.
Returns ``None`` if the bundle has no messages yet (e.g. a streaming
sub-agent that has produced nothing); callers should skip emitting in
that case.
"""
if isinstance(snapshot, MessageBundle):
ui_msgs = outbound_messages.to_ui_messages(list(snapshot.messages))
return ui_msgs[-1] if ui_msgs else None
return outbound_messages.bundle_to_wire_output(snapshot)
return snapshot


Expand Down
15 changes: 6 additions & 9 deletions src/ai/types/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,18 +241,15 @@ def tool_result_part(
"""Create a :class:`ToolResultPart`.

``result`` is stored as-is; ``result_kind`` is derived: ``"error"`` when
``is_error`` is set, ``"content"`` for a :class:`ContentOutput`, else
``"json"`` (a ``str`` is sent raw to the model, anything else is
JSON-encoded at the provider boundary).
``is_error`` is set, ``"special"`` for a :class:`ContentOutput` or
:class:`MessageBundle`, else ``"json"`` (a ``str`` is sent raw to the
model, anything else is JSON-encoded at the provider boundary).

>>> ai.tool_result_part("tc-1", result={"temp": 72}, tool_name="weather")
"""
if is_error:
result_kind: ResultKind = "error"
elif isinstance(result, ContentOutput):
result_kind = "content"
else:
result_kind = "json"
result_kind: ResultKind = (
"error" if is_error else ToolResultPart.kind_for(result)
)
return ToolResultPart(
tool_call_id=tool_call_id,
tool_name=tool_name,
Expand Down
54 changes: 42 additions & 12 deletions src/ai/types/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def from_bytes(
# ---------------------------------------------------------------------------
# Multipart tool result -- a tool may return a mix of text and file/image
# parts so the model sees actual media. Stored on ``ToolResultPart.result``
# with ``result_kind="content"``; providers expand it into their multimodal
# with ``result_kind="special"``; providers expand it into their multimodal
# wire format.
# ---------------------------------------------------------------------------

Expand All @@ -122,13 +122,31 @@ class ContentOutput(pydantic.BaseModel):
model_config = pydantic.ConfigDict(frozen=True)


class MessageBundle(pydantic.BaseModel):
type: Literal["messages"] = "messages"
messages: tuple["Message", ...]


SpecialToolResult = ContentOutput | MessageBundle

_SPECIAL_TOOL_RESULT_ADAPTER: pydantic.TypeAdapter[SpecialToolResult] = (
pydantic.TypeAdapter(
Annotated[
SpecialToolResult,
pydantic.Field(discriminator="type"),
]
)
)


_MODEL_INPUT_UNSET: Any = object()

# Coarse tag for the shape of ``ToolResultPart.result``. ``"content"`` means
# a :class:`ContentOutput`; ``"error"`` flags an error result; ``"json"`` (the
# default) is any plain value. Providers decide text-vs-json at the wire
# boundary (a ``str`` is sent raw, everything else is JSON-encoded).
ResultKind = Literal["error", "json", "content"]
# Coarse tag for the shape of ``ToolResultPart.result``.
# ``"special"`` means a :class:`SpecialToolResult`; ``"error"`` flags
# an error result; ``"json"`` (the default) is any plain value.
# Providers decide text-vs-json at the wire boundary (a ``str`` is
# sent raw, everything else is JSON-encoded).
ResultKind = Literal["error", "json", "special"]


class ToolResultPart(pydantic.BaseModel):
Expand Down Expand Up @@ -162,24 +180,36 @@ class ToolResultPart(pydantic.BaseModel):
@pydantic.model_validator(mode="before")
@classmethod
def _restore_content(cls, data: Any) -> Any:
"""Rebuild a typed :class:`ContentOutput` after a JSON round-trip.
"""Rebuild a typed :class:`SpecialToolResult` after a JSON round-trip.

``result`` is ``Any``, so pydantic restores a serialized
``ContentOutput`` as a plain dict. When ``result_kind`` says the
result is content, coerce it back so providers (and the UI adapter)
can rely on ``isinstance(result, ContentOutput)``.
``ContentOutput`` / ``MessageBundle`` as a plain dict. When
``result_kind`` is ``"special"``, coerce it back to the typed result
so providers (and the UI adapter) can rely on ``isinstance`` checks.
"""
if (
isinstance(data, dict)
and data.get("result_kind") == "content"
and data.get("result_kind") == "special"
and isinstance(data.get("result"), dict)
):
data = {
**data,
"result": ContentOutput.model_validate(data["result"]),
"result": _SPECIAL_TOOL_RESULT_ADAPTER.validate_python(
data["result"]
),
}
return data

@staticmethod
def kind_for(result: Any) -> ResultKind:
"""Derive ``result_kind`` for a non-error result value.

A :data:`SpecialToolResult` is ``"special"``; anything else is
``"json"``. Error results are tagged ``"error"`` by the
caller, independent of the value.
"""
return "special" if isinstance(result, SpecialToolResult) else "json"

@property
def is_error(self) -> bool:
"""Whether this result represents an error to the model."""
Expand Down
2 changes: 1 addition & 1 deletion tests/agents/test_generator_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import ai
from ai import models
from ai.agents.agent import MessageBundle
from ai.types import events as agent_events_
from ai.types import events as events_
from ai.types import messages as messages_
from ai.types.messages import MessageBundle

from ..conftest import (
MOCK_MODEL,
Expand Down
Loading
Loading