diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 91819ea9..6342b2ea 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -28,9 +28,7 @@ async def main() -> None: async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: """Stream, execute tools with logging, repeat.""" while True: - s = ai.models.stream( - context.model, context.messages, tools=context.tools - ) + s = ai.models.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event diff --git a/src/ai/models/ai_gateway/adapter.py b/src/ai/models/ai_gateway/adapter.py index 8ebbba25..b414d41b 100644 --- a/src/ai/models/ai_gateway/adapter.py +++ b/src/ai/models/ai_gateway/adapter.py @@ -210,7 +210,9 @@ def _expand_tool_call( return [ types.events.ToolStart(tool_call_id=tc_id, tool_name=tool_name), types.events.ToolDelta(tool_call_id=tc_id, chunk=args_str), - types.events.ToolEnd(tool_call_id=tc_id), + types.events.ToolEnd( + tool_call_id=tc_id, tool_call=types.messages.DUMMY_TOOL_CALL + ), ] @@ -293,7 +295,12 @@ def _parse_stream_part( ] case "tool-input-end": - return [types.events.ToolEnd(tool_call_id=data.get("id", ""))] + return [ + types.events.ToolEnd( + tool_call_id=data.get("id", ""), + tool_call=types.messages.DUMMY_TOOL_CALL, + ) + ] case "tool-call": return _expand_tool_call(data, streamed_tool_ids) diff --git a/src/ai/models/anthropic/adapter.py b/src/ai/models/anthropic/adapter.py index 20ffe6d7..e4448067 100644 --- a/src/ai/models/anthropic/adapter.py +++ b/src/ai/models/anthropic/adapter.py @@ -13,6 +13,7 @@ from ... import types from ...types import events +from ...types import messages as messages_ from .. import core # --------------------------------------------------------------------------- @@ -352,7 +353,10 @@ async def stream( case "tool_use": tool_id = tool_ids.get(idx) if tool_id: - yield events.ToolEnd(tool_call_id=tool_id) + yield events.ToolEnd( + tool_call_id=tool_id, + tool_call=messages_.DUMMY_TOOL_CALL, + ) snapshot = sdk_stream.current_message_snapshot sdk_usage = snapshot.usage diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 3f4e5c78..72dc64bf 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -84,8 +84,8 @@ def __aiter__(self) -> Self: async def __anext__(self) -> types.Event: event = await self._gen.__anext__() - self._aggregate_event(event) - return event.model_copy(update={"message": self._message}) + updates = self._aggregate_event(event) + return event.model_copy(update={"message": self._message, **updates}) @property def message(self) -> types.Message: @@ -107,7 +107,9 @@ def tool_calls(self) -> list[types.ToolCallPart]: def output(self) -> Any: return self._message.output - def _aggregate_event(self, event: types.Event) -> None: + def _aggregate_event(self, event: types.Event) -> dict[str, Any]: + updates: dict[str, Any] = {} + # grab usage from any event that carries one if event.usage is not None: self._message.usage = event.usage @@ -149,6 +151,10 @@ def _aggregate_event(self, event: types.Event) -> None: existing_tool = self._parts.get(tcid) if isinstance(existing_tool, types.ToolCallPart): existing_tool.tool_args += c + case types.ToolEnd(tool_call_id=tcid): + existing_tool = self._parts.get(tcid) + if isinstance(existing_tool, types.ToolCallPart): + updates["tool_call"] = existing_tool case types.FileEvent(block_id=bid, media_type=mt, data=d, filename=fname): fp = types.FilePart( id=bid or types.generate_id(), @@ -161,6 +167,8 @@ def _aggregate_event(self, event: types.Event) -> None: case _: pass + return updates + def stream( model: model_.Model, diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index b7f81df8..ab9e6499 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -349,7 +349,10 @@ async def stream( text_started = False for tc in tc_state.values(): if tc["started"] and tc["id"]: - yield types.events.ToolEnd(tool_call_id=tc["id"]) + yield types.events.ToolEnd( + tool_call_id=tc["id"], + tool_call=types.messages.DUMMY_TOOL_CALL, + ) tc["started"] = False yield types.events.StreamEnd(usage=usage) diff --git a/src/ai/types/events.py b/src/ai/types/events.py index a0d5e30b..8ce0edae 100644 --- a/src/ai/types/events.py +++ b/src/ai/types/events.py @@ -86,6 +86,7 @@ class ToolDelta(BaseEvent): class ToolEnd(BaseEvent): + tool_call: messages.ToolCallPart tool_call_id: str = "" kind: Literal["tool_end"] = "tool_end" diff --git a/src/ai/types/messages.py b/src/ai/types/messages.py index 4be9d668..09c5c415 100644 --- a/src/ai/types/messages.py +++ b/src/ai/types/messages.py @@ -30,6 +30,11 @@ class ToolCallPart(pydantic.BaseModel): kind: Literal["tool_call"] = "tool_call" +DUMMY_TOOL_CALL = ToolCallPart( + id="", tool_call_id="", tool_name="", tool_args="" +) + + class ToolResultPart(pydantic.BaseModel): id: str = pydantic.Field(default_factory=generate_id) tool_call_id: str diff --git a/tests/conftest.py b/tests/conftest.py index 4bb74ab5..4be7dae2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,7 +128,7 @@ async def emit_events_for_messages( tool_call_id=part.tool_call_id, chunk=part.tool_args, ) - yield events_.ToolEnd(tool_call_id=part.tool_call_id) + yield events_.ToolEnd(tool_call_id=part.tool_call_id, tool_call=part) elif isinstance(part, messages_.FilePart): yield events_.FileEvent( diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 254b0d0a..d5fd859e 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -47,6 +47,43 @@ async def test_stream_basic() -> None: assert "".join(deltas) == "Hello world" +async def test_stream_tool_end_includes_aggregated_tool_call() -> None: + """ToolEnd exposes the full ToolCallPart assembled from streamed input.""" + + async def _tool_stream( + client: models.Client, + model: models.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[events_.Event]: + yield events_.StreamStart() + yield events_.ToolStart(tool_call_id="tc-1", tool_name="weather") + yield events_.ToolDelta(tool_call_id="tc-1", chunk='{"city"') + yield events_.ToolDelta(tool_call_id="tc-1", chunk=':"SF"}') + yield events_.ToolEnd( + tool_call_id="tc-1", + tool_call=messages_.DUMMY_TOOL_CALL, + ) + yield events_.StreamEnd() + + models.register_stream("mock", _tool_stream) + + s = models.stream(MOCK_MODEL, [ai.user_message("Check weather")]) + tool_end: events_.ToolEnd | None = None + async for event in s: + if isinstance(event, events_.ToolEnd): + tool_end = event + + assert tool_end is not None + assert tool_end.tool_call.tool_call_id == "tc-1" + assert tool_end.tool_call.tool_name == "weather" + assert tool_end.tool_call.tool_args == '{"city":"SF"}' + assert s.tool_calls == [tool_end.tool_call] + + async def test_stream_with_explicit_client() -> None: """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = []