Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/fastapi-vite/backend/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def talk_to_mothership(question: str) -> str:


@chat_agent.loop
async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]:
async def graph(context: ai.Context) -> AsyncGenerator[ai.StreamItem]:
"""Agent graph with human-in-the-loop tool approval.

Loops: stream LLM -> request approval -> execute tools -> repeat.
Expand All @@ -50,7 +50,7 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Event]:
yield ai.tool_result(*results)


async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message:
async def _execute_with_approval(tc: ai.ToolCall) -> ai.ToolCallResult:
"""Execute a tool call only after the user grants approval.

Creates a ToolApproval hook that suspends execution until the
Expand All @@ -66,7 +66,7 @@ async def _execute_with_approval(tc: ai.ToolCall) -> ai.Message:
if approval.granted:
return await tc()

return ai.tool_message(
return ai.tool_result(
tool_call_id=tc.id,
tool_name=tc.name,
result="Tool call was denied by the user.",
Expand Down
13 changes: 7 additions & 6 deletions examples/multiagent-textual/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import asyncio
import json
from typing import Any

import rich.text
import pydantic
Expand Down Expand Up @@ -91,7 +92,7 @@ def append_line(self, text: str, style: str = "dim") -> None:
# ---------------------------------------------------------------------------


class MultiAgentApp(textual.app.App):
class MultiAgentApp(textual.app.App[None]):
"""Textual app for the multi-agent hooks demo."""

CSS = """
Expand All @@ -109,10 +110,10 @@ class MultiAgentApp(textual.app.App):

def __init__(self) -> None:
super().__init__()
self._hook_queue: asyncio.Queue[ai.HookPart] = asyncio.Queue()
self._current_hook: ai.HookPart | None = None
self._hook_queue: asyncio.Queue[ai.HookPart[Any]] = asyncio.Queue()
self._current_hook: ai.HookPart[Any] | None = None
self._ws: websockets.ClientConnection | None = None
self._event_adapter = pydantic.TypeAdapter(ai.AgentEvent)
self._event_adapter: pydantic.TypeAdapter[ai.AgentEvent] = pydantic.TypeAdapter(ai.AgentEvent)
self._current_label = "unknown"

def compose(self) -> textual.app.ComposeResult:
Expand Down Expand Up @@ -207,7 +208,7 @@ def _handle_event(self, event: ai.AgentEvent) -> None:
# Hook lifecycle
# ------------------------------------------------------------------

def _on_hook_pending(self, hook_part: ai.HookPart) -> None:
def _on_hook_pending(self, hook_part: ai.HookPart[Any]) -> None:
branch = hook_part.metadata.get("branch", "unknown")
tool = hook_part.metadata.get("tool", "?")

Expand All @@ -219,7 +220,7 @@ def _on_hook_pending(self, hook_part: ai.HookPart) -> None:
self._hook_queue.put_nowait(hook_part)
self._maybe_activate_next_hook()

def _on_hook_resolved(self, hook_part: ai.HookPart) -> None:
def _on_hook_resolved(self, hook_part: ai.HookPart[Any]) -> None:
branch = hook_part.metadata.get("branch", "unknown")
granted = hook_part.resolution and hook_part.resolution.get("granted")
tag = "approved" if granted else "denied"
Expand Down
8 changes: 4 additions & 4 deletions examples/multiagent-textual/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _gated_agent(
gated = ai.agent(tools=tools)

@gated.loop
async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]:
async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.StreamItem]:
while True:
s = ai.stream(context.model, context.messages, tools=context.tools)
async for event in s:
Expand All @@ -90,7 +90,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]:
if not tool_calls:
break

results: list[ai.Message] = []
results: list[ai.ToolCallResult] = []
for tc in tool_calls:
if tc.name == approval_tool:
approval = await ai.hook(
Expand All @@ -102,7 +102,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]:
results.append(await tc())
else:
results.append(
ai.tool_message(
ai.tool_result(
tool_call_id=tc.id,
tool_name=tc.name,
result=f"Denied: {approval.reason}",
Expand Down Expand Up @@ -138,7 +138,7 @@ async def gated_loop(context: ai.Context) -> AsyncGenerator[ai.Event]:


@orchestrator.loop
async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.Event]:
async def multiagent_loop(context: ai.Context) -> AsyncGenerator[ai.AgentEvent]:
"""Run two gated agents in parallel, then summarise their results."""
query = context.messages[-1].text

Expand Down
65 changes: 33 additions & 32 deletions examples/samples/agent_custom_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,55 +9,56 @@
@ai.tool
async def get_weather(city: str) -> str:
"""Get current weather for a city."""
return f"Sunny, 72F in {city}"
await asyncio.sleep(2)
return f"Sunny, 72F in {city}" if city == "Tokyo" else f"Cloudy, 55F in {city}"


@ai.tool
async def get_population(city: str) -> int:
"""Get population of a city."""
await asyncio.sleep(1)
return {"new york": 8_336_817, "tokyo": 13_960_000}.get(city.lower(), 1_000_000)


async def main() -> None:
model = ai.ai_gateway("anthropic/claude-sonnet-4")

tools = [get_weather, get_population]
my_agent = ai.agent(tools=tools)

@my_agent.loop
async def custom(context: ai.Context) -> AsyncGenerator[ai.Event]:
class CustomAgent(ai.Agent):
async def default_loop(self, context: ai.Context) -> AsyncGenerator[ai.AgentEvent]:
"""Stream, execute tools with logging, repeat."""
while True:
s = ai.models.stream(context.model, context.messages, tools=context.tools)
async for event in s:
yield event

# Yield the assistant message for silent history collection.
if s.message is not None:
yield s.message
while context.keep_running():
async with (
ai.models.stream(
context.model, context.messages, tools=context.tools
) as stream,
ai.ToolRunner(stream) as tr,
):
async for event in ai.util.merge(stream, tr.events()):
yield event

if isinstance(event, ai.ToolEnd):
call = event.tool_call
print(f"Launching tool {call.tool_name}({call.tool_args})")
tool = context.resolve(call)
tr.schedule(tool)

context.add(stream.message)
# This adds the tool message to the history, which
# also has the effect of causing another turn through
# the loop.
context.add(tr.get_tool_message())

tool_calls = context.resolve(s.tool_calls)
if not tool_calls:
return

print(
f"\n [calling {len(tool_calls)} tool(s): "
f"{', '.join(tc.name for tc in tool_calls)}]"
)
# Each resolved tool call exposes tc.fn and tc.kwargs, and
# tc(**overrides) lets you adjust arguments before execution.

async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(tc()) for tc in tool_calls]
async def main() -> None:
model = ai.ai_gateway("anthropic/claude-sonnet-4")

yield ai.tool_result(*(t.result() for t in tasks))
tools = [get_weather, get_population]
my_agent = CustomAgent(tools=tools)

async for event in my_agent.run(
model,
[ai.user_message("Compare the weather and population of New York and Tokyo.")],
):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
if isinstance(event, ai.StreamEnd) and event.message.role == "assistant":
print("====", event.message.text, flush=True)

print()


Expand Down
12 changes: 6 additions & 6 deletions examples/samples/agent_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ async def main() -> None:
my_agent = ai.agent(tools=[contact_mothership])

@my_agent.loop
async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]:
async def with_approval(
context: ai.Context,
) -> AsyncGenerator[ai.StreamItem]:
while True:
s = ai.models.stream(
context.model, context.messages, tools=context.tools
)
s = ai.models.stream(context.model, context.messages, tools=context.tools)
async for event in s:
yield event
if s.message is not None:
Expand All @@ -45,7 +45,7 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]:
if not tool_calls:
return

results = []
results: list[ai.ToolCallResult] = []
for tc in tool_calls:
if tc.name == "contact_mothership":
# Suspends until resolved from outside the loop.
Expand All @@ -58,7 +58,7 @@ async def with_approval(context: ai.Context) -> AsyncGenerator[ai.Event]:
results.append(await tc())
else:
results.append(
ai.tool_message(
ai.tool_result(
tool_call_id=tc.id,
tool_name=tc.name,
result=f"Rejected: {approval.reason}",
Expand Down
15 changes: 7 additions & 8 deletions examples/samples/agent_hooks_serverless.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ async def main() -> None:
my_agent = ai.agent(tools=[delete_file])

@my_agent.loop
async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]:
async def with_confirmation(
context: ai.Context,
) -> AsyncGenerator[ai.StreamItem]:
while True:
s = ai.models.stream(
context.model, context.messages, tools=context.tools
)
s = ai.models.stream(context.model, context.messages, tools=context.tools)
async for event in s:
yield event
if s.message is not None:
Expand All @@ -50,7 +50,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]:
if not tool_calls:
return

results = []
results: list[ai.ToolCallResult] = []
for tc in tool_calls:
try:
confirmation = await ai.hook(
Expand All @@ -67,7 +67,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]:
results.append(await tc())
else:
results.append(
ai.tool_message(
ai.tool_result(
tool_call_id=tc.id,
tool_name=tc.name,
result=f"Rejected: {confirmation.reason}",
Expand All @@ -93,8 +93,7 @@ async def with_confirmation(context: ai.Context) -> AsyncGenerator[ai.Event]:
hook_part = event.hook
pending_hook_labels.append(hook_part.hook_id)
print(
f" Hook pending: {hook_part.hook_id}"
f" (metadata={hook_part.metadata})"
f" Hook pending: {hook_part.hook_id} (metadata={hook_part.metadata})"
)

print("\n Run interrupted; approval will be pre-registered for re-entry.\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/samples/agent_nested.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ async def get_facts(topic: str) -> str:
# This tool is an async generator — it streams intermediate messages
# through the runtime sink, then returns the final result.
@ai.tool # type: ignore[arg-type] # async generator tools are supported at runtime
async def research(topic: str) -> AsyncGenerator[ai.Event]:
async def research(topic: str) -> AsyncGenerator[ai.AgentEvent]:
"""Research a topic in depth using a sub-agent."""
researcher = ai.agent(tools=[get_facts])

Expand Down
2 changes: 2 additions & 0 deletions examples/samples/agent_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ async def main() -> None:
async for event in my_agent.run(model, messages):
if isinstance(event, ai.TextDelta):
print(event.chunk, end="", flush=True)
if isinstance(event, ai.StreamEnd):
print()
print()


Expand Down
2 changes: 1 addition & 1 deletion examples/samples/check_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ async def _check(model: ai.Model) -> None:

async def _list_models(name: str, provider: object) -> None:
try:
ids: list[str] = await provider.list() # type: ignore[union-attr]
ids: list[str] = await provider.list() # type: ignore[attr-defined]
print(f" {name}: {len(ids)} models")
for mid in ids:
print(f" - {mid}")
Expand Down
39 changes: 30 additions & 9 deletions examples/samples/middleware_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@

import asyncio
import time
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any

import pydantic

import ai


class PrintMiddleware(ai.Middleware):
"""Logs every execution surface to stdout."""

async def wrap_agent_run(self, call, next):
async def wrap_agent_run(
self,
call: ai.middleware.AgentRunContext,
next: Callable[[ai.middleware.AgentRunContext], AsyncGenerator[Any]],
) -> AsyncGenerator[Any]:
label = call.label or "(default)"
print(f">>> [run] agent starting label={label} tools={len(call.tools)}")
t0 = time.perf_counter()
Expand All @@ -31,20 +39,26 @@ async def wrap_agent_run(self, call, next):
elapsed = time.perf_counter() - t0
print(f"<<< [run] agent finished label={label} {elapsed:.2f}s")

async def wrap_model(self, call, next):
async def wrap_model(
self,
call: ai.middleware.ModelContext,
next: Callable[[ai.middleware.ModelContext], Awaitable[Any]],
) -> Any:
print(f"\n>>> [model] calling {call.model.id}")
print(f" messages: {len(call.messages)}")
if call.tools:
print(f" tools: {[t.name for t in call.tools]}")

result = await next(call)

# The result is a StreamResult — async-iterable of Event objects.
# We return it as-is; the consumer iterates it normally.
print("<<< [model] stream started")
return result

async def wrap_generate(self, call, next):
async def wrap_generate(
self,
call: ai.middleware.GenerateContext,
next: Callable[[ai.middleware.GenerateContext], Awaitable[ai.Message]],
) -> ai.Message:
print(f"\n>>> [generate] calling {call.model.id}")
print(f" messages: {len(call.messages)}")

Expand All @@ -53,20 +67,27 @@ async def wrap_generate(self, call, next):
print("<<< [generate] done")
return result

async def wrap_tool(self, call, next):
async def wrap_tool(
self,
call: ai.middleware.ToolContext,
next: Callable[[ai.middleware.ToolContext], Awaitable[ai.ToolCallResult]],
) -> ai.ToolCallResult:
print(f"\n>>> [tool] {call.tool_name}({call.kwargs})")

result = await next(call)

# result is a tool-result Message.
tr = result.tool_results[0] if result.tool_results else None
tr = result.results[0] if result.results else None
if tr and not tr.is_error:
print(f"<<< [tool] {call.tool_name} -> {tr.result}")
elif tr:
print(f"<<< [tool] {call.tool_name} ERROR: {tr.result}")
return result

async def wrap_hook(self, call, next):
async def wrap_hook(
self,
call: ai.middleware.HookContext,
next: Callable[[ai.middleware.HookContext], Awaitable[pydantic.BaseModel]],
) -> pydantic.BaseModel:
print(f"\n>>> [hook] {call.label} payload={call.payload.__name__}")

result = await next(call)
Expand Down
Loading
Loading