From a7f9d692f41707910d7476305d54ffb1f3255ea9 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 13 Feb 2026 13:18:55 -0800 Subject: [PATCH 1/4] Patch MCP implementation --- src/vercel_ai_sdk/mcp/client.py | 5 +- tests/mcp/__init__.py | 0 tests/mcp/test_client.py | 112 ++++++++++++++++++++++++++++++++ 3 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 tests/mcp/__init__.py create mode 100644 tests/mcp/test_client.py diff --git a/src/vercel_ai_sdk/mcp/client.py b/src/vercel_ai_sdk/mcp/client.py index ee8692a9..97524fe4 100644 --- a/src/vercel_ai_sdk/mcp/client.py +++ b/src/vercel_ai_sdk/mcp/client.py @@ -231,12 +231,15 @@ def _mcp_tool_to_native( if tool_prefix: name = f"{tool_prefix}_{name}" - return core.tools.Tool( + t = core.tools.Tool( name=name, description=mcp_tool.description or "", tool_schema=mcp_tool.inputSchema, fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), ) + # Register so execute_tool() can find it by name + core.tools._tool_registry[name] = t + return t async def close_connections() -> None: diff --git a/tests/mcp/__init__.py b/tests/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/mcp/test_client.py b/tests/mcp/test_client.py new file mode 100644 index 00000000..7853fb97 --- /dev/null +++ b/tests/mcp/test_client.py @@ -0,0 +1,112 @@ +"""MCP client: tool registration in global registry, end-to-end execution.""" + +import asyncio + +import mcp.types +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.core.tools import _tool_registry, get_tool +from vercel_ai_sdk.mcp.client import _mcp_tool_to_native + +from ..conftest import MockLLM, text_msg, tool_msg + + +def _fake_mcp_tool( + name: str = "mcp_echo", description: str = "Echo input" +) -> mcp.types.Tool: + """Build a minimal mcp.types.Tool for testing.""" + return mcp.types.Tool( + name=name, + description=description, + inputSchema={ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + ) + + +def _noop_transport_factory(): + """Dummy transport factory — never actually called in these tests.""" + raise NotImplementedError("should not be called") + + +# -- _mcp_tool_to_native registers in global registry ---------------------- + + +def test_mcp_tool_to_native_registers_in_global_registry(): + """Converting an MCP tool to native registers it in _tool_registry.""" + mcp_tool = _fake_mcp_tool(name="mcp_reg_test") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + + assert native.name == "mcp_reg_test" + assert get_tool("mcp_reg_test") is native + assert _tool_registry["mcp_reg_test"] is native + + +def test_mcp_tool_to_native_with_prefix(): + """Tool prefix is prepended to the name and both name forms are correct.""" + mcp_tool = _fake_mcp_tool(name="echo") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, "ctx7") + + assert native.name == "ctx7_echo" + assert get_tool("ctx7_echo") is native + + +def test_mcp_tool_to_native_schema_preserved(): + """The inputSchema from the MCP tool is passed through as tool_schema.""" + mcp_tool = _fake_mcp_tool() + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + + assert native.tool_schema == mcp_tool.inputSchema + assert native.description == "Echo input" + + +# -- End-to-end: MCP tool executes through stream_loop -------------------- + + +@pytest.mark.asyncio +async def test_mcp_tool_executes_through_stream_loop(): + """An MCP-style tool registered via _mcp_tool_to_native can be called by the agent loop.""" + call_log: list[dict] = [] + + async def fake_fn(**kwargs): + call_log.append(kwargs) + return f"echoed: {kwargs.get('text', '')}" + + # Build and register a tool the same way the MCP client does, + # but with a fake fn so we don't need a real MCP server. + mcp_tool = _fake_mcp_tool(name="mcp_e2e_echo") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + # Replace the real fn (which would try to connect) with our fake + native.fn = fake_fn + _tool_registry[native.name] = native + + async def graph(llm: ai.LanguageModel): + return await ai.stream_loop( + llm, + messages=ai.make_messages(user="echo hello"), + tools=[native], + ) + + call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] + call2 = [text_msg("Done.", id="msg-2")] + llm = MockLLM([call1, call2]) + + result = ai.run(graph, llm) + msgs = [m async for m in result] + + # Tool was called with the right args + assert len(call_log) == 1 + assert call_log[0] == {"text": "hello"} + + # Tool result is visible in messages + tool_results = [ + m for m in msgs if m.tool_calls and m.tool_calls[0].status == "result" + ] + assert len(tool_results) >= 1 + assert tool_results[0].tool_calls[0].result == "echoed: hello" + + # LLM was called twice (tool call + final text) + assert llm.call_count == 2 From 2adc70196f4edc0b90b88ae7998d4aaee8996f8d Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 13 Feb 2026 13:26:23 -0800 Subject: [PATCH 2/4] Add get_hook_part to message --- examples/multiagent-textual/client.py | 21 +++++++-------------- examples/multiagent-textual/server.py | 10 +--------- examples/samples/agent.py | 5 ++--- examples/samples/hooks.py | 21 ++++++++++----------- src/vercel_ai_sdk/core/messages.py | 8 ++++++++ tests/core/test_messages.py | 27 +++++++++++++++++++++++++++ 6 files changed, 55 insertions(+), 37 deletions(-) diff --git a/examples/multiagent-textual/client.py b/examples/multiagent-textual/client.py index 6a85fb37..5f2bfde8 100644 --- a/examples/multiagent-textual/client.py +++ b/examples/multiagent-textual/client.py @@ -158,15 +158,15 @@ async def run_websocket(self) -> None: # ------------------------------------------------------------------ def _handle_message(self, msg: ai.Message) -> None: - hook_part = _get_hook_part(msg) label = msg.label or "unknown" - if hook_part and hook_part.status == "pending": - self._on_hook_pending(hook_part) - return - if hook_part and hook_part.status == "resolved": - self._on_hook_resolved(hook_part) - return + if (hook_part := msg.get_hook_part()) is not None: + if hook_part.status == "pending": + self._on_hook_pending(hook_part) + return + if hook_part.status == "resolved": + self._on_hook_resolved(hook_part) + return panel = self._get_panel(label) if panel is None: @@ -303,13 +303,6 @@ def _set_input_placeholder(self, text: str) -> None: inp.placeholder = text -def _get_hook_part(msg: ai.Message) -> ai.HookPart | None: - for part in msg.parts: - if isinstance(part, ai.HookPart): - return part - return None - - if __name__ == "__main__": app = MultiAgentApp() app.run() diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index 2b07b5e4..0e253b96 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -154,13 +154,6 @@ async def multiagent(llm: ai.LanguageModel, query: str): # --------------------------------------------------------------------------- -def _get_hook_part(msg: ai.Message) -> ai.HookPart | None: - for part in msg.parts: - if isinstance(part, ai.HookPart): - return part - return None - - def _normalise_message(data: dict) -> dict: """Ensure ToolPart.result is always a dict for safe deserialisation.""" for part in data.get("parts", []): @@ -208,8 +201,7 @@ async def read_resolutions(): data = _normalise_message(msg.model_dump()) await websocket.send_json(data) - hook_part = _get_hook_part(msg) - if hook_part: + if hook_part := msg.get_hook_part(): print(f" Hook {hook_part.status}: {hook_part.hook_id}") finally: reader.cancel() diff --git a/examples/samples/agent.py b/examples/samples/agent.py index 98cfdab8..ebca2c6a 100644 --- a/examples/samples/agent.py +++ b/examples/samples/agent.py @@ -24,9 +24,8 @@ async def main(): async for msg in coding_agent.run(messages): # Auto-approve all tool calls - for part in msg.parts: - if isinstance(part, ai.HookPart) and part.status == "pending": - agent.ToolApproval.resolve(part.hook_id, {"granted": True}) + if (hook := msg.get_hook_part()) and hook.status == "pending": + agent.ToolApproval.resolve(hook.hook_id, {"granted": True}) if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 9b3a51c4..0fc89353 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -61,17 +61,16 @@ async def main(): async for msg in ai.run(graph, llm, "When will the robots take over?"): # Hook parts arrive as pending, waiting for resolution - for part in msg.parts: - if isinstance(part, ai.HookPart) and part.status == "pending": - answer = input(f"Approve {part.hook_id}? [y/n] ") - CommunicationApproval.resolve( - part.hook_id, - { - "granted": answer.strip().lower() in ("y", "yes"), - "reason": "operator decision", - }, - ) - continue + if (hook := msg.get_hook_part()) and hook.status == "pending": + answer = input(f"Approve {hook.hook_id}? [y/n] ") + CommunicationApproval.resolve( + hook.hook_id, + { + "granted": answer.strip().lower() in ("y", "yes"), + "reason": "operator decision", + }, + ) + continue if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/src/vercel_ai_sdk/core/messages.py b/src/vercel_ai_sdk/core/messages.py index b8f5dd6d..09a2684a 100644 --- a/src/vercel_ai_sdk/core/messages.py +++ b/src/vercel_ai_sdk/core/messages.py @@ -146,6 +146,14 @@ def get_tool_part(self, tool_call_id: str) -> ToolPart | None: return part return None + def get_hook_part(self, hook_id: str | None = None) -> HookPart | None: + """Find a HookPart by hook_id, or return the first HookPart if no id given.""" + for part in self.parts: + if isinstance(part, HookPart): + if hook_id is None or part.hook_id == hook_id: + return part + return None + def make_messages(*, system: str | None = None, user: str) -> list[Message]: """Convenience builder for common system + user message pattern.""" diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index c5cb1789..612e801f 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -1,6 +1,7 @@ """Message model: properties, ToolPart.set_result, make_messages.""" from vercel_ai_sdk.core.messages import ( + HookPart, Message, ReasoningPart, TextPart, @@ -151,6 +152,32 @@ def test_get_tool_part_missing(): assert m.get_tool_part("tc-nope") is None +# -- get_hook_part --------------------------------------------------------- + + +def test_get_hook_part_found(): + """get_hook_part returns the HookPart when present.""" + hook = HookPart(hook_id="h1", hook_type="Approval", status="pending") + m = Message(id="m1", role="assistant", parts=[hook]) + assert m.get_hook_part() is hook + assert m.get_hook_part("h1") is hook + + +def test_get_hook_part_by_id(): + """get_hook_part with a specific hook_id skips non-matching hooks.""" + h1 = HookPart(hook_id="h1", hook_type="Approval", status="pending") + h2 = HookPart(hook_id="h2", hook_type="Approval", status="resolved") + m = Message(id="m1", role="assistant", parts=[h1, h2]) + assert m.get_hook_part("h2") is h2 + + +def test_get_hook_part_missing(): + """get_hook_part returns None when no HookPart exists.""" + m = Message(id="m1", role="assistant", parts=[TextPart(text="no hooks")]) + assert m.get_hook_part() is None + assert m.get_hook_part("h-nope") is None + + # -- ToolPart.set_result --------------------------------------------------- From 240eb3a736c40a22195c38236b6b8b9fdb4de6c5 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 13 Feb 2026 13:42:11 -0800 Subject: [PATCH 3/4] Update the readme and bump the version --- README.md | 132 ++++++++++++++++++++++++++++++++++++++----------- pyproject.toml | 2 +- 2 files changed, 103 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 2d460c43..6b5e543f 100644 --- a/README.md +++ b/README.md @@ -9,6 +9,7 @@ uv add vercel-ai-sdk ``` ```python +import os import vercel_ai_sdk as ai @ai.tool @@ -27,7 +28,7 @@ async def agent(llm, query): ) llm = ai.openai.OpenAIModel( - model="anthropic/claude-sonnet-4", + model="anthropic/claude-opus-4.6", base_url="https://ai-gateway.vercel.sh/v1", api_key=os.environ["AI_GATEWAY_API_KEY"], ) @@ -40,15 +41,21 @@ async for msg in ai.run(agent, llm, "When will the robots take over?"): ### Core Primitives -#### `ai.run(root, *args)` +#### `ai.run(root, *args, checkpoint=None, cancel_on_hooks=False)` -Entry point. Executes an async function, yields all `Message` objects from nested streams. +Entry point. Starts `root` as a background task, processes the step/hook queue, yields `Message` objects. Returns a `RunResult`. ```python -async for msg in ai.run(my_agent, llm, "hello"): +result = ai.run(my_agent, llm, "hello") +async for msg in result: print(msg.text_delta, end="") + +result.checkpoint # Checkpoint with all completed work +result.pending_hooks # dict of unresolved hooks (empty if run completed) ``` +If `root` declares a `runtime: ai.Runtime` parameter, it's auto-injected. + #### `@ai.tool` Decorator that turns an async function into a `Tool`. Parameters extracted from type hints, docstring becomes description. @@ -60,7 +67,7 @@ async def search(query: str, limit: int = 10) -> list[str]: ... ``` -If a tool declares a `runtime: ai.Runtime` parameter, it's auto-injected: +If a tool declares a `runtime: ai.Runtime` parameter, it's auto-injected (not passed by the LLM): ```python @ai.tool @@ -83,6 +90,8 @@ async def my_custom_step(llm, messages): result = await my_custom_step(llm, messages) # returns StreamResult ``` +Must be called within `ai.run()` (needs a Runtime context). + #### `@ai.hook` Decorator that creates a suspension point from a pydantic model. The model defines the resolution schema. @@ -92,20 +101,38 @@ Decorator that creates a suspension point from a pydantic model. The model defin class Approval(pydantic.BaseModel): granted: bool reason: str +``` -# In your agent - blocks until resolved -approval = await Approval.create(metadata={"tool": "send_email"}) +Inside your agent — blocks until resolved: + +```python +approval = await Approval.create("approve_send_email", metadata={"tool": "send_email"}) if approval.granted: ... +``` + +From outside (API handler, websocket, iterator loop, etc.): -# From outside (API handler, websocket, etc.) -Approval.resolve(hook_id, {"granted": True, "reason": "User approved"}) +```python +Approval.resolve("approve_send_email", {"granted": True, "reason": "User approved"}) +Approval.cancel("approve_send_email") # or cancel it ``` -For serverless (raises `HookPending` if resolution not provided): +**Long-running mode** (`cancel_on_hooks=False`, the default): the `await` in `create()` blocks until `resolve()` or `cancel()` is called from external code. + +**Serverless mode** (`cancel_on_hooks=True`): if no resolution is available, the hook's future is cancelled and the branch dies. Inspect `result.pending_hooks` and `result.checkpoint` to resume later: ```python -approval = Approval.create_or_raise(f"approval_{tool_call_id}", resolutions=saved_resolutions) +result = ai.run(my_agent, llm, query, cancel_on_hooks=True) +async for msg in result: + ... + +if result.pending_hooks: + # Save result.checkpoint, collect resolutions, then re-enter: + Approval.resolve("approve_send_email", {"granted": True, "reason": "User approved"}) + result = ai.run(my_agent, llm, query, checkpoint=result.checkpoint) + async for msg in result: + ... ``` ### Convenience Functions @@ -127,14 +154,16 @@ Full agent loop: calls LLM, executes tools, repeats until no more tool calls. Re result = await ai.stream_loop(llm, messages, tools=[search, get_weather]) ``` -#### `ToolPart.execute()` +#### `ai.execute_tool(tool_call, message=None)` -Execute a tool call. Tools are looked up from the global registry (populated by `@ai.tool`). +Execute a single tool call. Looks up the tool from the global registry (populated by `@ai.tool`). Updates the `ToolPart` with the result. If `message` is provided, emits it to the Runtime queue so the UI sees the status change. ```python -await asyncio.gather(*(tc.execute() for tc in result.tool_calls)) +await asyncio.gather(*(ai.execute_tool(tc, message=last_msg) for tc in result.tool_calls)) ``` +Supports checkpoint replay — returns the cached result without re-executing if one exists. + #### `ai.make_messages(*, system=None, user)` Build a message list from system + user strings. @@ -143,6 +172,33 @@ Build a message list from system + user strings. messages = ai.make_messages(system="You are helpful.", user="Hello!") ``` +#### `ai.get_checkpoint()` + +Get the current `Checkpoint` from the active Runtime context. Call this from within `ai.run()`. + +```python +checkpoint = ai.get_checkpoint() +``` + +### Checkpoints + +`Checkpoint` records completed work (LLM steps, tool executions, hook resolutions) so a run can be replayed without re-executing already-finished operations. + +```python +# After a run completes or suspends +checkpoint = result.checkpoint +data = checkpoint.serialize() # dict, JSON-safe + +# Later: restore and resume +checkpoint = ai.Checkpoint.deserialize(data) +result = ai.run(my_agent, llm, query, checkpoint=checkpoint) +``` + +Three event types are tracked: +- **Steps** — LLM call results (replayed without calling the model) +- **Tools** — tool execution results (replayed without re-executing) +- **Hooks** — hook resolutions (replayed without re-suspending) + ### Adapters #### LLM Providers @@ -150,7 +206,7 @@ messages = ai.make_messages(system="You are helpful.", user="Hello!") ```python # OpenAI-compatible (including Vercel AI Gateway) llm = ai.openai.OpenAIModel( - model="anthropic/claude-sonnet-4", + model="anthropic/claude-opus-4.6", base_url="https://ai-gateway.vercel.sh/v1", api_key=os.environ["AI_GATEWAY_API_KEY"], thinking=True, # enable reasoning output @@ -159,7 +215,7 @@ llm = ai.openai.OpenAIModel( # Anthropic (native client) llm = ai.anthropic.AnthropicModel( - model="claude-sonnet-4-5-20250929", + model="claude-opus-4.6-20250916", thinking=True, budget_tokens=10000, ) @@ -182,6 +238,8 @@ tools = await ai.mcp.get_stdio_tools( ) ``` +MCP connections are pooled per `ai.run()` and cleaned up automatically. + #### AI SDK UI For streaming to AI SDK frontend (`useChat`, etc.): @@ -204,26 +262,40 @@ return StreamingResponse(stream_response(), headers=UI_MESSAGE_STREAM_HEADERS) | Type | Description | |------|-------------| -| `Message` | Universal message with `role`, `parts`, `label`. Properties: `text`, `text_delta`, `reasoning_delta`, `tool_deltas`, `is_done` | +| `Message` | Universal message with `role`, `parts`, `label`. Properties: `text`, `text_delta`, `reasoning_delta`, `tool_deltas`, `tool_calls`, `is_done` | | `TextPart` | Text content with streaming `state` and `delta` | -| `ToolPart` | Tool call with `tool_call_id`, `tool_name`, `tool_args`, `status`, `result`. Has `.execute()` method | +| `ToolPart` | Tool call with `tool_call_id`, `tool_name`, `tool_args`, `status`, `result`. Has `.set_result()` | +| `ToolDelta` | Tool argument streaming delta (`tool_call_id`, `tool_name`, `args_delta`) | | `ReasoningPart` | Model reasoning/thinking with optional `signature` (Anthropic) | -| `HookPart` | Hook suspension with `hook_id`, `hook_type`, `status`, `metadata`, `resolution` | -| `PartState` | Literal type: `"streaming"` or `"done"` | -| `StreamResult` | Result of a stream: `messages`, `tool_calls`, `text`, `last_message` | +| `HookPart` | Hook suspension with `hook_id`, `hook_type`, `status` (`pending`/`resolved`/`cancelled`), `metadata`, `resolution` | +| `Part` | Union: `TextPart \| ToolPart \| ReasoningPart \| HookPart` | +| `PartState` | Literal: `"streaming"` \| `"done"` | +| `StreamResult` | Result of a stream step: `messages`, `tool_calls`, `text`, `last_message` | | `Tool` | Tool definition: `name`, `description`, `schema`, `fn` | -| `Runtime` | Step queue with `put_message()`, `get_all_hooks()` | +| `ToolSchema` | Serializable tool description: `name`, `description`, `tool_schema` (no `fn`) | +| `Runtime` | Central coordinator for the agent loop. Step queue, message queue, checkpoint replay/record | +| `RunResult` | Return type of `run()`. Async-iterable for messages, then `.checkpoint` and `.pending_hooks` | +| `HookInfo` | Pending hook info: `label`, `hook_type`, `metadata` | +| `Hook` | Generic hook base with `.create()`, `.resolve()`, `.cancel()` class methods | +| `Checkpoint` | Serializable snapshot of completed work: `steps[]`, `tools[]`, `hooks[]`. Has `.serialize()` / `.deserialize()` | | `LanguageModel` | Abstract base class for LLM providers | -| `HookPending` | Exception raised by `Hook.create_or_raise()` when resolution needed | ## Examples See the `examples/` directory: -- `run_agent.py` - Basic agent with tools -- `run_multiagent.py` - Parallel agents with live display -- `run_hooks.py` - Human-in-the-loop approval flow -- `run_streaming_tool.py` - Tool that streams progress via Runtime -- `run_custom_loop.py` - Custom step with `@ai.stream` -- `run_mcp.py` - MCP integration -- `run_fake_serverless.py` - Suspend/resume with `HookPending` +**Samples** (`examples/samples/`): + +- `simple.py` — Basic agent with tools and `stream_loop` +- `agent.py` — Coding agent with local filesystem tools +- `hooks.py` — Human-in-the-loop approval flow +- `streaming_tool.py` — Tool that streams progress via Runtime +- `multiagent.py` — Parallel agents with labels, then summarization +- `custom_loop.py` — Custom step with `@ai.stream` +- `mcp.py` — MCP integration (Context7) + +**Projects**: + +- `examples/fastapi-vite/` — Full-stack chat app (FastAPI + Vite + AI SDK UI) +- `examples/temporal-durable/` — Durable execution with Temporal workflows +- `examples/multiagent-textual/` — Multi-agent TUI with Textual diff --git a/pyproject.toml b/pyproject.toml index 1507c16a..fea893c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vercel-ai-sdk" -version = "0.0.1.dev3" +version = "0.0.1.dev4" description = "The AI Toolkit for Python" readme = "README.md" authors = [ From b46550d65de26e210a7e18bc13ea829e0d41ad01 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 13 Feb 2026 13:50:30 -0800 Subject: [PATCH 4/4] Fix the typing issue and clean up the FastAPI Vite example --- examples/fastapi-vite/backend/routes/chat.py | 12 +----------- examples/fastapi-vite/backend/uv.lock | 2 +- src/vercel_ai_sdk/ai_sdk_ui/adapter.py | 8 ++++---- 3 files changed, 6 insertions(+), 16 deletions(-) diff --git a/examples/fastapi-vite/backend/routes/chat.py b/examples/fastapi-vite/backend/routes/chat.py index 6da6431d..b62f4847 100644 --- a/examples/fastapi-vite/backend/routes/chat.py +++ b/examples/fastapi-vite/backend/routes/chat.py @@ -2,8 +2,6 @@ from __future__ import annotations -from collections.abc import AsyncGenerator - from fastapi import APIRouter from fastapi.responses import StreamingResponse from pydantic import BaseModel @@ -30,14 +28,6 @@ class ChatRequest(BaseModel): session_id: str | None = None -async def _iter_result( - result: ai.RunResult, -) -> AsyncGenerator[ai.Message, None]: - """Unwrap RunResult into an AsyncGenerator for to_sse_stream.""" - async for msg in result: - yield msg - - @router.post("/chat") async def chat(request: ChatRequest): """Handle chat requests and stream responses.""" @@ -57,7 +47,7 @@ async def chat(request: ChatRequest): result = ai.run(graph, llm, messages, TOOLS, checkpoint=checkpoint) async def stream_response(): - async for chunk in to_sse_stream(_iter_result(result)): + async for chunk in to_sse_stream(result): yield chunk # If the run completed (no pending hooks), clear the checkpoint diff --git a/examples/fastapi-vite/backend/uv.lock b/examples/fastapi-vite/backend/uv.lock index 3af4dba3..0450deef 100644 --- a/examples/fastapi-vite/backend/uv.lock +++ b/examples/fastapi-vite/backend/uv.lock @@ -1318,7 +1318,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev3" +version = "0.0.1.dev4" source = { editable = "../../../" } dependencies = [ { name = "anthropic" }, diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index af40b500..417b6fe3 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -7,7 +7,7 @@ import dataclasses import json import uuid -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterable from typing import Any, Literal from .. import core @@ -127,7 +127,7 @@ def begin_message( async def to_ui_message_stream( - messages: AsyncGenerator[core.messages.Message, None], + messages: AsyncIterable[core.messages.Message], ) -> AsyncGenerator[protocol.UIMessageStreamPart, None]: """ Convert a proto_sdk message stream into AI SDK UI message stream parts. @@ -257,7 +257,7 @@ async def to_ui_message_stream( async def filter_by_label( - messages: AsyncGenerator[core.messages.Message, None], + messages: AsyncIterable[core.messages.Message], label: str | None = None, ) -> AsyncGenerator[core.messages.Message, None]: """Filter a message stream to a single agent label. @@ -273,7 +273,7 @@ async def filter_by_label( async def to_sse_stream( - messages: AsyncGenerator[core.messages.Message, None], + messages: AsyncIterable[core.messages.Message], ) -> AsyncGenerator[str, None]: """Convert a proto_sdk message stream directly into SSE-formatted strings.""" async for part in to_ui_message_stream(messages):