From d1ad30744a6b9eca41c33b5bb2011b3939cbc97f Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 2 Mar 2026 14:01:25 -0800 Subject: [PATCH 01/10] Wire tool approval into AI SDK UI protocol --- src/vercel_ai_sdk/__init__.py | 3 +- src/vercel_ai_sdk/agent/__init__.py | 4 +- src/vercel_ai_sdk/agent/agent.py | 10 +-- src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 64 +++++++++++++++++ src/vercel_ai_sdk/ai_sdk_ui/ui_message.py | 24 +++++-- src/vercel_ai_sdk/core/hooks.py | 23 ++++-- tests/ai_sdk_ui/test_adapter.py | 87 ++++++++++++++++++++++- 7 files changed, 192 insertions(+), 23 deletions(-) diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index 79028cac..35f444d2 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,7 +1,7 @@ from . import ai_gateway, ai_sdk_ui, anthropic, mcp, openai from .core import telemetry from .core.checkpoint import Checkpoint -from .core.hooks import Hook, hook +from .core.hooks import Hook, ToolApproval, hook from .core.llm import LanguageModel # Re-export core types @@ -51,6 +51,7 @@ "StreamResult", "Hook", "HookPart", + "ToolApproval", "StructuredOutputPart", "Checkpoint", # Functions diff --git a/src/vercel_ai_sdk/agent/__init__.py b/src/vercel_ai_sdk/agent/__init__.py index 4b349261..ed06c24a 100644 --- a/src/vercel_ai_sdk/agent/__init__.py +++ b/src/vercel_ai_sdk/agent/__init__.py @@ -1,4 +1,4 @@ from . import local, proto, tools, vercel -from .agent import Agent, ToolApproval +from .agent import Agent -__all__ = ["Agent", "ToolApproval", "proto", "tools", "local", "vercel"] +__all__ = ["Agent", "proto", "tools", "local", "vercel"] diff --git a/src/vercel_ai_sdk/agent/agent.py b/src/vercel_ai_sdk/agent/agent.py index 8e3b4b2a..c7140613 100644 --- a/src/vercel_ai_sdk/agent/agent.py +++ b/src/vercel_ai_sdk/agent/agent.py @@ -4,20 +4,12 @@ import dataclasses from typing import Any -import pydantic - import vercel_ai_sdk as ai from . import proto from .tools import BUILTIN_TOOLS, _filesystem -@ai.hook -class ToolApproval(pydantic.BaseModel): - granted: bool - reason: str | None = None - - @dataclasses.dataclass class Agent: """ @@ -48,7 +40,7 @@ async def _execute_tool( """ # TODO: mypy doesn't support class decorators that change the class type — # @ai.hook returns type[Hook[T]] but mypy still sees the original BaseModel. - approval = await ToolApproval.create( # type: ignore[attr-defined] + approval = await ai.ToolApproval.create( # type: ignore[attr-defined] f"approve_{tc.tool_call_id}", metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, ) diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index ed080f1a..f050e3d6 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -11,6 +11,7 @@ from typing import Any, Literal from .. import core +from ..core import hooks from . import protocol, ui_message # ============================================================================ @@ -69,6 +70,7 @@ def __init__(self) -> None: self.started_tool_calls: set[str] = set() self.emitted_tool_results: set[str] = set() self.pending_tool_calls: set[str] = set() + self.emitted_approval_requests: set[str] = set() def close_open_blocks(self) -> list[protocol.UIMessageStreamPart]: """Close any open reasoning/text blocks, returning parts to emit.""" @@ -94,6 +96,7 @@ def reset_tool_tracking(self) -> None: self.started_tool_calls = set() self.emitted_tool_results = set() self.pending_tool_calls = set() + self.emitted_approval_requests = set() def begin_message( self, msg: core.messages.Message @@ -130,6 +133,22 @@ def begin_message( return parts +def _tool_call_id_from_approval_hook( + hook_part: core.messages.HookPart, +) -> str | None: + """Extract tool_call_id from a ToolApproval HookPart. + + Returns the tool_call_id if this is a ToolApproval hook whose hook_id + follows the ``approve_{tool_call_id}`` convention, otherwise None. + """ + if hook_part.hook_type != hooks.ToolApproval.hook_type: # type: ignore[attr-defined] + return None + prefix = "approve_" + if hook_part.hook_id.startswith(prefix): + return hook_part.hook_id[len(prefix) :] + return None + + async def to_ui_message_stream( messages: AsyncIterable[core.messages.Message], ) -> AsyncGenerator[protocol.UIMessageStreamPart]: @@ -256,6 +275,33 @@ async def to_ui_message_stream( output=result, ) + # Pass 3: Hook-based tool approvals + for msg_part in msg.parts: + if not isinstance(msg_part, core.messages.HookPart): + continue + approval_tc_id = _tool_call_id_from_approval_hook(msg_part) + if approval_tc_id is None: + continue + + if msg_part.status == "pending": + if approval_tc_id not in state.emitted_approval_requests: + state.emitted_approval_requests.add(approval_tc_id) + yield protocol.ToolApprovalRequestPart( + approval_id=msg_part.hook_id, + tool_call_id=approval_tc_id, + ) + elif msg_part.status == "resolved": + resolution = msg_part.resolution or {} + if not resolution.get("granted", False): + yield protocol.ToolOutputDeniedPart( + tool_call_id=approval_tc_id, + ) + elif msg_part.status == "cancelled": + yield protocol.ToolOutputErrorPart( + tool_call_id=approval_tc_id, + error_text="Hook cancelled", + ) + # Final cleanup for part in state.finish_step(): yield part @@ -333,6 +379,10 @@ def to_messages( ) -> list[core.messages.Message]: """Convert AI SDK v6 UI messages to internal Message format. + As a side-effect, tool parts in ``approval-responded`` state trigger + ``ToolApproval.resolve()`` so the agent loop can resume execution + without the caller needing to handle approval routing explicitly. + Args: ui_messages: List of UIMessage objects from the AI SDK v6 frontend. @@ -375,6 +425,20 @@ def to_messages( result=_normalize_tool_result(tp.output), ) ) + # Side-effect: resolve ToolApproval hooks from approval + # responses so the agent loop can resume execution. + if ( + tp.state == "approval-responded" + and tp.approval is not None + and tp.approval.approved is not None + ): + hooks.ToolApproval.resolve( # type: ignore[attr-defined] + tp.approval.id, + { + "granted": tp.approval.approved, + "reason": tp.approval.reason, + }, + ) case ( ui_message.UIStepStartPart() diff --git a/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py b/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py index cb87d9d2..f1a15a34 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/ui_message.py @@ -37,11 +37,11 @@ class UIReasoningPart(pydantic.BaseModel): # Tool invocation states in AI SDK v6: # - "input-streaming": Tool arguments are being streamed # - "input-available": Tool arguments are complete, ready for execution -# - "approval-requested": Tool requires user approval (TODO: approval workflow) -# - "approval-responded": User has responded to approval (TODO: approval workflow) +# - "approval-requested": Tool requires user approval before execution +# - "approval-responded": User has responded to approval request # - "output-available": Tool has been executed, result is available # - "output-error": Tool execution failed -# - "output-denied": Tool execution was denied by user (TODO: approval workflow) +# - "output-denied": Tool execution was denied by user UIToolInvocationState = Literal[ "input-streaming", "input-available", @@ -78,6 +78,21 @@ class UIStepStartPart(pydantic.BaseModel): type: Literal["step-start"] +class UIToolApproval(pydantic.BaseModel): + """Approval state on a tool part (AI SDK v6 protocol). + + Present when a tool requires user approval before execution. + ``id`` matches the hook label used by the ToolApproval hook. + ``approved`` is None while awaiting a response, True/False after. + """ + + model_config = pydantic.ConfigDict(populate_by_name=True) + + id: str + approved: bool | None = None + reason: str | None = None + + class UIToolPart(pydantic.BaseModel): """Tool part with dynamic type pattern: tool-{toolName}. @@ -95,8 +110,7 @@ class UIToolPart(pydantic.BaseModel): input: str | dict[str, Any] | None = None # JSON string or parsed dict output: Any | None = None error_text: str | None = pydantic.Field(default=None, alias="errorText") - # TODO: title, providerExecuted, preliminary fields - # TODO: approval workflow (approval object) + approval: UIToolApproval | None = None @property def tool_name(self) -> str: diff --git a/src/vercel_ai_sdk/core/hooks.py b/src/vercel_ai_sdk/core/hooks.py index 17a1c1f3..f4001d5e 100644 --- a/src/vercel_ai_sdk/core/hooks.py +++ b/src/vercel_ai_sdk/core/hooks.py @@ -72,7 +72,7 @@ class Hook[T: pydantic.BaseModel]: """ _schema: ClassVar[type[pydantic.BaseModel]] - _hook_type: ClassVar[str] + hook_type: ClassVar[str] @classmethod async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: @@ -113,7 +113,7 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: future: asyncio.Future[dict[str, Any]] = asyncio.Future() suspension = rt_mod.HookSuspension( label=label, - hook_type=cls._hook_type, + hook_type=cls.hook_type, metadata=metadata or {}, future=future, ) @@ -142,7 +142,7 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: parts=[ messages_.HookPart( hook_id=label, - hook_type=cls._hook_type, + hook_type=cls.hook_type, status="resolved", metadata=hook_metadata, resolution=resolution, @@ -215,7 +215,7 @@ async def cancel(cls, label: str, reason: str | None = None) -> None: parts=[ messages_.HookPart( hook_id=label, - hook_type=cls._hook_type, + hook_type=cls.hook_type, status="cancelled", metadata=hook_metadata, ) @@ -235,9 +235,22 @@ def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: (Hook,), { "_schema": cls, - "_hook_type": cls.__name__, + "hook_type": cls.__name__, "__doc__": cls.__doc__, }, ) return hook_impl + + +@hook +class ToolApproval(pydantic.BaseModel): + """Prewired hook for tool call approval. + + Used by the AI SDK UI adapter to bridge the protocol's + tool-approval-request / approval-responded flow to the + hook system. + """ + + granted: bool + reason: str | None = None diff --git a/tests/ai_sdk_ui/test_adapter.py b/tests/ai_sdk_ui/test_adapter.py index 94de6bf7..68af026e 100644 --- a/tests/ai_sdk_ui/test_adapter.py +++ b/tests/ai_sdk_ui/test_adapter.py @@ -8,7 +8,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.ai_sdk_ui import adapter, ui_message -from vercel_ai_sdk.core import messages +from vercel_ai_sdk.core import hooks, messages from ..conftest import MockLLM @@ -492,3 +492,88 @@ def test_ui_skips_unsupported_parts() -> None: internal = adapter.to_messages([ui_msg]) assert len(internal[0].parts) == 2 + + +# ----------------------------------------------------------------------------- +# Tool approval (human-in-the-loop) tests +# ----------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_tool_approval_hook_emits_approval_request() -> None: + """Pending ToolApproval HookPart emits tool-approval-request on the wire.""" + msgs = [ + # Tool pending (args complete, awaiting approval) + messages.Message( + id="msg-1", + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="rm_rf", + tool_args='{"path": "/"}', + status="pending", + state="done", + ), + ], + ), + # Hook pending (approval requested) + messages.Message( + id="msg-1", + role="assistant", + parts=[ + messages.HookPart( + hook_id="approve_tc-1", + hook_type=hooks.ToolApproval.hook_type, # type: ignore[attr-defined] + status="pending", + metadata={"tool_name": "rm_rf", "tool_args": '{"path": "/"}'}, + ), + ], + ), + ] + + event_types = await get_event_types(msgs) + assert event_types == [ + "start", + "start-step", + "tool-input-start", + "tool-input-available", + "tool-approval-request", + "finish-step", + "finish", + ] + + +def test_approval_responded_resolves_hook() -> None: + """to_messages() resolves the ToolApproval hook for approval-responded parts.""" + label = "approve_tc-42" + raw_messages = [ + { + "id": "msg-1", + "role": "assistant", + "parts": [ + { + "type": "tool-dangerous_action", + "toolCallId": "tc-42", + "state": "approval-responded", + "input": '{"x": 1}', + "approval": { + "id": label, + "approved": True, + "reason": "looks safe", + }, + } + ], + }, + ] + + # Clean up any leftover state from other tests + hooks._pending_resolutions.pop(label, None) + + ui_msgs = [ui_message.UIMessage.model_validate(m) for m in raw_messages] + adapter.to_messages(ui_msgs) + + # The side-effect should have pre-registered the resolution + assert label in hooks._pending_resolutions + resolution = hooks._pending_resolutions.pop(label) + assert resolution == {"granted": True, "reason": "looks safe"} From f5ec3b5ab4855ef720dd8eddbab06a4083c4b35c Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 2 Mar 2026 16:06:03 -0800 Subject: [PATCH 02/10] Ensure the baseline example app works --- examples/fastapi-vite/backend/__init__.py | 1 - examples/fastapi-vite/backend/main.py | 66 ++++++++++++++++--- examples/fastapi-vite/backend/pyproject.toml | 6 +- .../fastapi-vite/backend/routes/__init__.py | 1 - examples/fastapi-vite/backend/routes/chat.py | 60 ----------------- examples/fastapi-vite/backend/uv.lock | 63 ++++++++++++++++-- examples/fastapi-vite/frontend/vite.config.ts | 5 -- 7 files changed, 121 insertions(+), 81 deletions(-) delete mode 100644 examples/fastapi-vite/backend/__init__.py delete mode 100644 examples/fastapi-vite/backend/routes/__init__.py delete mode 100644 examples/fastapi-vite/backend/routes/chat.py diff --git a/examples/fastapi-vite/backend/__init__.py b/examples/fastapi-vite/backend/__init__.py deleted file mode 100644 index 7f831694..00000000 --- a/examples/fastapi-vite/backend/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Backend package diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 7aa8cf42..9ce7211a 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -1,15 +1,25 @@ """FastAPI application entry point.""" +from __future__ import annotations + +from collections.abc import AsyncGenerator + +import agent import fastapi import fastapi.middleware.cors -from routes import chat +import fastapi.responses +import pydantic +import storage -api = fastapi.FastAPI( +import vercel_ai_sdk as ai +import vercel_ai_sdk.ai_sdk_ui + +app = fastapi.FastAPI( title="py-ai-fastapi-chat", description="Chat demo using Python Vercel AI SDK", ) -api.add_middleware( +app.add_middleware( fastapi.middleware.cors.CORSMiddleware, allow_origins=["*"], allow_credentials=True, @@ -17,14 +27,54 @@ allow_headers=["*"], ) -api.include_router(chat.router) - -@api.get("/health") +@app.get("/health") async def health() -> dict[str, str]: """Health check endpoint.""" return {"status": "ok"} -app = fastapi.FastAPI() -app.mount("/api", api) +file_storage = storage.FileStorage() + + +class ChatRequest(pydantic.BaseModel): + """Request body for the chat endpoint.""" + + messages: list[ai.ai_sdk_ui.UIMessage] + session_id: str | None = None + + +@app.post("/chat") +async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: + """Handle chat requests and stream responses.""" + messages = ai.ai_sdk_ui.to_messages(request.messages) + session_id = request.session_id or "default" + checkpoint_key = f"checkpoint:{session_id}" + + llm = agent.get_llm() + + # Checkpoints resume an *interrupted* run (e.g. a hook that needed + # user input in serverless mode). Each normal chat turn is a fresh + # run — the frontend carries the full message history — so we only + # load a checkpoint when one was saved from a previous incomplete run. + saved = await file_storage.get(checkpoint_key) + checkpoint = ai.Checkpoint.model_validate(saved) if saved else None + + result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) + + async def stream_response() -> AsyncGenerator[str]: + async for chunk in ai.ai_sdk_ui.to_sse_stream(result): + yield chunk + + # If the run completed (no pending hooks), clear the checkpoint + # so the next request starts fresh. If hooks are pending, save + # the checkpoint so the next request can resume from here. + if result.pending_hooks: + await file_storage.put(checkpoint_key, result.checkpoint.model_dump()) + else: + await file_storage.delete(checkpoint_key) + + return fastapi.responses.StreamingResponse( + stream_response(), + headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, + ) diff --git a/examples/fastapi-vite/backend/pyproject.toml b/examples/fastapi-vite/backend/pyproject.toml index 1909ae57..a98fb231 100644 --- a/examples/fastapi-vite/backend/pyproject.toml +++ b/examples/fastapi-vite/backend/pyproject.toml @@ -5,5 +5,9 @@ description = "Chat demo using Python Vercel AI SDK with FastAPI" requires-python = ">=3.12" dependencies = [ "fastapi[standard]>=0.128.1", - "vercel-ai-sdk>=0.0.1.dev5", + "vercel-ai-sdk", + # "vercel-ai-sdk>=0.0.1.dev5", ] + +[tool.uv.sources] +vercel-ai-sdk = { path = "../../.." } diff --git a/examples/fastapi-vite/backend/routes/__init__.py b/examples/fastapi-vite/backend/routes/__init__.py deleted file mode 100644 index d212dab6..00000000 --- a/examples/fastapi-vite/backend/routes/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Routes package diff --git a/examples/fastapi-vite/backend/routes/chat.py b/examples/fastapi-vite/backend/routes/chat.py deleted file mode 100644 index 69b95d88..00000000 --- a/examples/fastapi-vite/backend/routes/chat.py +++ /dev/null @@ -1,60 +0,0 @@ -"""Chat route — streams LLM responses via the AI SDK UI protocol.""" - -from __future__ import annotations - -from collections.abc import AsyncGenerator - -import agent -import fastapi -import fastapi.responses -import pydantic -import storage - -import vercel_ai_sdk as ai -import vercel_ai_sdk.ai_sdk_ui - -router = fastapi.APIRouter() -file_storage = storage.FileStorage() - - -class ChatRequest(pydantic.BaseModel): - """Request body for the chat endpoint.""" - - messages: list[ai.ai_sdk_ui.UIMessage] - session_id: str | None = None - - -@router.post("/chat") -async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: - """Handle chat requests and stream responses.""" - messages = ai.ai_sdk_ui.to_messages(request.messages) - session_id = request.session_id or "default" - checkpoint_key = f"checkpoint:{session_id}" - - llm = agent.get_llm() - - # Checkpoints resume an *interrupted* run (e.g. a hook that needed - # user input in serverless mode). Each normal chat turn is a fresh - # run — the frontend carries the full message history — so we only - # load a checkpoint when one was saved from a previous incomplete run. - saved = await file_storage.get(checkpoint_key) - checkpoint = ai.Checkpoint.model_validate(saved) if saved else None - - result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) - - async def stream_response() -> AsyncGenerator[str]: - async for chunk in ai.ai_sdk_ui.to_sse_stream(result): - yield chunk - - # If the run completed (no pending hooks), clear the checkpoint - # so the next request starts fresh. If hooks are pending, save - # the checkpoint so the next request can resume from here. - if result.pending_hooks: - await file_storage.put(checkpoint_key, result.checkpoint.model_dump()) - else: - await file_storage.delete(checkpoint_key) - - return fastapi.responses.StreamingResponse( - stream_response(), - headers=ai.ai_sdk_ui.UI_MESSAGE_STREAM_HEADERS, - ) diff --git a/examples/fastapi-vite/backend/uv.lock b/examples/fastapi-vite/backend/uv.lock index f27e1630..8075cde9 100644 --- a/examples/fastapi-vite/backend/uv.lock +++ b/examples/fastapi-vite/backend/uv.lock @@ -460,6 +460,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0e/61/66938bbb5fc52dbdf84594873d5b51fb1f7c7794e9c0f5bd885f30bc507b/idna-3.11-py3-none-any.whl", hash = "sha256:771a87f49d9defaf64091e6e6fe9c18d4833f140bd19464795bc32d966ca37ea", size = 71008, upload-time = "2025-10-12T14:55:18.883Z" }, ] +[[package]] +name = "importlib-metadata" +version = "8.7.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "zipp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f3/49/3b30cad09e7771a4982d9975a8cbf64f00d4a1ececb53297f1d9a7be1b10/importlib_metadata-8.7.1.tar.gz", hash = "sha256:49fef1ae6440c182052f407c8d34a68f72efc36db9ca90dc0113398f2fdde8bb", size = 57107, upload-time = "2025-12-21T10:00:19.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fa/5e/f8e9a1d23b9c20a551a8a02ea3637b4642e22c2626e3a13a9a29cdea99eb/importlib_metadata-8.7.1-py3-none-any.whl", hash = "sha256:5a1f80bf1daa489495071efbb095d75a634cf28a8bc299581244063b53176151", size = 27865, upload-time = "2025-12-21T10:00:18.329Z" }, +] + [[package]] name = "jinja2" version = "3.1.6" @@ -695,6 +707,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/a0/cf4297aa51bbc21e83ef0ac018947fa06aea8f2364aad7c96cbf148590e6/openai-2.20.0-py3-none-any.whl", hash = "sha256:38d989c4b1075cd1f76abc68364059d822327cf1a932531d429795f4fc18be99", size = 1098479, upload-time = "2026-02-10T19:02:52.157Z" }, ] +[[package]] +name = "opentelemetry-api" +version = "1.39.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "importlib-metadata" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/97/b9/3161be15bb8e3ad01be8be5a968a9237c3027c5be504362ff800fca3e442/opentelemetry_api-1.39.1.tar.gz", hash = "sha256:fbde8c80e1b937a2c61f20347e91c0c18a1940cecf012d62e65a7caf08967c9c", size = 65767, upload-time = "2025-12-11T13:32:39.182Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/df/d3f1ddf4bb4cb50ed9b1139cc7b1c54c34a1e7ce8fd1b9a37c0d1551a6bd/opentelemetry_api-1.39.1-py3-none-any.whl", hash = "sha256:2edd8463432a7f8443edce90972169b195e7d6a05500cd29e6d13898187c9950", size = 66356, upload-time = "2025-12-11T13:32:17.304Z" }, +] + [[package]] name = "py-ai-fastapi-chat" version = "0.1.0" @@ -707,7 +732,7 @@ dependencies = [ [package.metadata] requires-dist = [ { name = "fastapi", extras = ["standard"], specifier = ">=0.128.1" }, - { name = "vercel-ai-sdk", specifier = ">=0.0.1.dev5" }, + { name = "vercel-ai-sdk", directory = "../../../" }, ] [[package]] @@ -1319,18 +1344,37 @@ wheels = [ [[package]] name = "vercel-ai-sdk" version = "0.0.1.dev5" -source = { registry = "https://pypi.org/simple" } +source = { directory = "../../../" } dependencies = [ { name = "anthropic" }, { name = "httpx" }, { name = "mcp" }, { name = "openai" }, + { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "vercel" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b0/dd/3b399134076883247582af3919d5fbd38c9e270a42005fa27d1472705dd1/vercel_ai_sdk-0.0.1.dev5.tar.gz", hash = "sha256:998814780fc6163000be1b29e48dacbe710adb8a765636867bd6dd5a6b9b41b0", size = 37870, upload-time = "2026-02-25T16:31:01.844Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c5/7a/f50dd25ed596c07c2222f2abd81c413169534904bfc4e735b1a5e7084870/vercel_ai_sdk-0.0.1.dev5-py3-none-any.whl", hash = "sha256:22de26c8b667738a825f812aa5b7042d741dd905033c16b3515f6123cd220245", size = 50938, upload-time = "2026-02-25T16:31:00.175Z" }, + +[package.metadata] +requires-dist = [ + { name = "anthropic", specifier = ">=0.83.0" }, + { name = "httpx", specifier = ">=0.28.1" }, + { name = "mcp", specifier = ">=1.18.0" }, + { name = "openai", specifier = ">=2.14.0" }, + { name = "opentelemetry-api", specifier = ">=1.0" }, + { name = "pydantic", specifier = ">=2.12.5" }, + { name = "vercel", specifier = ">=0.3.8" }, +] + +[package.metadata.requires-dev] +dev = [ + { name = "mypy", specifier = ">=1.11" }, + { name = "opentelemetry-sdk", specifier = ">=1.0" }, + { name = "pytest", specifier = ">=8.0" }, + { name = "pytest-asyncio", specifier = ">=0.24" }, + { name = "python-dotenv", specifier = ">=1.2.1" }, + { name = "rich", specifier = ">=14.2.0" }, + { name = "ruff", specifier = ">=0.8" }, ] [[package]] @@ -1464,3 +1508,12 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/3e/28135a24e384493fa804216b79a6a6759a38cc4ff59118787b9fb693df93/websockets-16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:b14dc141ed6d2dde437cddb216004bcac6a1df0935d79656387bd41632ba0bbd", size = 178531, upload-time = "2026-01-10T09:23:35.016Z" }, { url = "https://files.pythonhosted.org/packages/6f/28/258ebab549c2bf3e64d2b0217b973467394a9cea8c42f70418ca2c5d0d2e/websockets-16.0-py3-none-any.whl", hash = "sha256:1637db62fad1dc833276dded54215f2c7fa46912301a24bd94d45d46a011ceec", size = 171598, upload-time = "2026-01-10T09:23:45.395Z" }, ] + +[[package]] +name = "zipp" +version = "3.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/02/0f2892c661036d50ede074e376733dca2ae7c6eb617489437771209d4180/zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166", size = 25547, upload-time = "2025-06-08T17:06:39.4Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/54/647ade08bf0db230bfea292f893923872fd20be6ac6f53b2b936ba839d75/zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e", size = 10276, upload-time = "2025-06-08T17:06:38.034Z" }, +] diff --git a/examples/fastapi-vite/frontend/vite.config.ts b/examples/fastapi-vite/frontend/vite.config.ts index 39d31769..0cca23e1 100644 --- a/examples/fastapi-vite/frontend/vite.config.ts +++ b/examples/fastapi-vite/frontend/vite.config.ts @@ -11,9 +11,4 @@ export default defineConfig({ '@': path.resolve(__dirname, './src'), }, }, - server: { - proxy: { - '/api': 'http://localhost:8000', - }, - }, }) From 8ca7b54849ad037e9269309d6c2569e58372beed Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 2 Mar 2026 17:22:47 -0800 Subject: [PATCH 03/10] Make human-in-the-loop work in the fastapi-vite example --- examples/fastapi-vite/README.md | 17 +++ examples/fastapi-vite/backend/agent.py | 52 ++++++- examples/fastapi-vite/backend/main.py | 67 +++++++-- examples/fastapi-vite/frontend/src/App.tsx | 80 +++++++++-- .../components/ai-elements/confirmation.tsx | 131 ++++++++++++++++++ src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 18 +++ 6 files changed, 340 insertions(+), 25 deletions(-) create mode 100644 examples/fastapi-vite/frontend/src/components/ai-elements/confirmation.tsx diff --git a/examples/fastapi-vite/README.md b/examples/fastapi-vite/README.md index 7f75c5e5..e8552512 100644 --- a/examples/fastapi-vite/README.md +++ b/examples/fastapi-vite/README.md @@ -1,12 +1,29 @@ # fastapi-chat Chat demo using the Python Vercel AI SDK with a FastAPI backend and React frontend. +Includes **human-in-the-loop tool approval** — every tool call is gated +behind user confirmation before execution. ## Stack - **Backend:** FastAPI + vercel-ai-sdk (Python 3.12) - **Frontend:** Vite + React + AI SDK UI + AI Elements +## Human-in-the-Loop + +The agent graph in `backend/agent.py` uses the `ToolApproval` hook to +suspend execution whenever the LLM wants to call a tool. The flow is: + +1. LLM emits a tool call +2. Backend creates a `ToolApproval` hook — this emits an + `approval-requested` event on the SSE stream and suspends execution +3. The frontend renders Approve / Reject buttons via the + `` component (from AI Elements) +4. When the user clicks a button, `addToolApprovalResponse()` patches + the message and sends a new request with the decision +5. The backend resumes from the checkpoint and either executes the tool + or marks it as denied + ## Setup ```bash diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 94ecd714..023a66d2 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -1,5 +1,10 @@ -"""Agent logic for the chat demo.""" +"""Agent logic for the chat demo. +Demonstrates human-in-the-loop tool approval using ToolApproval hooks. +Every tool call is gated behind user approval before execution. +""" + +import asyncio from typing import Any import vercel_ai_sdk as ai @@ -19,16 +24,49 @@ def get_llm() -> ai.LanguageModel: TOOLS: list[ai.Tool[..., Any]] = [talk_to_mothership] +async def _execute_with_approval( + tc: ai.ToolPart, message: ai.Message | None = None +) -> None: + """Execute a tool call only after the user grants approval. + + Creates a ToolApproval hook that suspends execution until the + frontend responds with an approve/reject decision. + """ + approval = await ai.ToolApproval.create( # type: ignore[attr-defined] + f"approve_{tc.tool_call_id}", + metadata={"tool_name": tc.tool_name, "tool_args": tc.tool_args}, + ) + + if approval.granted: + await ai.execute_tool(tc, message=message) + else: + tc.set_error("Tool call was denied by the user.") + + async def graph( llm: ai.LanguageModel, messages: list[ai.Message], tools: list[ai.Tool[..., Any]], ) -> ai.StreamResult: - """ - Agent graph: stream LLM, execute tools, repeat until done. + """Agent graph with human-in-the-loop tool approval. - This is a plain async function that goes through the Runtime queue - via stream_loop. When hooks are added later, they slot in here - between tool calls — no structural change needed. + Loops: stream LLM -> request approval -> execute tools -> repeat. + The ToolApproval hook suspends execution and emits an approval- + request event on the SSE stream. The frontend displays Approve / + Reject buttons and sends the decision back on the next request. """ - return await ai.stream_loop(llm, messages, tools) + local_messages = list(messages) + + while True: + result = await ai.stream_step(llm, local_messages, tools) + + if not result.tool_calls: + return result + + last_msg = result.last_message + assert last_msg is not None + local_messages.append(last_msg) + + await asyncio.gather( + *(_execute_with_approval(tc, message=last_msg) for tc in result.tool_calls) + ) diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 9ce7211a..8175dbbb 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -44,6 +44,29 @@ class ChatRequest(pydantic.BaseModel): session_id: str | None = None +def _has_matching_approval( + ui_messages: list[ai.ai_sdk_ui.UIMessage], + pending_hooks: list[str], +) -> bool: + """True when the incoming messages resolve at least one pending hook. + + Hook labels follow the ``approve_{tool_call_id}`` convention set by + ``_execute_with_approval`` in the agent graph. + """ + pending = set(pending_hooks) + for msg in ui_messages: + for part in msg.parts: + state = getattr(part, "state", None) + tcid = getattr(part, "tool_call_id", None) + if ( + state == "approval-responded" + and tcid is not None + and f"approve_{tcid}" in pending + ): + return True + return False + + @app.post("/chat") async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: """Handle chat requests and stream responses.""" @@ -53,24 +76,48 @@ async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: llm = agent.get_llm() - # Checkpoints resume an *interrupted* run (e.g. a hook that needed - # user input in serverless mode). Each normal chat turn is a fresh - # run — the frontend carries the full message history — so we only - # load a checkpoint when one was saved from a previous incomplete run. + # Only load a checkpoint when this request is actually resuming + # an interrupted run — i.e. the frontend is sending back an + # approval response that matches a pending hook. Otherwise + # discard stale checkpoints so fresh turns aren't poisoned. + checkpoint = None saved = await file_storage.get(checkpoint_key) - checkpoint = ai.Checkpoint.model_validate(saved) if saved else None + if saved: + pending = saved.get("pending_hooks", []) + if _has_matching_approval(request.messages, pending): + checkpoint = ai.Checkpoint.model_validate(saved["checkpoint"]) + # The frontend sends the full message history including the + # assistant message from the interrupted run. The checkpoint + # will replay that same step, so strip the trailing assistant + # message to avoid sending a duplicate tool_use to the LLM. + if messages and messages[-1].role == "assistant": + messages = messages[:-1] + else: + await file_storage.delete(checkpoint_key) - result = ai.run(agent.graph, llm, messages, agent.TOOLS, checkpoint=checkpoint) + result = ai.run( + agent.graph, + llm, + messages, + agent.TOOLS, + checkpoint=checkpoint, + cancel_on_hooks=True, + ) async def stream_response() -> AsyncGenerator[str]: async for chunk in ai.ai_sdk_ui.to_sse_stream(result): yield chunk - # If the run completed (no pending hooks), clear the checkpoint - # so the next request starts fresh. If hooks are pending, save - # the checkpoint so the next request can resume from here. + # Save checkpoint + pending hook labels so the next request + # can decide whether it's a resume or a fresh turn. if result.pending_hooks: - await file_storage.put(checkpoint_key, result.checkpoint.model_dump()) + await file_storage.put( + checkpoint_key, + { + "checkpoint": result.checkpoint.model_dump(), + "pending_hooks": list(result.pending_hooks.keys()), + }, + ) else: await file_storage.delete(checkpoint_key) diff --git a/examples/fastapi-vite/frontend/src/App.tsx b/examples/fastapi-vite/frontend/src/App.tsx index a7142590..0fedfd98 100644 --- a/examples/fastapi-vite/frontend/src/App.tsx +++ b/examples/fastapi-vite/frontend/src/App.tsx @@ -1,8 +1,21 @@ import { useChat } from "@ai-sdk/react"; -import { DefaultChatTransport } from "ai"; +import { + DefaultChatTransport, + lastAssistantMessageIsCompleteWithApprovalResponses, +} from "ai"; import type { ToolUIPart } from "ai"; +import { CheckIcon, XIcon } from "lucide-react"; import { Fragment } from "react"; +import { + Confirmation, + ConfirmationAccepted, + ConfirmationAction, + ConfirmationActions, + ConfirmationRejected, + ConfirmationRequest, + ConfirmationTitle, +} from "@/components/ai-elements/confirmation"; import { Conversation, ConversationContent, @@ -29,11 +42,16 @@ import { import { TooltipProvider } from "@/components/ui/tooltip"; export default function App() { - const { messages, sendMessage, status, stop } = useChat({ - transport: new DefaultChatTransport({ - api: "/api/chat", - }), - }); + const { messages, sendMessage, addToolApprovalResponse, status, stop } = + useChat({ + transport: new DefaultChatTransport({ + api: "/api/chat", + }), + // After the user approves/rejects a tool, automatically send the + // updated messages back to the backend so it can resume execution. + sendAutomaticallyWhen: + lastAssistantMessageIsCompleteWithApprovalResponses, + }); const isLoading = status === "submitted" || status === "streaming"; @@ -63,7 +81,8 @@ export default function App() { // Handle tool parts (type starts with "tool-") if (part.type.startsWith("tool-")) { const toolPart = part as ToolUIPart; - const isComplete = toolPart.state === "output-available"; + const isComplete = + toolPart.state === "output-available"; return ( + + {/* Human-in-the-loop approval UI */} + + + + Allow this tool to run? + + + + Approved + + + + Rejected + + + + + addToolApprovalResponse({ + id: toolPart.approval!.id, + approved: false, + }) + } + > + Reject + + + addToolApprovalResponse({ + id: toolPart.approval!.id, + approved: true, + }) + } + > + Approve + + + + (null); + +const useConfirmation = () => { + const ctx = useContext(ConfirmationContext); + if (!ctx) throw new Error("Confirmation components must be used within "); + return ctx; +}; + +/* ------------------------------------------------------------------ */ +/* */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationProps = ComponentProps<"div"> & { + approval?: ToolUIPartApproval; + state: ToolUIPart["state"]; +}; + +export const Confirmation = ({ + className, + approval, + state, + children, + ...props +}: ConfirmationProps) => { + if (!approval || state === "input-streaming" || state === "input-available") { + return null; + } + + return ( + +
+ {children} +
+
+ ); +}; + +/* ------------------------------------------------------------------ */ +/* */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationTitleProps = ComponentProps<"p">; + +export const ConfirmationTitle = ({ + className, + ...props +}: ConfirmationTitleProps) => ( +

+); + +/* ------------------------------------------------------------------ */ +/* State-conditional wrappers */ +/* ------------------------------------------------------------------ */ + +export const ConfirmationRequest = ({ children }: { children?: ReactNode }) => { + const { state } = useConfirmation(); + return state === "approval-requested" ? <>{children} : null; +}; + +export const ConfirmationAccepted = ({ children }: { children?: ReactNode }) => { + const { approval, state } = useConfirmation(); + const show = + approval?.approved === true && + (state === "approval-responded" || + state === "output-available" || + state === "output-denied"); + return show ? <>{children} : null; +}; + +export const ConfirmationRejected = ({ children }: { children?: ReactNode }) => { + const { approval, state } = useConfirmation(); + const show = + approval?.approved === false && + (state === "approval-responded" || + state === "output-available" || + state === "output-denied"); + return show ? <>{children} : null; +}; + +/* ------------------------------------------------------------------ */ +/* Actions */ +/* ------------------------------------------------------------------ */ + +export type ConfirmationActionsProps = ComponentProps<"div">; + +export const ConfirmationActions = ({ + className, + ...props +}: ConfirmationActionsProps) => { + const { state } = useConfirmation(); + if (state !== "approval-requested") return null; + + return ( +

+ ); +}; + +export type ConfirmationActionProps = ComponentProps; + +export const ConfirmationAction = (props: ConfirmationActionProps) => ( +