Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 38 additions & 36 deletions src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,40 +225,6 @@ async def _coerce_events(
yield item


async def _default_loop(context: Context) -> AsyncGenerator[events_.AgentEvent]:
while True:
stream = models.stream(
context.model,
context.messages,
tools=context.tools,
)
async for stream_event in stream:
yield stream_event

# Bridge: emit MessageStart/MessageEnd around the assistant message
# the model stream just produced, so _collect_messages and downstream
# consumers (AI-SDK outbound, label stamping) see the same boundary
# events they did under the previous adapter contract.
if stream.message is not None and stream.message.parts:
async for boundary in _message_events(stream.message):
yield boundary

tool_calls = context.resolve(stream.tool_calls)
if not tool_calls:
break

# Execute tool calls in parallel.
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(tc()) for tc in tool_calls]

# Yield one merged tool-result message — history auto-collects it.
# Left un-stamped: the tool result is the input of the *next* turn,
# so the next stream() call will stamp it with that turn's id.
tool_msg = builders.tool_message(*(t.result() for t in tasks))
async for boundary in _message_events(tool_msg):
yield boundary


async def _collect_messages(
source: AsyncIterable[StreamItem],
messages: list[types.Message],
Expand Down Expand Up @@ -317,18 +283,54 @@ def __init__(
tools: list[Tool[..., Any]] | None = None,
) -> None:
self._tools: list[Tool[..., Any]] = tools or []
self._loop_fn: LoopFn = _default_loop
self._loop_fn: LoopFn | None = None

@property
def tools(self) -> list[Tool[..., Any]]:
"""The agent's registered tools (read-only copy)."""
return list(self._tools)

# TODO: remove?
def loop(self, fn: LoopFn) -> LoopFn:
"""Decorator: override the default loop function."""
self._loop_fn = fn
return fn

async def default_loop(
self, context: Context
) -> AsyncGenerator[events_.AgentEvent]:
while True:
stream = models.stream(
context.model,
context.messages,
tools=context.tools,
)
async for stream_event in stream:
yield stream_event

# Bridge: emit MessageStart/MessageEnd around the assistant message
# the model stream just produced, so _collect_messages and downstream
# consumers (AI-SDK outbound, label stamping) see the same boundary
# events they did under the previous adapter contract.
if stream.message is not None and stream.message.parts:
async for boundary in _message_events(stream.message):
yield boundary

tool_calls = context.resolve(stream.tool_calls)
if not tool_calls:
break

# Execute tool calls in parallel.
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(tc()) for tc in tool_calls]

# Yield one merged tool-result message — history auto-collects it.
# Left un-stamped: the tool result is the input of the *next* turn,
# so the next stream() call will stamp it with that turn's id.
tool_msg = builders.tool_message(*(t.result() for t in tasks))
async for boundary in _message_events(tool_msg):
yield boundary

async def run(
self,
model: models.Model,
Expand Down Expand Up @@ -356,7 +358,7 @@ async def run(
label=label,
)

loop_fn = self._loop_fn
loop_fn = self._loop_fn or self.default_loop

async def _real(
call: middleware_.AgentRunContext,
Expand Down
Loading