diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 37fcdc20..57735f99 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -225,40 +225,6 @@ async def _coerce_events( yield item -async def _default_loop(context: Context) -> AsyncGenerator[events_.AgentEvent]: - while True: - stream = models.stream( - context.model, - context.messages, - tools=context.tools, - ) - async for stream_event in stream: - yield stream_event - - # Bridge: emit MessageStart/MessageEnd around the assistant message - # the model stream just produced, so _collect_messages and downstream - # consumers (AI-SDK outbound, label stamping) see the same boundary - # events they did under the previous adapter contract. - if stream.message is not None and stream.message.parts: - async for boundary in _message_events(stream.message): - yield boundary - - tool_calls = context.resolve(stream.tool_calls) - if not tool_calls: - break - - # Execute tool calls in parallel. - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(tc()) for tc in tool_calls] - - # Yield one merged tool-result message — history auto-collects it. - # Left un-stamped: the tool result is the input of the *next* turn, - # so the next stream() call will stamp it with that turn's id. - tool_msg = builders.tool_message(*(t.result() for t in tasks)) - async for boundary in _message_events(tool_msg): - yield boundary - - async def _collect_messages( source: AsyncIterable[StreamItem], messages: list[types.Message], @@ -317,18 +283,54 @@ def __init__( tools: list[Tool[..., Any]] | None = None, ) -> None: self._tools: list[Tool[..., Any]] = tools or [] - self._loop_fn: LoopFn = _default_loop + self._loop_fn: LoopFn | None = None @property def tools(self) -> list[Tool[..., Any]]: """The agent's registered tools (read-only copy).""" return list(self._tools) + # TODO: remove? def loop(self, fn: LoopFn) -> LoopFn: """Decorator: override the default loop function.""" self._loop_fn = fn return fn + async def default_loop( + self, context: Context + ) -> AsyncGenerator[events_.AgentEvent]: + while True: + stream = models.stream( + context.model, + context.messages, + tools=context.tools, + ) + async for stream_event in stream: + yield stream_event + + # Bridge: emit MessageStart/MessageEnd around the assistant message + # the model stream just produced, so _collect_messages and downstream + # consumers (AI-SDK outbound, label stamping) see the same boundary + # events they did under the previous adapter contract. + if stream.message is not None and stream.message.parts: + async for boundary in _message_events(stream.message): + yield boundary + + tool_calls = context.resolve(stream.tool_calls) + if not tool_calls: + break + + # Execute tool calls in parallel. + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(tc()) for tc in tool_calls] + + # Yield one merged tool-result message — history auto-collects it. + # Left un-stamped: the tool result is the input of the *next* turn, + # so the next stream() call will stamp it with that turn's id. + tool_msg = builders.tool_message(*(t.result() for t in tasks)) + async for boundary in _message_events(tool_msg): + yield boundary + async def run( self, model: models.Model, @@ -356,7 +358,7 @@ async def run( label=label, ) - loop_fn = self._loop_fn + loop_fn = self._loop_fn or self.default_loop async def _real( call: middleware_.AgentRunContext,