Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ jobs:

- run: uv run ruff format --check src tests examples
- run: uv run ruff check src tests examples
- run: uv run mypy src tests
- run: uv run pyright src tests
- run: uv run mypy
- run: uv run ty check

- run: uv run pytest

Expand Down
5 changes: 2 additions & 3 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
2. after making changes run format, lint, and typecheck like ci:
- uv run ruff format --check src tests examples
- uv run ruff check src tests examples
- uv run mypy src tests
- uv run pyright src tests
- uv run mypy
- uv run ty check
3. imports:
- import by module, using the shortest unambiguous relative path. `from ..core import helpers`, `from . import streaming`
- UNLESS it's `typing` — then `from typing import Foo` (there are too many of them).
Expand Down Expand Up @@ -37,4 +37,3 @@ ensure state is easy to serialize and deserialize, modify, and compose at any le
move normalization and translation complexity inside the framework and keep the public data model minimal.

- *example*: public data model consists of a single unified `Message` type. the framework does not expose events and other intermediate steps unless the user is writing a custom adapter.

8 changes: 6 additions & 2 deletions examples/fastapi-vite/backend/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@
async def get_weather(city: str) -> str:
"""Get current weather for a city."""
await asyncio.sleep(2)
return f"Sunny, 72F in {city}" if city == "Tokyo" else f"Cloudy, 55F in {city}"
return (
f"Sunny, 72F in {city}" if city == "Tokyo" else f"Cloudy, 55F in {city}"
)


@ai.tool
async def get_population(city: str) -> int:
"""Get population of a city."""
await asyncio.sleep(1)
return {"new york": 8_336_817, "tokyo": 13_960_000}.get(city.lower(), 1_000_000)
return {"new york": 8_336_817, "tokyo": 13_960_000}.get(
city.lower(), 1_000_000
)


@ai.tool(require_approval=True)
Expand Down
9 changes: 7 additions & 2 deletions examples/fastapi-vite/backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import sys
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING

import agent as agent_
import fastapi
Expand All @@ -14,6 +14,9 @@

import ai

if TYPE_CHECKING:
from collections.abc import AsyncGenerator

