diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index 7f75c5e5..e8552512 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -1,12 +1,29 @@ # fastapi-chat Chat demo using the Python Vercel AI SDK with a FastAPI backend and React frontend. +Includes **human-in-the-loop tool approval** — every tool call is gated +behind user confirmation before execution. ## Stack - **Backend:** FastAPI + vercel-ai-sdk (Python 3.12) - **Frontend:** Vite + React + AI SDK UI + AI Elements +## Human-in-the-Loop + +The agent graph in `backend/agent.py` uses the `ToolApproval` hook to +suspend execution whenever the LLM wants to call a tool. The flow is: + +1. LLM emits a tool call +2. Backend creates a `ToolApproval` hook — this emits an + `approval-requested` event on the SSE stream and suspends execution +3. The frontend renders Approve / Reject buttons via the + `` component (from AI Elements) +4. When the user clicks a button, `addToolApprovalResponse()` patches + the message and sends a new request with the decision +5. The backend resumes from the checkpoint and either executes the tool + or marks it as denied + ## Setup ```bash diff --git a/examples/fastapi-vite/backend/__init__.py b/examples/fastapi-vite/backend/__init__.py deleted file mode 100644 index 7f831694..00000000 --- a/examples/fastapi-vite/backend/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Backend package diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 94ecd714..023a66d2 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -1,5 +1,10 @@ -"""Agent logic for the chat demo.""" +"""Agent logic for the chat demo. +Demonstrates human-in-the-loop tool approval using ToolApproval hooks. +Every tool call is gated behind user approval before execution. +""" + +import asyncio from typing import Any import vercel_ai_sdk as ai @@ -19,16 +24,49 @@ def get_llm() -> ai.LanguageModel: TOOLS: list[ai.Tool[..., Any]] = [talk_to_mothership] +async def _execute_with_approval( + tc: ai.ToolPart, message: ai.Message | None = None +) -> None: + """Execute a tool call only after the user grants approval. + + Creates a ToolApproval hook that suspends execution until the + frontend responds with an approve/reject decision. + """ + approval = await ai.ToolApproval.create( # type: ignore[attr-defined] + f"approve_{tc.tool_call_id}", + metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, + ) + + if approval.granted: + await ai.execute_tool(tc, message=message) + else: + tc.set_error("Tool call was denied by the user.") + + async def graph( llm: ai.LanguageModel, messages: list[ai.Message], tools: list[ai.Tool[..., Any]], ) -> ai.StreamResult: - """ - Agent graph: stream LLM, execute tools, repeat until done. + """Agent graph with human-in-the-loop tool approval. - This is a plain async function that goes through the Runtime queue - via stream_loop. When hooks are added later, they slot in here - between tool calls — no structural change needed. + Loops: stream LLM -> request approval -> execute tools -> repeat. + The ToolApproval hook suspends execution and emits an approval- + request event on the SSE stream. The frontend displays Approve / + Reject buttons and sends the decision back on the next request. """ - return await ai.stream_loop(llm, messages, tools) + local_messages = list(messages) + + while True: + result = await ai.stream_step(llm, local_messages, tools) + + if not result.tool_calls: + return result + + last_msg = result.last_message + assert last_msg is not None + local_messages.append(last_msg) + + await asyncio.gather( + *(_execute_with_approval(tc, message=last_msg) for tc in result.tool_calls) + ) diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 7aa8cf42..d34322f4 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -1,15 +1,25 @@ """FastAPI application entry point.""" +from __future__ import annotations + +from collections.abc import AsyncGenerator + +import agent import fastapi import fastapi.middleware.cors -from routes import chat +import fastapi.responses +import pydantic +import storage -api = fastapi.FastAPI( +import vercel_ai_sdk as ai +import vercel_ai_sdk.ai_sdk_ui + +app = fastapi.FastAPI( title="py-ai-fastapi-chat", description="Chat demo using Python Vercel AI SDK", ) -api.add_middleware( +app.add_middleware( fastapi.middleware.cors.CORSMiddleware, allow_origins=["*"], allow_credentials=True, @@ -17,14 +27,59 @@ allow_headers=["*"], ) -api.include_router(chat.router) - -@api.get("/health") +@app.get("/health") async def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} -app = fastapi.FastAPI() -app.mount("/api", api) +file_storage = storage.FileStorage() + + +class ChatRequest(pydantic.BaseModel): + """Request body for the chat endpoint.""" + + messages: list[ai.ai_sdk_ui.UIMessage] + session_id: str | None = None + + +@app.post("/chat") +async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: + """Handle chat requests and stream responses.""" + messages = ai.ai_sdk_ui.to_messages(request.messages) + session_id = request.session_id or "default" + checkpoint_key = f"checkpoint:{session_id}" + + llm = agent.get_llm() + + checkpoint = None + saved = await file_storage.get(checkpoint_key) + if saved: + checkpoint = ai.Checkpoint.model_validate(saved) + + result = ai.run( + agent.graph, + llm, + messages, + agent.TOOLS, + checkpoint=checkpoint, + cancel_on_hooks=True, + ) + + async def stream_response() -> AsyncGenerator[str]: + async for chunk in ai.ai_sdk_ui.to_sse_stream(result): + yield chunk + + if result.checkpoint.pending_hooks: + await file_storage.put( + checkpoint_key, + result.checkpoint.model_dump(), + ) + else: + await file_storage.delete(checkpoint_key) + + return fastapi.responses.StreamingResponse( + stream_response(), + headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, + ) diff --git a/examples/fastapi-vite/backend/pyproject.toml b/examples/fastapi-vite/backend/pyproject.toml index 1909ae57..3fa5b347 100644 --- a/examples/fastapi-vite/backend/pyproject.toml +++ b/examples/fastapi-vite/backend/pyproject.toml @@ -7,3 +7,4 @@ dependencies = [ "fastapi[standard]>=0.128.1", "vercel-ai-sdk>=0.0.1.dev5", ] + diff --git a/examples/fastapi-vite/backend/routes/__init__.py b/examples/fastapi-vite/backend/routes/__init__.py deleted file mode 100644 index d212dab6..00000000 --- a/examples/fastapi-vite/backend/routes/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Routes package diff --git a/examples/fastapi-vite/backend/routes/chat.py b/examples/fastapi-vite/backend/routes/chat.py deleted file mode 100644 index 69b95d88..00000000 --- a/examples/fastapi-vite/backend/routes/chat.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Chat route — streams LLM responses via the AI SDK UI protocol.""" - -from __future__ import annotations - -from collections.abc import AsyncGenerator - -import agent -import fastapi -import fastapi.responses -import pydantic -import storage - -import vercel_ai_sdk as ai -import vercel_ai_sdk.ai_sdk_ui - -router = fastapi.APIRouter() -file_storage = storage.FileStorage() - - -class ChatRequest(pydantic.BaseModel): - """Request body for the chat endpoint.""" - - messages: list[ai.ai_sdk_ui.UIMessage] - session_id: str | None = None - - -@router.post("/chat") -async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: - """Handle chat requests and stream responses.""" - messages = ai.ai_sdk_ui.to_messages(request.messages) - session_id = request.session_id or "default" - checkpoint_key = f"checkpoint:{session_id}" - - llm = agent.get_llm() - - # Checkpoints resume an *interrupted* run (e.g. a hook that needed - # user input in serverless mode). Each normal chat turn is a fresh - # run — the frontend carries the full message history — so we only - # load a checkpoint when one was saved from a previous incomplete run. - saved = await file_storage.get(checkpoint_key) - checkpoint = ai.Checkpoint.model_validate(saved) if saved else None - - result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) - - async def stream_response() -> AsyncGenerator[str]: - async for chunk in ai.ai_sdk_ui.to_sse_stream(result): - yield chunk - - # If the run completed (no pending hooks), clear the checkpoint - # so the next request starts fresh. If hooks are pending, save - # the checkpoint so the next request can resume from here. - if result.pending_hooks: - await file_storage.put(checkpoint_key, result.checkpoint.model_dump()) - else: - await file_storage.delete(checkpoint_key) - - return fastapi.responses.StreamingResponse( - stream_response(), - headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, - ) diff --git a/examples/fastapi-vite/frontend/src/App.tsx b/examples/fastapi-vite/frontend/src/App.tsx index a7142590..0fedfd98 100644 --- a/examples/fastapi-vite/frontend/src/App.tsx +++ b/examples/fastapi-vite/frontend/src/App.tsx @@ -1,8 +1,21 @@ import { useChat } from "@ai-sdk/react"; -import { DefaultChatTransport } from "ai"; +import { + DefaultChatTransport, + lastAssistantMessageIsCompleteWithApprovalResponses, +} from "ai"; import type { ToolUIPart } from "ai"; +import { CheckIcon, XIcon } from "lucide-react"; import { Fragment } from "react"; +import { + Confirmation, + ConfirmationAccepted, + ConfirmationAction, + ConfirmationActions, + ConfirmationRejected, + ConfirmationRequest, + ConfirmationTitle, +} from "@/components/ai-elements/confirmation"; import { Conversation, ConversationContent, @@ -29,11 +42,16 @@ import { import { TooltipProvider } from "@/components/ui/tooltip"; export default function App() { - const { messages, sendMessage, status, stop } = useChat({ - transport: new DefaultChatTransport({ - api: "/api/chat", - }), - }); + const { messages, sendMessage, addToolApprovalResponse, status, stop } = + useChat({ + transport: new DefaultChatTransport({ + api: "/api/chat", + }), + // After the user approves/rejects a tool, automatically send the + // updated messages back to the backend so it can resume execution. + sendAutomaticallyWhen: + lastAssistantMessageIsCompleteWithApprovalResponses, + }); const isLoading = status === "submitted" || status === "streaming"; @@ -63,7 +81,8 @@ export default function App() { // Handle tool parts (type starts with "tool-") if (part.type.startsWith("tool-")) { const toolPart = part as ToolUIPart; - const isComplete = toolPart.state === "output-available"; + const isComplete = + toolPart.state === "output-available"; return ( + + {/* Human-in-the-loop approval UI */} + + + + Allow this tool to run? + + + + Approved + + + + Rejected + + + + + addToolApprovalResponse({ + id: toolPart.approval!.id, + approved: false, + }) + } + > + Reject + + + addToolApprovalResponse({ + id: toolPart.approval!.id, + approved: true, + }) + } + > + Approve + + + + (null); + +const useConfirmation = () => { + const ctx = useContext(ConfirmationContext); + if (!ctx) throw new Error("Confirmation components must be used within "); + return ctx; +}; + +/* ------------------------------------------------------------------ */ +/* */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationProps = ComponentProps<"div"> & { + approval?: ToolUIPartApproval; + state: ToolUIPart["state"]; +}; + +export const Confirmation = ({ + className, + approval, + state, + children, + ...props +}: ConfirmationProps) => { + if (!approval || state === "input-streaming" || state === "input-available") { + return null; + } + + return ( + +
+ {children} +
+
+ ); +}; + +/* ------------------------------------------------------------------ */ +/* */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationTitleProps = ComponentProps<"p">; + +export const ConfirmationTitle = ({ + className, + ...props +}: ConfirmationTitleProps) => ( +

+); + +/* ------------------------------------------------------------------ */ +/* State-conditional wrappers */ +/* ------------------------------------------------------------------ */ + +export const ConfirmationRequest = ({ children }: { children?: ReactNode }) => { + const { state } = useConfirmation(); + return state === "approval-requested" ? <>{children} : null; +}; + +export const ConfirmationAccepted = ({ children }: { children?: ReactNode }) => { + const { approval, state } = useConfirmation(); + const show = + approval?.approved === true && + (state === "approval-responded" || + state === "output-available" || + state === "output-denied"); + return show ? <>{children} : null; +}; + +export const ConfirmationRejected = ({ children }: { children?: ReactNode }) => { + const { approval, state } = useConfirmation(); + const show = + approval?.approved === false && + (state === "approval-responded" || + state === "output-available" || + state === "output-denied"); + return show ? <>{children} : null; +}; + +/* ------------------------------------------------------------------ */ +/* Actions */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationActionsProps = ComponentProps<"div">; + +export const ConfirmationActions = ({ + className, + ...props +}: ConfirmationActionsProps) => { + const { state } = useConfirmation(); + if (state !== "approval-requested") return null; + + return ( +

+ ); +}; + +export type ConfirmationActionProps = ComponentProps; + +export const ConfirmationAction = (props: ConfirmationActionProps) => ( +