diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 3f1303d2..f2885e63 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -25,7 +25,7 @@ async def talk_to_mothership(question: str) -> str: @chat_agent.loop -async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]: +async def graph(context: ai.Context) -> AsyncGenerator[ai.StreamItem]: """Agent graph with human-in-the-loop tool approval. Loops: stream LLM -> request approval -> execute tools -> repeat. @@ -50,7 +50,7 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]: yield ai.tool_result(*results) -async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message: +async def _execute_with_approval(tc: ai.ToolCall) -> ai.ToolCallResult: """Execute a tool call only after the user grants approval. Creates a ToolApproval hook that suspends execution until the @@ -66,7 +66,7 @@ async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message: if approval.granted: return await tc() - return ai.tool_message( + return ai.tool_result( tool_call_id=tc.id, tool_name=tc.name, result="Tool call was denied by the user.", diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 5badedcd..a47d2fcc 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -11,6 +11,7 @@ import asyncio import json +from typing import Any import rich.text import pydantic @@ -91,7 +92,7 @@ def append_line(self, text: str, style: str = "dim") -> None: # --------------------------------------------------------------------------- -class MultiAgentApp(textual.app.App): +class MultiAgentApp(textual.app.App[None]): """Textual app for the multi-agent hooks demo.""" CSS = """ @@ -109,10 +110,10 @@ class MultiAgentApp(textual.app.App): def __init__(self) -> None: super().__init__() - self._hook_queue: asyncio.Queue[ai.HookPart] = asyncio.Queue() - self._current_hook: ai.HookPart | None = None + self._hook_queue: asyncio.Queue[ai.HookPart[Any]] = asyncio.Queue() + self._current_hook: ai.HookPart[Any] | None = None self._ws: websockets.ClientConnection | None = None - self._event_adapter = pydantic.TypeAdapter(ai.AgentEvent) + self._event_adapter: pydantic.TypeAdapter[ai.AgentEvent] = pydantic.TypeAdapter(ai.AgentEvent) self._current_label = "unknown" def compose(self) -> textual.app.ComposeResult: @@ -207,7 +208,7 @@ def _handle_event(self, event: ai.AgentEvent) -> None: # Hook lifecycle # ------------------------------------------------------------------ - def _on_hook_pending(self, hook_part: ai.HookPart) -> None: + def _on_hook_pending(self, hook_part: ai.HookPart[Any]) -> None: branch = hook_part.metadata.get("branch", "unknown") tool = hook_part.metadata.get("tool", "?") @@ -219,7 +220,7 @@ def _on_hook_pending(self, hook_part: ai.HookPart) -> None: self._hook_queue.put_nowait(hook_part) self._maybe_activate_next_hook() - def _on_hook_resolved(self, hook_part: ai.HookPart) -> None: + def _on_hook_resolved(self, hook_part: ai.HookPart[Any]) -> None: branch = hook_part.metadata.get("branch", "unknown") granted = hook_part.resolution and hook_part.resolution.get("granted") tag = "approved" if granted else "denied" diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index a673fe2c..f7716062 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -78,7 +78,7 @@ def _gated_agent( gated = ai.agent(tools=tools) @gated.loop - async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: + async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.StreamItem]: while True: s = ai.stream(context.model, context.messages, tools=context.tools) async for event in s: @@ -90,7 +90,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: if not tool_calls: break - results: list[ai.Message] = [] + results: list[ai.ToolCallResult] = [] for tc in tool_calls: if tc.name == approval_tool: approval = await ai.hook( @@ -102,7 +102,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: results.append(await tc()) else: results.append( - ai.tool_message( + ai.tool_result( tool_call_id=tc.id, tool_name=tc.name, result=f"Denied: {approval.reason}", @@ -138,7 +138,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: @orchestrator.loop -async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: +async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.AgentEvent]: """Run two gated agents in parallel, then summarise their results.""" query = context.messages[-1].text diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 7bcfe152..9ba0c498 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -9,55 +9,56 @@ @ai.tool async def get_weather(city: str) -> str: """Get current weather for a city.""" - return f"Sunny, 72F in {city}" + await asyncio.sleep(2) + return f"Sunny, 72F in {city}" if city == "Tokyo" else f"Cloudy, 55F in {city}" @ai.tool async def get_population(city: str) -> int: """Get population of a city.""" + await asyncio.sleep(1) return {"new york": 8_336_817, "tokyo": 13_960_000}.get(city.lower(), 1_000_000) -async def main() -> None: - model = ai.ai_gateway("anthropic/claude-sonnet-4") - - tools = [get_weather, get_population] - my_agent = ai.agent(tools=tools) - - @my_agent.loop - async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]: +class CustomAgent(ai.Agent): + async def default_loop(self, context: ai.Context) -> AsyncGenerator[ai.AgentEvent]: """Stream, execute tools with logging, repeat.""" - while True: - s = ai.models.stream(context.model, context.messages, tools=context.tools) - async for event in s: - yield event - - # Yield the assistant message for silent history collection. - if s.message is not None: - yield s.message + while context.keep_running(): + async with ( + ai.models.stream( + context.model, context.messages, tools=context.tools + ) as stream, + ai.ToolRunner(stream) as tr, + ): + async for event in ai.util.merge(stream, tr.events()): + yield event + + if isinstance(event, ai.ToolEnd): + call = event.tool_call + print(f"Launching tool {call.tool_name}({call.tool_args})") + tool = context.resolve(call) + tr.schedule(tool) + + context.add(stream.message) + # This adds the tool message to the history, which + # also has the effect of causing another turn through + # the loop. + context.add(tr.get_tool_message()) - tool_calls = context.resolve(s.tool_calls) - if not tool_calls: - return - print( - f"\n [calling {len(tool_calls)} tool(s): " - f"{', '.join(tc.name for tc in tool_calls)}]" - ) - # Each resolved tool call exposes tc.fn and tc.kwargs, and - # tc(**overrides) lets you adjust arguments before execution. - - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(tc()) for tc in tool_calls] +async def main() -> None: + model = ai.ai_gateway("anthropic/claude-sonnet-4") - yield ai.tool_result(*(t.result() for t in tasks)) + tools = [get_weather, get_population] + my_agent = CustomAgent(tools=tools) async for event in my_agent.run( model, [ai.user_message("Compare the weather and population of New York and Tokyo.")], ): - if isinstance(event, ai.TextDelta): - print(event.chunk, end="", flush=True) + if isinstance(event, ai.StreamEnd) and event.message.role == "assistant": + print("====", event.message.text, flush=True) + print() diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index c103035c..a3ed0677 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -31,11 +31,11 @@ async def main() -> None: my_agent = ai.agent(tools=[contact_mothership]) @my_agent.loop - async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: + async def with_approval( + context: ai.Context, + ) -> AsyncGenerator[ai.StreamItem]: while True: - s = ai.models.stream( - context.model, context.messages, tools=context.tools - ) + s = ai.models.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event if s.message is not None: @@ -45,7 +45,7 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: if not tool_calls: return - results = [] + results: list[ai.ToolCallResult] = [] for tc in tool_calls: if tc.name == "contact_mothership": # Suspends until resolved from outside the loop. @@ -58,7 +58,7 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]: results.append(await tc()) else: results.append( - ai.tool_message( + ai.tool_result( tool_call_id=tc.id, tool_name=tc.name, result=f"Rejected: {approval.reason}", diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 8262c9ab..72376b49 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -36,11 +36,11 @@ async def main() -> None: my_agent = ai.agent(tools=[delete_file]) @my_agent.loop - async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: + async def with_confirmation( + context: ai.Context, + ) -> AsyncGenerator[ai.StreamItem]: while True: - s = ai.models.stream( - context.model, context.messages, tools=context.tools - ) + s = ai.models.stream(context.model, context.messages, tools=context.tools) async for event in s: yield event if s.message is not None: @@ -50,7 +50,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: if not tool_calls: return - results = [] + results: list[ai.ToolCallResult] = [] for tc in tool_calls: try: confirmation = await ai.hook( @@ -67,7 +67,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: results.append(await tc()) else: results.append( - ai.tool_message( + ai.tool_result( tool_call_id=tc.id, tool_name=tc.name, result=f"Rejected: {confirmation.reason}", @@ -93,8 +93,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]: hook_part = event.hook pending_hook_labels.append(hook_part.hook_id) print( - f" Hook pending: {hook_part.hook_id}" - f" (metadata={hook_part.metadata})" + f" Hook pending: {hook_part.hook_id} (metadata={hook_part.metadata})" ) print("\n Run interrupted; approval will be pre-registered for re-entry.\n") diff --git a/examples/samples/agent_nested.py b/examples/samples/agent_nested.py index c5a11e87..638c2899 100644 --- a/examples/samples/agent_nested.py +++ b/examples/samples/agent_nested.py @@ -21,7 +21,7 @@ async def get_facts(topic: str) -> str: # This tool is an async generator — it streams intermediate messages # through the runtime sink, then returns the final result. @ai.tool # type: ignore[arg-type] # async generator tools are supported at runtime -async def research(topic: str) -> AsyncGenerator[ai.Event]: +async def research(topic: str) -> AsyncGenerator[ai.AgentEvent]: """Research a topic in depth using a sub-agent.""" researcher = ai.agent(tools=[get_facts]) diff --git a/examples/samples/agent_simple.py b/examples/samples/agent_simple.py index 4c376d0a..02916c9a 100644 --- a/examples/samples/agent_simple.py +++ b/examples/samples/agent_simple.py @@ -24,6 +24,8 @@ async def main() -> None: async for event in my_agent.run(model, messages): if isinstance(event, ai.TextDelta): print(event.chunk, end="", flush=True) + if isinstance(event, ai.StreamEnd): + print() print() diff --git a/examples/samples/check_connection.py b/examples/samples/check_connection.py index b89d2f16..2799a807 100644 --- a/examples/samples/check_connection.py +++ b/examples/samples/check_connection.py @@ -28,7 +28,7 @@ async def _check(model: ai.Model) -> None: async def _list_models(name: str, provider: object) -> None: try: - ids: list[str] = await provider.list() # type: ignore[union-attr] + ids: list[str] = await provider.list() # type: ignore[attr-defined] print(f" {name}: {len(ids)} models") for mid in ids: print(f" - {mid}") diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index f317ba3a..7651a34e 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -13,6 +13,10 @@ import asyncio import time +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +import pydantic import ai @@ -20,7 +24,11 @@ class PrintMiddleware(ai.Middleware): """Logs every execution surface to stdout.""" - async def wrap_agent_run(self, call, next): + async def wrap_agent_run( + self, + call: ai.middleware.AgentRunContext, + next: Callable[[ai.middleware.AgentRunContext], AsyncGenerator[Any]], + ) -> AsyncGenerator[Any]: label = call.label or "(default)" print(f">>> [run] agent starting label={label} tools={len(call.tools)}") t0 = time.perf_counter() @@ -31,7 +39,11 @@ async def wrap_agent_run(self, call, next): elapsed = time.perf_counter() - t0 print(f"<<< [run] agent finished label={label} {elapsed:.2f}s") - async def wrap_model(self, call, next): + async def wrap_model( + self, + call: ai.middleware.ModelContext, + next: Callable[[ai.middleware.ModelContext], Awaitable[Any]], + ) -> Any: print(f"\n>>> [model] calling {call.model.id}") print(f" messages: {len(call.messages)}") if call.tools: @@ -39,12 +51,14 @@ async def wrap_model(self, call, next): result = await next(call) - # The result is a StreamResult — async-iterable of Event objects. - # We return it as-is; the consumer iterates it normally. print("<<< [model] stream started") return result - async def wrap_generate(self, call, next): + async def wrap_generate( + self, + call: ai.middleware.GenerateContext, + next: Callable[[ai.middleware.GenerateContext], Awaitable[ai.Message]], + ) -> ai.Message: print(f"\n>>> [generate] calling {call.model.id}") print(f" messages: {len(call.messages)}") @@ -53,20 +67,27 @@ async def wrap_generate(self, call, next): print("<<< [generate] done") return result - async def wrap_tool(self, call, next): + async def wrap_tool( + self, + call: ai.middleware.ToolContext, + next: Callable[[ai.middleware.ToolContext], Awaitable[ai.ToolCallResult]], + ) -> ai.ToolCallResult: print(f"\n>>> [tool] {call.tool_name}({call.kwargs})") result = await next(call) - # result is a tool-result Message. - tr = result.tool_results[0] if result.tool_results else None + tr = result.results[0] if result.results else None if tr and not tr.is_error: print(f"<<< [tool] {call.tool_name} -> {tr.result}") elif tr: print(f"<<< [tool] {call.tool_name} ERROR: {tr.result}") return result - async def wrap_hook(self, call, next): + async def wrap_hook( + self, + call: ai.middleware.HookContext, + next: Callable[[ai.middleware.HookContext], Awaitable[pydantic.BaseModel]], + ) -> pydantic.BaseModel: print(f"\n>>> [hook] {call.label} payload={call.payload.__name__}") result = await next(call) diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 277e3b33..694549ac 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -12,7 +12,7 @@ @ai.tool # type: ignore[arg-type] # async generator tools are supported at runtime -async def talk_to_mothership(question: str) -> AsyncGenerator[ai.Event]: +async def talk_to_mothership(question: str) -> AsyncGenerator[ai.StreamItem]: """Ask the mothership a question. Streams progress back to the caller.""" for step in ["Connecting...", "Transmitting...", "Awaiting response..."]: yield ai.TextStart(block_id=f"progress-{step}") diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 3af2bf0e..66f7b2d3 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -118,7 +118,7 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: @weather_agent.loop -async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.Event]: +async def temporal_loop(context: ai.Context) -> AsyncGenerator[ai.StreamItem]: tool_schemas = [ {"name": t.name, "description": t.description, "param_schema": t.param_schema} for t in context.tools diff --git a/examples/temporal-middleware/main.py b/examples/temporal-middleware/main.py index 0eebe327..eafbbc80 100644 --- a/examples/temporal-middleware/main.py +++ b/examples/temporal-middleware/main.py @@ -149,7 +149,7 @@ async def _replay_as_stream(msg: ai.Message) -> AsyncGenerator[ai.Event]: yield ai.ToolDelta( tool_call_id=part.tool_call_id, chunk=part.tool_args ) - yield ai.ToolEnd(tool_call_id=part.tool_call_id) + yield ai.ToolEnd(tool_call_id=part.tool_call_id, tool_call=part) yield ai.StreamEnd() @@ -193,7 +193,7 @@ async def wrap_tool( self, call: ai.middleware.ToolContext, next: Any, - ) -> ai.Message: + ) -> ai.ToolCallResult: """Tool execution → Temporal activity.""" result = await temporalio.workflow.execute_activity( tool_dispatch_activity, @@ -203,13 +203,11 @@ async def wrap_tool( ), start_to_close_timeout=datetime.timedelta(minutes=2), ) - return ai.tool_message( - ai.ToolResultPart( - tool_call_id=call.tool_call_id, - tool_name=call.tool_name, - result=result.result, - is_error=result.is_error, - ) + return ai.tool_result( + tool_call_id=call.tool_call_id, + tool_name=call.tool_name, + result=result.result, + is_error=result.is_error, ) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 816a506f..c80c9dc7 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -1,15 +1,17 @@ -from . import middleware, models +from . import middleware, models, util from .agents import ( TOOL_APPROVAL_HOOK_TYPE, Agent, AgentEvent, Context, HookEvent, + StreamItem, TerminalEvent, Tool, ToolApproval, ToolCall, ToolCallResult, + ToolRunner, agent, cancel_hook, hook, @@ -141,12 +143,14 @@ # Agents — primary API "Agent", "AgentEvent", + "StreamItem", "agent", "Context", # Agents — tools "Tool", "ToolCall", "ToolCallResult", + "ToolRunner", "tool", # Agents — composition "yield_from", @@ -164,4 +168,5 @@ "middleware", # Submodules "mcp", + "util", ] diff --git a/src/ai/agents/__init__.py b/src/ai/agents/__init__.py index 575ab7eb..cd09e453 100644 --- a/src/ai/agents/__init__.py +++ b/src/ai/agents/__init__.py @@ -1,5 +1,16 @@ from . import mcp, ui -from .agent import Agent, Context, Tool, ToolCall, agent, tool, tool_result, yield_from +from .agent import ( + Agent, + Context, + StreamItem, + Tool, + ToolCall, + ToolRunner, + agent, + tool, + tool_result, + yield_from, +) from .events import AgentEvent, HookEvent, TerminalEvent, ToolCallResult from .hooks import ( TOOL_APPROVAL_HOOK_TYPE, @@ -12,11 +23,13 @@ __all__ = [ "Agent", "AgentEvent", + "StreamItem", "Context", "HookEvent", "Tool", "ToolApproval", "ToolCall", + "ToolRunner", "TerminalEvent", "ToolCallResult", "TOOL_APPROVAL_HOOK_TYPE", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 701f9e7f..fae6b19b 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -5,13 +5,13 @@ import asyncio import inspect import json -from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable -from typing import Any, Protocol, get_type_hints +from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Callable, Sequence +from typing import Any, Protocol, Self, get_type_hints, overload import pydantic from .. import middleware as middleware_ -from .. import models, types +from .. import models, types, util from ..types import builders from . import events as events_ from . import runtime @@ -140,8 +140,8 @@ def kwargs(self) -> dict[str, Any]: self._kwargs = self._tool.parse_args(self._part.tool_args) return dict(self._kwargs) - async def __call__(self, **overrides: Any) -> types.Message: - """Execute the tool and return a single tool-result message.""" + async def __call__(self, **overrides: Any) -> events_.ToolCallResult: + """Execute the tool and return a :class:`ToolCallResult`.""" # Best-effort parse so middleware sees usable kwargs when possible. # If parsing fails, middleware still gets the raw tool_call_id / # tool_name and can replace kwargs before _real() executes. @@ -163,11 +163,11 @@ async def __call__(self, **overrides: Any) -> types.Message: tool = self._tool - async def _real(call: middleware_.ToolContext) -> types.Message: + async def _real(call: middleware_.ToolContext) -> events_.ToolCallResult: try: result = await tool.execute_kwargs(call.kwargs) except Exception as exc: - return builders.tool_message( + return tool_result( types.ToolResultPart( tool_call_id=call.tool_call_id, tool_name=call.tool_name, @@ -175,7 +175,7 @@ async def _real(call: middleware_.ToolContext) -> types.Message: is_error=True, ) ) - return builders.tool_message( + return tool_result( types.ToolResultPart( tool_call_id=call.tool_call_id, tool_name=call.tool_name, @@ -187,6 +187,64 @@ async def _real(call: middleware_.ToolContext) -> types.Message: return await chain(call) +class ToolRunner: + def __init__(self, stream: models.Stream) -> None: + self._stream = stream + # finish_future gets set when the stream exhausts. We won't + # exhaust until that happens, since the stream can cause more + # tools to get triggered. + self._finish_future = stream.finish_future + # A future that gets signalled when we add a new tool, so that + # asyncio.wait gets woken up and cycles around in the loop to + # wait on the new thing as well. + self._sched_waiter: asyncio.Future[None] = ( + asyncio.get_running_loop().create_future() + ) + self._active: set[ + asyncio.Future[events_.ToolCallResult] | asyncio.Future[None] + ] = {self._finish_future} + self._tool_results: list[events_.ToolCallResult] = [] + self._tg_base = asyncio.TaskGroup() + + async def __aenter__(self) -> Self: + self._tg = await self._tg_base.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> None: + await self._tg_base.__aexit__(*args) + + def events(self) -> AsyncGenerator[events_.ToolCallResult]: + return self._iterate() + + def schedule(self, tc: ToolCall) -> None: + self._active.add(self._tg.create_task(tc())) + self._sched_waiter.set_result(None) + + def get_tool_message(self) -> types.Message | None: + if self._tool_results: + return builders.tool_message(*[t.message for t in self._tool_results]) + return None + + async def _iterate(self) -> AsyncGenerator[events_.ToolCallResult]: + while self._active: + done, _ = await asyncio.wait( + [*self._active, self._sched_waiter], + return_when=asyncio.FIRST_COMPLETED, + ) + for t in done: + self._active.discard(t) + if t is self._finish_future: + t.result() + elif t is self._sched_waiter: + t.result() + self._sched_waiter = asyncio.get_running_loop().create_future() + else: + res = t.result() + assert res is not None + self._tool_results.append(res) + yield res + + class Context(pydantic.BaseModel): """Everything that goes into the LLM.""" @@ -194,14 +252,41 @@ class Context(pydantic.BaseModel): messages: list[types.Message] tools: list[Tool[..., Any]] + _tools_by_name: dict[str, Tool[..., Any]] = pydantic.PrivateAttr() + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) - def resolve(self, tool_parts: list[types.ToolCallPart]) -> list[ToolCall]: - """Resolve ToolCallParts into callable ToolCall objects.""" - tools_by_name = {t.name: t for t in self.tools} - return [ - ToolCall(part=tp, tool=tools_by_name[tp.tool_name]) for tp in tool_parts - ] + def model_post_init(self, __context: Any) -> None: + self._tools_by_name = {t.name: t for t in self.tools} + + def keep_running(self) -> bool: + """Call at top of an agent loop to see whether to keep running.""" + return bool( + self.messages and self.messages[-1].role not in ("assistant", "internal") + ) + + @overload + def resolve(self, tool_part: types.ToolCallPart) -> ToolCall: ... + @overload + def resolve(self, tool_part: Sequence[types.ToolCallPart]) -> list[ToolCall]: ... + + def resolve( + self, tool_part: types.ToolCallPart | Sequence[types.ToolCallPart] + ) -> ToolCall | list[ToolCall]: + """Resolve ToolCallPart(s) into callable ToolCall object(s).""" + if isinstance(tool_part, types.ToolCallPart): + return ToolCall( + part=tool_part, tool=self._tools_by_name[tool_part.tool_name] + ) + return [self.resolve(tp) for tp in tool_part] + + def add(self, message: types.Message | Sequence[types.Message] | None) -> None: + if message is None: + return + if isinstance(message, types.Message): + self.messages.append(message) + else: + self.messages.extend(message) class LoopFn(Protocol): @@ -209,16 +294,42 @@ def __call__(self, context: Context) -> AsyncGenerator[StreamItem]: ... def tool_result( - *items: types.Message | types.ToolResultPart | list[types.Message], + *items: types.Message + | types.ToolResultPart + | events_.ToolCallResult + | list[types.Message], + tool_call_id: str | None = None, + result: Any = None, + tool_name: str = "", + is_error: bool = False, ) -> events_.ToolCallResult: - """Create a :class:`ToolCallResult` from tool messages or parts. + """Create a :class:`ToolCallResult` from tool messages, parts, or kwargs. - Wraps :func:`ai.tool_message` and returns a ``ToolCallResult`` - event ready to yield from a custom loop:: + Accepts ``ToolCallResult`` items (extracts their ``.message``), + plain ``Message`` objects, ``ToolResultPart`` instances, or keyword + arguments matching :func:`ai.tool_message`:: yield ai.tool_result(*(t.result() for t in tasks)) + ai.tool_result(tool_call_id="tc-1", result="denied", is_error=True) """ - msg = builders.tool_message(*items) + if tool_call_id is not None: + msg = builders.tool_message( + tool_call_id=tool_call_id, + result=result, + tool_name=tool_name, + is_error=is_error, + ) + return events_.ToolCallResult(message=msg, results=msg.tool_results) + + unwrapped: list[types.Message | types.ToolResultPart] = [] + for item in items: + if isinstance(item, events_.ToolCallResult): + unwrapped.append(item.message) + elif isinstance(item, list): + unwrapped.extend(item) + else: + unwrapped.append(item) + msg = builders.tool_message(*unwrapped) return events_.ToolCallResult(message=msg, results=msg.tool_results) @@ -231,6 +342,7 @@ def _upsert_message(messages: list[types.Message], message: types.Message) -> No messages.append(message) +# TODO: Stop doing this? async def _collect_messages( source: AsyncIterable[StreamItem], messages: list[types.Message], @@ -238,7 +350,6 @@ async def _collect_messages( """Intercept yielded items and maintain the *messages* history list. * Bare ``Message`` — silently collected (not forwarded to consumer). - * ``ToolCallResult`` — collected *and* forwarded. * Any other ``AgentEvent`` — forwarded as-is. This runs on the **producer** side (same coroutine as the loop function), @@ -248,9 +359,6 @@ async def _collect_messages( async for item in source: if isinstance(item, types.Message): _upsert_message(messages, item) - elif isinstance(item, events_.ToolCallResult): - _upsert_message(messages, item.message) - yield item else: yield item @@ -306,29 +414,29 @@ def loop(self, fn: LoopFn) -> LoopFn: self._loop_fn = fn return fn - async def default_loop(self, context: Context) -> AsyncGenerator[StreamItem]: - while True: - stream = models.stream( - context.model, - context.messages, - tools=context.tools, - ) - async for stream_event in stream: - yield stream_event - - # Yield the assistant message for silent history collection. - if stream.message is not None and stream.message.parts: - yield stream.message - - tool_calls = context.resolve(stream.tool_calls) - if not tool_calls: - break - - # Execute tool calls in parallel. - async with asyncio.TaskGroup() as tg: - tasks = [tg.create_task(tc()) for tc in tool_calls] - - yield tool_result(*(t.result() for t in tasks)) + async def default_loop( + self, context: Context + ) -> AsyncGenerator[events_.AgentEvent]: + """Stream, execute tools, repeat.""" + while context.keep_running(): + async with ( + models.stream( + context.model, context.messages, tools=context.tools + ) as stream, + ToolRunner(stream) as tr, + ): + async for event in util.merge(stream, tr.events()): + yield event + + if isinstance(event, types.ToolEnd): + tool = context.resolve(event.tool_call) + tr.schedule(tool) + + context.add(stream.message) + # This adds the tool message to the history, which + # also has the effect of causing another turn through + # the loop. + context.add(tr.get_tool_message()) async def run( self, diff --git a/src/ai/middleware.py b/src/ai/middleware.py index 8e0c22ff..d9e9c2c5 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from .agents.agent import Tool + from .agents.events import ToolCallResult from .models.core.model import Model @@ -191,11 +192,11 @@ async def wrap_generate( async def wrap_tool( self, call: ToolContext, - next: Callable[[ToolContext], Awaitable[_Message]], - ) -> _Message: + next: Callable[[ToolContext], Awaitable[ToolCallResult]], + ) -> ToolCallResult: """Wrap a tool execution. - ``next(call)`` returns the tool-result ``Message``. + ``next(call)`` returns a :class:`ToolCallResult`. """ return await next(call) @@ -298,8 +299,8 @@ async def _wrapped(call: GenerateContext) -> _Message: def _build_tool_chain( - real: Callable[[ToolContext], Awaitable[_Message]], -) -> Callable[[ToolContext], Awaitable[_Message]]: + real: Callable[[ToolContext], Awaitable[ToolCallResult]], +) -> Callable[[ToolContext], Awaitable[ToolCallResult]]: mw = get() if not mw: return real @@ -308,9 +309,9 @@ def _build_tool_chain( for m in reversed(mw): def _make( - m: Middleware, nxt: Callable[[ToolContext], Awaitable[_Message]] - ) -> Callable[[ToolContext], Awaitable[_Message]]: - async def _wrapped(call: ToolContext) -> _Message: + m: Middleware, nxt: Callable[[ToolContext], Awaitable[ToolCallResult]] + ) -> Callable[[ToolContext], Awaitable[ToolCallResult]]: + async def _wrapped(call: ToolContext) -> ToolCallResult: return await m.wrap_tool(call, nxt) return _wrapped diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 72dc64bf..54785fbf 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -1,3 +1,4 @@ +import asyncio import dataclasses from collections.abc import AsyncGenerator, Sequence from typing import Any, Protocol, Self, runtime_checkable @@ -67,6 +68,13 @@ def __init__(self, gen: AsyncGenerator[types.Event]) -> None: self._gen = gen self._message: types.Message = types.Message(role="assistant", parts=[]) self._parts: dict[str, types.Part] = {} + self._finish_future: asyncio.Future[None] = ( + asyncio.get_event_loop().create_future() + ) + + @property + def finish_future(self) -> asyncio.Future[None]: + return self._finish_future async def __aenter__(self) -> Self: return self @@ -76,14 +84,21 @@ async def __aexit__( exc_type: type[BaseException] | None, exc: BaseException | None, tb: object, - ) -> None: + ) -> bool: await self._gen.aclose() + return False def __aiter__(self) -> Self: return self - async def __anext__(self) -> types.Event: - event = await self._gen.__anext__() + async def __anext__(self: Self) -> types.Event: + try: + event = await self._gen.__anext__() + except Exception: + # Usually this fires on StopAsyncIteration, but could be a + # real exception too + self._finish_future.set_result(None) + raise updates = self._aggregate_event(event) return event.model_copy(update={"message": self._message, **updates}) diff --git a/tests/agents/test_tools.py b/tests/agents/test_tools.py index 1e2e04f4..9f9b9f8a 100644 --- a/tests/agents/test_tools.py +++ b/tests/agents/test_tools.py @@ -88,12 +88,12 @@ async def double(x: int) -> int: assert tc.fn is double.fn assert tc.kwargs == {"x": 5} - assert result.role == "tool" - assert len(result.tool_results) == 1 - assert result.tool_results[0].tool_call_id == "tc-1" - assert result.tool_results[0].tool_name == "double" - assert result.tool_results[0].result == 10 - assert not result.tool_results[0].is_error + assert result.message.role == "tool" + assert len(result.results) == 1 + assert result.results[0].tool_call_id == "tc-1" + assert result.results[0].tool_name == "double" + assert result.results[0].result == 10 + assert not result.results[0].is_error async def test_tool_call_catches_errors() -> None: @@ -110,8 +110,8 @@ async def fail(x: int) -> int: tc = ai.ToolCall(part=part, tool=fail) result = await tc() - assert result.tool_results[0].is_error - assert "boom" in str(result.tool_results[0].result) + assert result.results[0].is_error + assert "boom" in str(result.results[0].result) async def test_tool_call_allows_kwarg_overrides() -> None: @@ -129,7 +129,7 @@ async def double(x: int) -> int: result = await tc(x=7) - assert result.tool_results[0].result == 14 + assert result.results[0].result == 14 async def test_tool_call_override_validation_failure() -> None: @@ -164,7 +164,7 @@ async def double(x: int) -> int: result = await tc() - assert result.tool_results[0].is_error + assert result.results[0].is_error # -- Helpers ---------------------------------------------------------------