app = fastapi.FastAPI(
title="py-ai-fastapi-chat",
description="Chat demo using Python Vercel AI SDK",
Expand All @@ -38,7 +41,9 @@ async def log_validation_errors(
file=sys.stderr,
flush=True,
)
return fastapi.responses.JSONResponse({"detail": exc.errors()}, status_code=422)
return fastapi.responses.JSONResponse(
{"detail": exc.errors()}, status_code=422
)


@app.get("/health")
Expand Down
6 changes: 2 additions & 4 deletions examples/fastapi-vite/backend/storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
Pluggable storage for checkpoints and session data.
"""Pluggable storage for checkpoints and session data.

Provides a minimal Storage protocol and a FileStorage implementation
that persists data as JSON files on disk. Swap in any backend that
Expand All @@ -25,8 +24,7 @@ async def delete(self, key: str) -> None: ...


class FileStorage:
"""
JSON-file-per-key storage backend.
"""JSON-file-per-key storage backend.

Each key is stored as ``{directory}/{key}.json``. Good enough for
local development; replace with a real database for production.
Expand Down
27 changes: 20 additions & 7 deletions examples/multiagent-textual/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
import asyncio
import json
import os
from typing import Any
from typing import Any, ClassVar

import pydantic
import rich.text
import textual
import textual.app
import textual.binding
import textual.containers
import textual.widgets
import textual.worker
Expand Down Expand Up @@ -107,11 +108,15 @@ class MultiAgentApp(textual.app.App[None]):
}
"""

BINDINGS = [("q", "quit", "quit")]
BINDINGS: ClassVar[
list[textual.binding.Binding | tuple[str, str] | tuple[str, str, str]]
] = [("q", "quit", "quit")]

def __init__(self) -> None:
super().__init__()
self._hook_queue: asyncio.Queue[ai.messages.HookPart[Any]] = asyncio.Queue()
self._hook_queue: asyncio.Queue[ai.messages.HookPart[Any]] = (
asyncio.Queue()
)
self._current_hook: ai.messages.HookPart[Any] | None = None
self._ws: websockets.ClientConnection | None = None
self._event_adapter: pydantic.TypeAdapter[ai.events.AgentEvent] = (
Expand Down Expand Up @@ -186,9 +191,13 @@ def _render(self, label: str, event: ai.events.AgentEvent) -> None:
if panel is not None:
for part in event.message.parts:
match part:
case ai.messages.ToolCallPart(tool_name=name, tool_args=args):
case ai.messages.ToolCallPart(
tool_name=name, tool_args=args
):
panel.append_line(f"> {name}({args})")
case ai.messages.ToolResultPart(tool_name=name, result=result):
case ai.messages.ToolResultPart(
tool_name=name, result=result
):
panel.append_line(f"< {name} = {result}")
return

Expand Down Expand Up @@ -221,7 +230,9 @@ def _on_hook_pending(self, hook_part: ai.messages.HookPart[Any]) -> None:

panel = self._get_panel(branch)
if panel:
panel.append_line(f"!! approval required: {tool}", style="dim yellow")
panel.append_line(
f"!! approval required: {tool}", style="dim yellow"
)
panel.status = "awaiting approval"

self._hook_queue.put_nowait(hook_part)
Expand Down Expand Up @@ -270,7 +281,9 @@ def _maybe_activate_next_hook(self) -> None:
inp.placeholder = f"approve {branch}/{tool}? [y/n]"
inp.focus()

async def on_input_submitted(self, event: textual.widgets.Input.Submitted) -> None:
async def on_input_submitted(
self, event: textual.widgets.Input.Submitted
) -> None:
if self._current_hook is None:
event.input.clear()
return
Expand Down
16 changes: 11 additions & 5 deletions examples/multiagent-textual/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,16 @@
import contextlib
import json
import warnings
from collections.abc import AsyncGenerator
from typing import Any
from typing import TYPE_CHECKING, Any

import fastapi
import pydantic

import ai

if TYPE_CHECKING:
from collections.abc import AsyncGenerator

# ToolResultPart.result is typed as dict but tools can return plain strings.
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")

Expand Down Expand Up @@ -149,7 +151,9 @@ async def loop(
class Orchestrator(ai.Agent):
"""Run two gated agents in parallel, then summarise their results."""

async def loop(self, context: ai.Context) -> AsyncGenerator[ai.events.AgentEvent]:
async def loop(
self, context: ai.Context
) -> AsyncGenerator[ai.events.AgentEvent]:
query = context.messages[-1].text

# Fan out: both branches stream concurrently via yield_from.
Expand Down Expand Up @@ -219,9 +223,11 @@ async def loop(self, context: ai.Context) -> AsyncGenerator[ai.events.AgentEvent


def _normalise_message(data: dict[str, Any]) -> dict[str, Any]:
"""Ensure ToolResultPart.result is always a dict for safe deserialisation."""
"""Ensure ToolResultPart.result is always safe to deserialize."""
for part in data.get("parts", []):
if part.get("kind") == "tool_result" and isinstance(part.get("result"), str):
if part.get("kind") == "tool_result" and isinstance(
part.get("result"), str
):
part["result"] = {"value": part["result"]}
return data

Expand Down
8 changes: 5 additions & 3 deletions examples/multiagent-textual/test-e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@
import tempfile
import time
import urllib.request
from http.client import HTTPResponse
from pathlib import Path
from typing import cast
from typing import TYPE_CHECKING, cast

import ai

if TYPE_CHECKING:
from http.client import HTTPResponse

HERE = Path(__file__).resolve().parent
SESSION = f"multiagent-e2e-{os.getpid()}"
SERVER_PORT = os.environ.get("SERVER_PORT", "8000")
Expand All @@ -34,7 +36,7 @@
def _check_health() -> bool:
try:
with urllib.request.urlopen(f"{SERVER_URL}/api/health", timeout=1) as r:
return cast(HTTPResponse, r).status == 200
return cast("HTTPResponse", r).status == 200
except Exception:
return False

Expand Down
37 changes: 28 additions & 9 deletions examples/run-examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def _sample_path(name: str) -> Path:
return SAMPLES / path


def _select_sample(name: str, known_samples: dict[str, Sample]) -> Sample | None:
def _select_sample(
name: str, known_samples: dict[str, Sample]
) -> Sample | None:
sample = known_samples.get(name)
if sample is not None:
return sample
Expand All @@ -178,7 +180,9 @@ def _select_sample(name: str, known_samples: dict[str, Sample]) -> Sample | None
return None


def _sample_cmd(sample: Sample, model: str | None, protocol: str | None) -> list[str]:
def _sample_cmd(
sample: Sample, model: str | None, protocol: str | None
) -> list[str]:
if sample.cmd is not None:
return sample.cmd
base = [
Expand Down Expand Up @@ -258,11 +262,21 @@ def run_sample_quiet(

def main() -> None:
parser = argparse.ArgumentParser(description="Run example samples.")
parser.add_argument("--text", action="store_true", help="include text samples")
parser.add_argument("--image", action="store_true", help="include image samples")
parser.add_argument("--video", action="store_true", help="include video samples")
parser.add_argument("--broken", action="store_true", help="include broken samples")
parser.add_argument("--e2e", action="store_true", help="include e2e test scripts")
parser.add_argument(
"--text", action="store_true", help="include text samples"
)
parser.add_argument(
"--image", action="store_true", help="include image samples"
)
parser.add_argument(
"--video", action="store_true", help="include video samples"
)
parser.add_argument(
"--broken", action="store_true", help="include broken samples"
)
parser.add_argument(
"--e2e", action="store_true", help="include e2e test scripts"
)
parser.add_argument("--all", action="store_true", help="run all samples")
parser.add_argument(
"--parallel", action="store_true", help="run samples in parallel"
Expand All @@ -287,11 +301,16 @@ def main() -> None:
"examples",
nargs="*",
metavar="example",
help="example file(s) to run, e.g. stream.py or examples/samples/stream.py",
help=(
"example file(s) to run, e.g. stream.py or "
"examples/samples/stream.py"
),
)
args = parser.parse_args()

has_category = args.text or args.image or args.video or args.broken or args.e2e
has_category = (
args.text or args.image or args.video or args.broken or args.e2e
)

samples: list[Sample] = []
if args.examples:
Expand Down
Loading
Loading