Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions examples/samples/agent_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ async def main() -> None:
async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]:
"""Stream, execute tools with logging, repeat."""
while True:
s = ai.models.stream(
context.model, context.messages, tools=context.tools
)
s = ai.models.stream(context.model, context.messages, tools=context.tools)
async for event in s:
yield event

Expand Down
11 changes: 9 additions & 2 deletions src/ai/models/ai_gateway/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def _expand_tool_call(
return [
types.events.ToolStart(tool_call_id=tc_id, tool_name=tool_name),
types.events.ToolDelta(tool_call_id=tc_id, chunk=args_str),
types.events.ToolEnd(tool_call_id=tc_id),
types.events.ToolEnd(
tool_call_id=tc_id, tool_call=types.messages.DUMMY_TOOL_CALL
),
]


Expand Down Expand Up @@ -293,7 +295,12 @@ def _parse_stream_part(
]

case "tool-input-end":
return [types.events.ToolEnd(tool_call_id=data.get("id", ""))]
return [
types.events.ToolEnd(
tool_call_id=data.get("id", ""),
tool_call=types.messages.DUMMY_TOOL_CALL,
)
]

case "tool-call":
return _expand_tool_call(data, streamed_tool_ids)
Expand Down
6 changes: 5 additions & 1 deletion src/ai/models/anthropic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from ... import types
from ...types import events
from ...types import messages as messages_
from .. import core

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -352,7 +353,10 @@ async def stream(
case "tool_use":
tool_id = tool_ids.get(idx)
if tool_id:
yield events.ToolEnd(tool_call_id=tool_id)
yield events.ToolEnd(
tool_call_id=tool_id,
tool_call=messages_.DUMMY_TOOL_CALL,
)

snapshot = sdk_stream.current_message_snapshot
sdk_usage = snapshot.usage
Expand Down
14 changes: 11 additions & 3 deletions src/ai/models/core/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __aiter__(self) -> Self:

async def __anext__(self) -> types.Event:
event = await self._gen.__anext__()
self._aggregate_event(event)
return event.model_copy(update={"message": self._message})
updates = self._aggregate_event(event)
return event.model_copy(update={"message": self._message, **updates})

@property
def message(self) -> types.Message:
Expand All @@ -107,7 +107,9 @@ def tool_calls(self) -> list[types.ToolCallPart]:
def output(self) -> Any:
return self._message.output

def _aggregate_event(self, event: types.Event) -> None:
def _aggregate_event(self, event: types.Event) -> dict[str, Any]:
updates: dict[str, Any] = {}

# grab usage from any event that carries one
if event.usage is not None:
self._message.usage = event.usage
Expand Down Expand Up @@ -149,6 +151,10 @@ def _aggregate_event(self, event: types.Event) -> None:
existing_tool = self._parts.get(tcid)
if isinstance(existing_tool, types.ToolCallPart):
existing_tool.tool_args += c
case types.ToolEnd(tool_call_id=tcid):
existing_tool = self._parts.get(tcid)
if isinstance(existing_tool, types.ToolCallPart):
updates["tool_call"] = existing_tool
case types.FileEvent(block_id=bid, media_type=mt, data=d, filename=fname):
fp = types.FilePart(
id=bid or types.generate_id(),
Expand All @@ -161,6 +167,8 @@ def _aggregate_event(self, event: types.Event) -> None:
case _:
pass

return updates


def stream(
model: model_.Model,
Expand Down
5 changes: 4 additions & 1 deletion src/ai/models/openai/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ async def stream(
text_started = False
for tc in tc_state.values():
if tc["started"] and tc["id"]:
yield types.events.ToolEnd(tool_call_id=tc["id"])
yield types.events.ToolEnd(
tool_call_id=tc["id"],
tool_call=types.messages.DUMMY_TOOL_CALL,
)
tc["started"] = False

yield types.events.StreamEnd(usage=usage)
Expand Down
1 change: 1 addition & 0 deletions src/ai/types/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class ToolDelta(BaseEvent):


class ToolEnd(BaseEvent):
tool_call: messages.ToolCallPart
tool_call_id: str = ""

kind: Literal["tool_end"] = "tool_end"
Expand Down
5 changes: 5 additions & 0 deletions src/ai/types/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ class ToolCallPart(pydantic.BaseModel):
kind: Literal["tool_call"] = "tool_call"


DUMMY_TOOL_CALL = ToolCallPart(
id="<invalid>", tool_call_id="", tool_name="", tool_args=""
)


class ToolResultPart(pydantic.BaseModel):
id: str = pydantic.Field(default_factory=generate_id)
tool_call_id: str
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def emit_events_for_messages(
tool_call_id=part.tool_call_id,
chunk=part.tool_args,
)
yield events_.ToolEnd(tool_call_id=part.tool_call_id)
yield events_.ToolEnd(tool_call_id=part.tool_call_id, tool_call=part)

elif isinstance(part, messages_.FilePart):
yield events_.FileEvent(
Expand Down
37 changes: 37 additions & 0 deletions tests/models/test_public_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,43 @@ async def test_stream_basic() -> None:
assert "".join(deltas) == "Hello world"


async def test_stream_tool_end_includes_aggregated_tool_call() -> None:
"""ToolEnd exposes the full ToolCallPart assembled from streamed input."""

async def _tool_stream(
client: models.Client,
model: models.Model,
messages: list[messages_.Message],
*,
tools: Sequence[ai.ToolLike] | None = None,
output_type: type[pydantic.BaseModel] | None = None,
**kwargs: Any,
) -> AsyncGenerator[events_.Event]:
yield events_.StreamStart()
yield events_.ToolStart(tool_call_id="tc-1", tool_name="weather")
yield events_.ToolDelta(tool_call_id="tc-1", chunk='{"city"')
yield events_.ToolDelta(tool_call_id="tc-1", chunk=':"SF"}')
yield events_.ToolEnd(
tool_call_id="tc-1",
tool_call=messages_.DUMMY_TOOL_CALL,
)
yield events_.StreamEnd()

models.register_stream("mock", _tool_stream)

s = models.stream(MOCK_MODEL, [ai.user_message("Check weather")])
tool_end: events_.ToolEnd | None = None
async for event in s:
if isinstance(event, events_.ToolEnd):
tool_end = event

assert tool_end is not None
assert tool_end.tool_call.tool_call_id == "tc-1"
assert tool_end.tool_call.tool_name == "weather"
assert tool_end.tool_call.tool_args == '{"city":"SF"}'
assert s.tool_calls == [tool_end.tool_call]


async def test_stream_with_explicit_client() -> None:
"""Model with explicit client= forwards it to the adapter."""
received_clients: list[models.Client] = []
Expand Down
Loading