Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/fastapi-vite/backend/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
"""FastAPI application entry point."""

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import fastapi
import fastapi.middleware.cors

from .routes import chat

app = FastAPI(
app = fastapi.FastAPI(
title="py-ai-fastapi-chat",
description="Chat demo using Python Vercel AI SDK",
)

app.add_middleware(
CORSMiddleware,
fastapi.middleware.cors.CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
Expand Down
43 changes: 19 additions & 24 deletions examples/fastapi-vite/backend/routes/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,63 +2,58 @@

from __future__ import annotations

from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
import fastapi
import fastapi.responses
import pydantic

import vercel_ai_sdk as ai
from vercel_ai_sdk.ai_sdk_ui import (
UI_MESSAGE_STREAM_HEADERS,
UIMessage,
to_messages,
to_sse_stream,
)
import vercel_ai_sdk.ai_sdk_ui

from ..agent import TOOLS, get_llm, graph
from ..storage import FileStorage
from .. import agent
from .. import storage

router = APIRouter()
storage = FileStorage()
router = fastapi.APIRouter()
file_storage = storage.FileStorage()


class ChatRequest(BaseModel):
class ChatRequest(pydantic.BaseModel):
"""Request body for the chat endpoint."""

messages: list[UIMessage]
messages: list[ai.ai_sdk_ui.UIMessage]
session_id: str | None = None


@router.post("/chat")
async def chat(request: ChatRequest):
"""Handle chat requests and stream responses."""
messages = to_messages(request.messages)
messages = ai.ai_sdk_ui.to_messages(request.messages)
session_id = request.session_id or "default"
checkpoint_key = f"checkpoint:{session_id}"

llm = get_llm()
llm = agent.get_llm()

# Checkpoints resume an *interrupted* run (e.g. a hook that needed
# user input in serverless mode). Each normal chat turn is a fresh
# run — the frontend carries the full message history — so we only
# load a checkpoint when one was saved from a previous incomplete run.
saved = await storage.get(checkpoint_key)
saved = await file_storage.get(checkpoint_key)
checkpoint = ai.Checkpoint.deserialize(saved) if saved else None

result = ai.run(graph, llm, messages, TOOLS, checkpoint=checkpoint)
result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint)

async def stream_response():
async for chunk in to_sse_stream(result):
async for chunk in ai.ai_sdk_ui.to_sse_stream(result):
yield chunk

# If the run completed (no pending hooks), clear the checkpoint
# so the next request starts fresh. If hooks are pending, save
# the checkpoint so the next request can resume from here.
if result.pending_hooks:
await storage.put(checkpoint_key, result.checkpoint.serialize())
await file_storage.put(checkpoint_key, result.checkpoint.serialize())
else:
await storage.delete(checkpoint_key)
await file_storage.delete(checkpoint_key)

return StreamingResponse(
return fastapi.responses.StreamingResponse(
stream_response(),
headers=UI_MESSAGE_STREAM_HEADERS,
headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS,
)
8 changes: 4 additions & 4 deletions examples/fastapi-vite/backend/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from __future__ import annotations

import json
from pathlib import Path
import pathlib
from typing import Any, Protocol, runtime_checkable


Expand All @@ -31,11 +31,11 @@ class FileStorage:
local development; replace with a real database for production.
"""

def __init__(self, directory: str | Path = "./data") -> None:
self._dir = Path(directory)
def __init__(self, directory: str | pathlib.Path = "./data") -> None:
self._dir = pathlib.Path(directory)
self._dir.mkdir(parents=True, exist_ok=True)

def _path(self, key: str) -> Path:
def _path(self, key: str) -> pathlib.Path:
# Sanitise the key so it's safe as a filename
safe = key.replace("/", "__").replace(":", "_")
return self._dir / f"{safe}.json"
Expand Down
42 changes: 21 additions & 21 deletions examples/multiagent-textual/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
import asyncio
import json

import textual
import textual.app
import textual.containers
import textual.widgets
import textual.worker
import rich.text
import websockets
from rich.text import Text
from textual import work
from textual.app import App, ComposeResult
from textual.containers import VerticalScroll
from textual.widgets import Input, Static
from textual.worker import get_current_worker

import vercel_ai_sdk as ai

Expand All @@ -29,7 +29,7 @@
# ---------------------------------------------------------------------------


class AgentPanel(VerticalScroll):
class AgentPanel(textual.containers.VerticalScroll):
"""Scrolling panel for one agent's output stream."""

DEFAULT_CSS = """
Expand All @@ -47,15 +47,15 @@ def __init__(self, agent_id: str, title: str) -> None:
super().__init__(id=agent_id)
self._title = title
self._status = "idle"
self._content = Text()
self._content = rich.text.Text()
self._update_border_title()

def compose(self) -> ComposeResult:
yield Static(id=f"{self.id}-text")
def compose(self) -> textual.app.ComposeResult:
yield textual.widgets.Static(id=f"{self.id}-text")

@property
def text_widget(self) -> Static:
return self.query_one(f"#{self.id}-text", Static)
def text_widget(self) -> textual.widgets.Static:
return self.query_one(f"#{self.id}-text", textual.widgets.Static)

# -- status management -------------------------------------------------

Expand Down Expand Up @@ -90,7 +90,7 @@ def append_line(self, text: str, style: str = "dim") -> None:
# ---------------------------------------------------------------------------


class MultiAgentApp(App):
class MultiAgentApp(textual.app.App):
"""Textual app for the multi-agent hooks demo."""

CSS = """
Expand All @@ -112,11 +112,11 @@ def __init__(self) -> None:
self._current_hook: ai.HookPart | None = None
self._ws: websockets.ClientConnection | None = None

def compose(self) -> ComposeResult:
def compose(self) -> textual.app.ComposeResult:
yield AgentPanel("mothership", "mothership")
yield AgentPanel("data_centers", "data-centers")
yield AgentPanel("summary", "summary")
yield Input(
yield textual.widgets.Input(
placeholder="waiting for agents...",
disabled=True,
id="input-bar",
Expand All @@ -129,9 +129,9 @@ def on_mount(self) -> None:
# WebSocket reader (background worker)
# ------------------------------------------------------------------

@work(exclusive=True)
@textual.work(exclusive=True)
async def run_websocket(self) -> None:
worker = get_current_worker()
worker = textual.worker.get_current_worker()

try:
async with websockets.connect(WS_URL) as ws:
Expand Down Expand Up @@ -228,7 +228,7 @@ def _on_run_complete(self) -> None:
if panel:
panel.status = "complete"

inp = self.query_one("#input-bar", Input)
inp = self.query_one("#input-bar", textual.widgets.Input)
inp.disabled = True
inp.placeholder = "done — press q to quit"

Expand All @@ -249,12 +249,12 @@ def _maybe_activate_next_hook(self) -> None:
branch = hook.metadata.get("branch", "unknown")
tool = hook.metadata.get("tool", "?")

inp = self.query_one("#input-bar", Input)
inp = self.query_one("#input-bar", textual.widgets.Input)
inp.disabled = False
inp.placeholder = f"approve {branch}/{tool}? [y/n]"
inp.focus()

async def on_input_submitted(self, event: Input.Submitted) -> None:
async def on_input_submitted(self, event: textual.widgets.Input.Submitted) -> None:
if self._current_hook is None:
event.input.clear()
return
Expand Down Expand Up @@ -299,7 +299,7 @@ def _get_panel(self, label: str) -> AgentPanel | None:
return None

def _set_input_placeholder(self, text: str) -> None:
inp = self.query_one("#input-bar", Input)
inp = self.query_one("#input-bar", textual.widgets.Input)
inp.placeholder = text


Expand Down
8 changes: 4 additions & 4 deletions examples/multiagent-textual/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
import os
import warnings

import fastapi
import pydantic
from fastapi import FastAPI, WebSocket, WebSocketDisconnect

import vercel_ai_sdk as ai

# ToolPart.result is typed as dict but tools can return plain strings.
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")

app = FastAPI(title="multiagent-textual")
app = fastapi.FastAPI(title="multiagent-textual")

# ---------------------------------------------------------------------------
# Tools
Expand Down Expand Up @@ -168,7 +168,7 @@ def _normalise_message(data: dict) -> dict:


@app.websocket("/ws")
async def ws_endpoint(websocket: WebSocket):
async def ws_endpoint(websocket: fastapi.WebSocket):
await websocket.accept()
print("Client connected")

Expand All @@ -191,7 +191,7 @@ async def read_resolutions():
data["hook_id"],
{"granted": data["granted"], "reason": data["reason"]},
)
except WebSocketDisconnect:
except fastapi.WebSocketDisconnect:
pass

reader = asyncio.create_task(read_resolutions())
Expand Down
18 changes: 9 additions & 9 deletions examples/temporal-durable/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,25 @@

from __future__ import annotations

import dataclasses
import os
from dataclasses import dataclass
from typing import Any

from temporalio import activity
import temporalio.activity

import vercel_ai_sdk as ai
from vercel_ai_sdk.anthropic import AnthropicModel
import vercel_ai_sdk.anthropic


# ── Tool activities (one per tool, plain functions) ───────────────


@activity.defn(name="get_weather")
@temporalio.activity.defn(name="get_weather")
async def get_weather_activity(city: str) -> str:
return f"Sunny, 72F in {city}"


@activity.defn(name="get_population")
@temporalio.activity.defn(name="get_population")
async def get_population_activity(city: str) -> int:
return {"new york": 8_336_817, "los angeles": 3_979_576}.get(
city.lower(), 1_000_000
Expand All @@ -35,21 +35,21 @@ async def get_population_activity(city: str) -> int:
# ── LLM activity ─────────────────────────────────────────────────


@dataclass
@dataclasses.dataclass
class LLMCallParams:
messages: list[dict[str, Any]]
tool_schemas: list[dict[str, Any]]


@dataclass
@dataclasses.dataclass
class LLMCallResult:
message: dict[str, Any] # serialized ai.Message


@activity.defn(name="llm_call")
@temporalio.activity.defn(name="llm_call")
async def llm_call_activity(params: LLMCallParams) -> LLMCallResult:
"""Call the LLM, drain the stream, return the final message."""
llm = AnthropicModel(
llm = ai.anthropic.AnthropicModel(
model="anthropic/claude-sonnet-4",
base_url="https://ai-gateway.vercel.sh",
api_key=os.environ.get("AI_GATEWAY_API_KEY"),
Expand Down
26 changes: 15 additions & 11 deletions examples/temporal-durable/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,30 +15,34 @@
import sys
import uuid

from temporalio.client import Client
from temporalio.worker import Worker
import temporalio.client
import temporalio.worker

from activities import get_weather_activity, get_population_activity, llm_call_activity
from workflow import AgentWorkflow
import activities
import workflow

TASK_QUEUE = "agent-durable"


async def main(user_query: str) -> None:
client = await Client.connect("localhost:7233")
temporal = await temporalio.client.Client.connect("localhost:7233")

async with Worker(
client,
async with temporalio.worker.Worker(
temporal,
task_queue=TASK_QUEUE,
workflows=[AgentWorkflow],
activities=[llm_call_activity, get_weather_activity, get_population_activity],
workflows=[workflow.AgentWorkflow],
activities=[
activities.llm_call_activity,
activities.get_weather_activity,
activities.get_population_activity,
],
):
workflow_id = f"agent-durable-{uuid.uuid4().hex[:8]}"
print(f"Workflow {workflow_id}")
print(f"Query: {user_query}\n")

result = await client.execute_workflow(
AgentWorkflow.run,
result = await temporal.execute_workflow(
workflow.AgentWorkflow.run,
user_query,
id=workflow_id,
task_queue=TASK_QUEUE,
Expand Down
Loading