From bad64005830aca82a0e7f2bbea60874ac9421ba8 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 24 Mar 2026 08:42:38 -0700 Subject: [PATCH 01/18] Update claude's proompt --- CLAUDE.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/CLAUDE.md b/CLAUDE.md index acb130e2..0605ce5b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,4 +1,35 @@ +# ai + +## development guidelines + 1. use `uv` to manage the project; `uv add` and `uv remove` to manage dependencies, `uv run` to run 2. after making changes run lint and typecheck: `uv run ruff check --fix src tests` and `uv run mypy src tests` 3. import by module (except `typing`) to improve readability via namespacing 4. treat `stream_step` and `stream_loop` as user code. they are convenience functions that could be reimplemented by the user, they *must* stay clean. + +## design principles + +### 1. maximize composability + +provide simple lego bricks that the user can build their feature with. each block should do one thing and be reasonably decoupled from the rest. +expose correct primitives to make it easy to modify behavior without rewriting it from scratch. + +- *example*: `agents` module provides `@ai.stream`, `@ai.tool` and `@ai.hook` that can be combined into an arbitrarily complex agent graph using plain python. +- *can the user rewrite this feature in plain python using the existing primitives?* + +### 2. minimize dsl-ness and frameworkiness + +express features in a way that doesn't require the user to read documentation and learn the framework. glue things together using python. +handle complexity inside the framework instead of delegating it to users. + +- *example*: `Runtime` does the heavy lifting so that multi-agent graphs can be expressed using python `asyncio`. +- *does this require the user to learn a framework-specific concept that has a direct python equivalent?* + +### 3. keep data model simple + +ensure state is easy to serialize and deserialize, modify, and compose at any level of granularity. +move normalization and translation complexity inside the framework and keep the public data model minimal. + +- *example*: public data model consists of a single unified `Message` type. the framework does not expose events and other intermediate steps unless the user is writing a custom adapter. + + From 1ca0978ad51cb92e9b201b96bc90893bbe541f05 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 24 Mar 2026 09:29:23 -0700 Subject: [PATCH 02/18] Move stream_events to abc, make stream concrete --- src/vercel_ai_sdk/models/ai_gateway/llm.py | 29 +-------- src/vercel_ai_sdk/models/anthropic/llm.py | 27 +------- src/vercel_ai_sdk/models/core/llm.py | 30 ++++++++- src/vercel_ai_sdk/models/openai/llm.py | 28 +------- tests/adapters/ai_sdk_ui/test_adapter.py | 12 +++- tests/conftest.py | 75 ++++++++++++++-------- 6 files changed, 87 insertions(+), 114 deletions(-) diff --git a/src/vercel_ai_sdk/models/ai_gateway/llm.py b/src/vercel_ai_sdk/models/ai_gateway/llm.py index b4e0d38f..fa918c96 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/llm.py +++ b/src/vercel_ai_sdk/models/ai_gateway/llm.py @@ -87,6 +87,7 @@ async def _raise_for_status(self, response: httpx.Response) -> None: # -- Stream events ------------------------------------------------------- + @override async def stream_events( self, messages: list[messages_.Message], @@ -144,34 +145,6 @@ async def stream_events( cause=exc, ) from exc - # -- LanguageModel interface --------------------------------------------- - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type): - msg = handler.handle_event(event) - yield msg - - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) - part = messages_.StructuredOutputPart( - data=data, - output_type_name=( - f"{output_type.__module__}.{output_type.__qualname__}" - ), - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg - # --------------------------------------------------------------------------- # Shared helpers for image/video models diff --git a/src/vercel_ai_sdk/models/anthropic/llm.py b/src/vercel_ai_sdk/models/anthropic/llm.py index f2812b31..59efe4f5 100644 --- a/src/vercel_ai_sdk/models/anthropic/llm.py +++ b/src/vercel_ai_sdk/models/anthropic/llm.py @@ -233,6 +233,7 @@ def __init__( resolved_key = api_key or os.environ.get("ANTHROPIC_API_KEY") or "" self._client = anthropic.AsyncAnthropic(base_url=base_url, api_key=resolved_key) + @override async def stream_events( self, messages: list[messages_.Message], @@ -340,29 +341,3 @@ async def stream_events( raw=sdk_usage.model_dump(exclude_none=True) or None, ) yield llm_.MessageDone(usage=usage) - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - """Stream Messages (uses StreamHandler internally).""" - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type=output_type): - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg diff --git a/src/vercel_ai_sdk/models/core/llm.py b/src/vercel_ai_sdk/models/core/llm.py index 765abf36..737a467f 100644 --- a/src/vercel_ai_sdk/models/core/llm.py +++ b/src/vercel_ai_sdk/models/core/llm.py @@ -2,6 +2,7 @@ import abc import dataclasses +import json from collections.abc import AsyncGenerator, Sequence import pydantic @@ -238,15 +239,40 @@ def _build_message( class LanguageModel(abc.ABC): @abc.abstractmethod - async def stream( + async def stream_events( self, messages: list[messages_.Message], tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: + ) -> AsyncGenerator[StreamEvent]: raise NotImplementedError yield + async def stream( + self, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + ) -> AsyncGenerator[messages_.Message]: + """Stream Messages (uses StreamHandler internally).""" + handler = StreamHandler() + msg: messages_.Message | None = None + async for event in self.stream_events(messages, tools, output_type=output_type): + msg = handler.handle_event(event) + yield msg + + # After stream completes, validate and attach structured output part + if output_type is not None and msg is not None and msg.text: + data = json.loads(msg.text) + output_type.model_validate(data) # fail fast on bad data + part = messages_.StructuredOutputPart( + data=data, + output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", + ) + msg = msg.model_copy() + msg.parts = [*msg.parts, part] + yield msg + async def buffer( self, messages: list[messages_.Message], diff --git a/src/vercel_ai_sdk/models/openai/llm.py b/src/vercel_ai_sdk/models/openai/llm.py index 3404cf27..57774324 100644 --- a/src/vercel_ai_sdk/models/openai/llm.py +++ b/src/vercel_ai_sdk/models/openai/llm.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json import os from collections.abc import AsyncGenerator, Sequence from typing import Any, override @@ -212,6 +211,7 @@ def __init__( resolved_key = api_key or os.environ.get("OPENAI_API_KEY") or "" self._client = openai.AsyncOpenAI(base_url=base_url, api_key=resolved_key) + @override async def stream_events( self, messages: list[messages_.Message], @@ -363,29 +363,3 @@ async def stream_events( # finish_reason. We'll emit MessageDone after the loop. yield llm_.MessageDone(finish_reason=finish_reason, usage=usage) - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - """Stream Messages (uses StreamHandler internally).""" - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type=output_type): - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index db717241..d3cf677e 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -308,22 +308,27 @@ async def test_runtime_tool_roundtrip() -> None: ] # This is what SHOULD happen: - # 1. First step yields tool call with status="pending" - # -> tool-input-start, tool-input-available + # 1. First step streams tool call args then completes + # -> tool-input-start, tool-input-delta, tool-input-available # 2. After tool execution, we yield the same message with # status="result" -> tool-output-available # (same step because same message ID) - # 3. Second LLM step yields final text -> text-start, text-end + # 3. Second LLM step streams text then completes + # -> text-start, text-delta, text-end, (final done msg) text-start, text-end expected = [ "start", "start-step", "tool-input-start", + "tool-input-delta", "tool-input-available", "tool-output-available", # Same step as input (same message ID) "finish-step", # Second LLM call (new message ID = new step) "start-step", "text-start", + "text-delta", + "text-end", + "text-start", # Final done message re-emits completed text "text-end", "finish-step", "finish", @@ -697,6 +702,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: "start", "start-step", "tool-input-start", + "tool-input-delta", "tool-input-available", "tool-approval-request", "finish-step", diff --git a/tests/conftest.py b/tests/conftest.py index 981f9db3..31fc755f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,58 +1,77 @@ from __future__ import annotations -import json from collections.abc import AsyncGenerator, Sequence import pydantic import vercel_ai_sdk as ai -from vercel_ai_sdk.types import messages -from vercel_ai_sdk.types.messages import StructuredOutputPart +from vercel_ai_sdk.models.core import llm as llm_ +from vercel_ai_sdk.types import messages as messages_ class MockLLM(ai.LanguageModel): - """LLM that yields pre-configured response sequences, one per call.""" + """LLM that yields pre-configured response sequences, one per call. - def __init__(self, responses: list[list[messages.Message]]) -> None: + Converts pre-configured ``Message`` objects into ``StreamEvent`` sequences + so the base-class ``stream()`` (which uses ``StreamHandler``) can + reconstruct them. + """ + + def __init__(self, responses: list[list[messages_.Message]]) -> None: self._responses = list(responses) self._call_index = 0 self.call_count = 0 - async def stream( + async def stream_events( self, - messages: list[messages.Message], + messages: list[messages_.Message], tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages.Message]: + ) -> AsyncGenerator[llm_.StreamEvent]: if self._call_index >= len(self._responses): raise RuntimeError("MockLLM: no more responses configured") self.call_count += 1 seq = self._responses[self._call_index] self._call_index += 1 - msg = None + for msg in seq: - yield msg - - # Simulate structured output validation (matching real provider behavior) - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg + for i, part in enumerate(msg.parts): + if isinstance(part, messages_.TextPart): + bid = f"text-{i}" + yield llm_.TextStart(block_id=bid) + if part.text: + yield llm_.TextDelta(block_id=bid, delta=part.text) + yield llm_.TextEnd(block_id=bid) + + elif isinstance(part, messages_.ReasoningPart): + bid = f"reasoning-{i}" + yield llm_.ReasoningStart(block_id=bid) + if part.text: + yield llm_.ReasoningDelta(block_id=bid, delta=part.text) + yield llm_.ReasoningEnd(block_id=bid, signature=part.signature) + + elif isinstance(part, messages_.ToolPart): + yield llm_.ToolStart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + if part.tool_args: + yield llm_.ToolArgsDelta( + tool_call_id=part.tool_call_id, + delta=part.tool_args, + ) + yield llm_.ToolEnd(tool_call_id=part.tool_call_id) + + yield llm_.MessageDone() def text_msg( text: str, *, id: str = "msg-1", state: str = "done", delta: str | None = None -) -> messages.Message: - return messages.Message( +) -> messages_.Message: + return messages_.Message( id=id, role="assistant", - parts=[messages.TextPart(text=text, state=state, delta=delta)], + parts=[messages_.TextPart(text=text, state=state, delta=delta)], ) @@ -64,12 +83,12 @@ def tool_msg( args: str = "{}", status: str = "pending", result: dict[str, object] | None = None, -) -> messages.Message: - return messages.Message( +) -> messages_.Message: + return messages_.Message( id=id, role="assistant", parts=[ - messages.ToolPart( + messages_.ToolPart( tool_call_id=tc_id, tool_name=name, tool_args=args, From d6644453a7a7b06b04f87d2b97997a14b83fc989 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 24 Mar 2026 12:18:06 -0700 Subject: [PATCH 03/18] Clean up openai and anthropic code --- src/vercel_ai_sdk/models/anthropic/llm.py | 246 +++++++++++----------- src/vercel_ai_sdk/models/openai/llm.py | 130 ++++++------ 2 files changed, 188 insertions(+), 188 deletions(-) diff --git a/src/vercel_ai_sdk/models/anthropic/llm.py b/src/vercel_ai_sdk/models/anthropic/llm.py index 59efe4f5..b6c4ff52 100644 --- a/src/vercel_ai_sdk/models/anthropic/llm.py +++ b/src/vercel_ai_sdk/models/anthropic/llm.py @@ -95,84 +95,77 @@ def _file_part_to_anthropic(part: messages_.FilePart) -> dict[str, Any]: async def _messages_to_anthropic( messages: list[messages_.Message], ) -> tuple[str | None, list[dict[str, Any]]]: - """Convert internal messages to Anthropic API format. - - Returns (system_prompt, messages) tuple since Anthropic handles - system prompts separately. - - Converts to the Anthropic wire format: - - - ``tool_use`` blocks in assistant messages - - ``tool_result`` blocks in user messages (immediately after) - - A final merge pass ensures strictly alternating roles (Anthropic - rejects consecutive same-role messages). - """ + """Convert internal messages to Anthropic API format.""" system_prompt: str | None = None result: list[dict[str, Any]] = [] for msg in messages: - if msg.role == "system": - system_prompt = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - elif msg.role == "assistant": - content: list[dict[str, Any]] = [] - tool_results: list[dict[str, Any]] = [] - - for part in msg.parts: - if isinstance(part, messages_.ReasoningPart): - if part.signature: - content.append( - { - "type": "thinking", - "thinking": part.text, - "signature": part.signature, - } - ) - elif isinstance(part, messages_.TextPart): - content.append({"type": "text", "text": part.text}) - elif isinstance(part, messages_.ToolPart): - tool_input = json.loads(part.tool_args) if part.tool_args else {} - content.append( - { - "type": "tool_use", - "id": part.tool_call_id, - "name": part.tool_name, - "input": tool_input, - } - ) - if part.status in ("result", "error"): - entry: dict[str, Any] = { - "type": "tool_result", - "tool_use_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - if part.status == "error": - entry["is_error"] = True - tool_results.append(entry) - - if content: - result.append({"role": "assistant", "content": content}) - if tool_results: - result.append({"role": "user", "content": tool_results}) - elif msg.role == "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - content_text = "".join( + match msg.role: + case "system": + system_prompt = "".join( p.text for p in msg.parts if isinstance(p, messages_.TextPart) ) - result.append({"role": "user", "content": content_text}) - else: - user_content: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - user_content.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - user_content.append(_file_part_to_anthropic(p)) - result.append({"role": "user", "content": user_content}) + case "assistant": + content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text, signature=signature): + if signature: + content.append( + { + "type": "thinking", + "thinking": text, + "signature": signature, + } + ) + case messages_.TextPart(text=text): + content.append({"type": "text", "text": text}) + case messages_.ToolPart(): + tool_input = ( + json.loads(part.tool_args) if part.tool_args else {} + ) + content.append( + { + "type": "tool_use", + "id": part.tool_call_id, + "name": part.tool_name, + "input": tool_input, + } + ) + if part.status in ("result", "error"): + entry: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + if part.status == "error": + entry["is_error"] = True + tool_results.append(entry) + + if content: + result.append({"role": "assistant", "content": content}) + if tool_results: + result.append({"role": "user", "content": tool_results}) + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + content_text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": content_text}) + else: + user_content: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + user_content.append({"type": "text", "text": text}) + case messages_.FilePart(): + user_content.append(_file_part_to_anthropic(p)) + result.append({"role": "user", "content": user_content}) # Merge consecutive same-role messages (e.g. synthetic user(tool_result) # followed by a real user message). @@ -273,58 +266,63 @@ async def stream_events( async with stream_cm as stream: async for event in stream: - if event.type == "content_block_start": - block = event.content_block - idx = event.index - block_types[idx] = block.type - - if block.type == "text": - yield llm_.TextStart(block_id=str(idx)) - elif block.type == "thinking": - yield llm_.ReasoningStart(block_id=str(idx)) - elif block.type == "tool_use": - tool_ids[idx] = block.id - yield llm_.ToolStart( - tool_call_id=block.id, tool_name=block.name - ) - - elif event.type == "content_block_delta": - delta = event.delta - idx = event.index - - if delta.type == "text_delta": - yield llm_.TextDelta(block_id=str(idx), delta=delta.text) - elif delta.type == "thinking_delta": - yield llm_.ReasoningDelta( - block_id=str(idx), delta=delta.thinking - ) - elif delta.type == "signature_delta": - # Accumulate signature for ReasoningEnd - signature_buffer[idx] = ( - signature_buffer.get(idx, "") + delta.signature - ) - elif delta.type == "input_json_delta": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolArgsDelta( - tool_call_id=tool_id, delta=delta.partial_json - ) - - elif event.type == "content_block_stop": - idx = event.index - block_type = block_types.get(idx) - - if block_type == "text": - yield llm_.TextEnd(block_id=str(idx)) - elif block_type == "thinking": - yield llm_.ReasoningEnd( - block_id=str(idx), - signature=signature_buffer.get(idx), - ) - elif block_type == "tool_use": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolEnd(tool_call_id=tool_id) + match event.type: + case "content_block_start": + block = event.content_block + idx = event.index + block_types[idx] = block.type + + match block.type: + case "text": + yield llm_.TextStart(block_id=str(idx)) + case "thinking": + yield llm_.ReasoningStart(block_id=str(idx)) + case "tool_use": + tool_ids[idx] = block.id + yield llm_.ToolStart( + tool_call_id=block.id, tool_name=block.name + ) + + case "content_block_delta": + delta = event.delta + idx = event.index + + match delta.type: + case "text_delta": + yield llm_.TextDelta( + block_id=str(idx), delta=delta.text + ) + case "thinking_delta": + yield llm_.ReasoningDelta( + block_id=str(idx), delta=delta.thinking + ) + case "signature_delta": + # Accumulate signature for ReasoningEnd + signature_buffer[idx] = ( + signature_buffer.get(idx, "") + delta.signature + ) + case "input_json_delta": + tool_id = tool_ids.get(idx) + if tool_id: + yield llm_.ToolArgsDelta( + tool_call_id=tool_id, + delta=delta.partial_json, + ) + + case "content_block_stop": + idx = event.index + match block_types.get(idx): + case "text": + yield llm_.TextEnd(block_id=str(idx)) + case "thinking": + yield llm_.ReasoningEnd( + block_id=str(idx), + signature=signature_buffer.get(idx), + ) + case "tool_use": + tool_id = tool_ids.get(idx) + if tool_id: + yield llm_.ToolEnd(tool_call_id=tool_id) # The Anthropic SDK accumulates usage across message_start and # message_delta events into current_message_snapshot. Read it diff --git a/src/vercel_ai_sdk/models/openai/llm.py b/src/vercel_ai_sdk/models/openai/llm.py index 57774324..46dd3a0d 100644 --- a/src/vercel_ai_sdk/models/openai/llm.py +++ b/src/vercel_ai_sdk/models/openai/llm.py @@ -102,72 +102,74 @@ async def _messages_to_openai( """ result: list[dict[str, Any]] = [] for msg in messages: - if msg.role == "assistant": - content = "" - reasoning = "" - tool_calls = [] - tool_results = [] - - for part in msg.parts: - if isinstance(part, messages_.ReasoningPart): - reasoning += part.text - elif isinstance(part, messages_.TextPart): - content += part.text - elif isinstance(part, messages_.ToolPart): - tool_calls.append( - { - "id": part.tool_call_id, - "type": "function", - "function": { - "name": part.tool_name, - "arguments": part.tool_args, - }, - } - ) - if part.status in ("result", "error"): - tool_results.append( - { - "role": "tool", - "tool_call_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - ) - - entry: dict[str, Any] = {"role": "assistant"} - if content: - entry["content"] = content - if reasoning: - entry["reasoning"] = reasoning - if tool_calls: - entry["tool_calls"] = tool_calls - result.append(entry) - - # Emit tool results as separate messages (OpenAI API format) - result.extend(tool_results) - elif msg.role == "system": - content = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "system", "content": content}) - else: - # User messages — may contain multimodal FileParts - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - # Text-only: keep simple string format (cheaper, no content array) - text = "".join( + match msg.role: + case "assistant": + content = "" + reasoning = "" + tool_calls = [] + tool_results = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text): + reasoning += text + case messages_.TextPart(text=text): + content += text + case messages_.ToolPart(): + tool_calls.append( + { + "id": part.tool_call_id, + "type": "function", + "function": { + "name": part.tool_name, + "arguments": part.tool_args, + }, + } + ) + if part.status in ("result", "error"): + tool_results.append( + { + "role": "tool", + "tool_call_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + ) + + entry: dict[str, Any] = {"role": "assistant"} + if content: + entry["content"] = content + if reasoning: + entry["reasoning"] = reasoning + if tool_calls: + entry["tool_calls"] = tool_calls + result.append(entry) + + # Emit tool results as separate messages (OpenAI API format) + result.extend(tool_results) + case "system": + content = "".join( p.text for p in msg.parts if isinstance(p, messages_.TextPart) ) - result.append({"role": "user", "content": text}) - else: - parts: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - parts.append(await _file_part_to_openai(p)) - result.append({"role": "user", "content": parts}) + result.append({"role": "system", "content": content}) + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + # Text-only: keep simple string format (cheaper, no content array) + text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": text}) + else: + parts: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + parts.append({"type": "text", "text": text}) + case messages_.FilePart(): + parts.append(await _file_part_to_openai(p)) + result.append({"role": "user", "content": parts}) return result From f1ba74d319b493a5f34c7777b542ed9b32a4e0e2 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 25 Mar 2026 11:44:24 -0700 Subject: [PATCH 04/18] Add basic model wiring --- src/vercel_ai_sdk/__init__.py | 7 +- src/vercel_ai_sdk/models/__init__.py | 79 +++++++++++- src/vercel_ai_sdk/models/core/__init__.py | 20 +++ src/vercel_ai_sdk/models/core/model.py | 29 +++++ src/vercel_ai_sdk/models/core/protocol.py | 145 ++++++++++++++++++++++ src/vercel_ai_sdk/models/core/registry.py | 54 ++++++++ 6 files changed, 332 insertions(+), 2 deletions(-) create mode 100644 src/vercel_ai_sdk/models/core/model.py create mode 100644 src/vercel_ai_sdk/models/core/protocol.py create mode 100644 src/vercel_ai_sdk/models/core/registry.py diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index d0bdbe1b..37946271 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,4 +1,4 @@ -from . import adapters, telemetry +from . import adapters, models, telemetry from .adapters import ai_sdk_ui from .agents import ( Checkpoint, @@ -25,6 +25,8 @@ LanguageModel, MediaModel, MediaResult, + Model, + Stream, VideoModel, ai_gateway, anthropic, @@ -66,11 +68,14 @@ "Usage", "make_messages", # Models (from models/) + "Model", + "Stream", "LanguageModel", "MediaModel", "MediaResult", "ImageModel", "VideoModel", + "models", # Agents (from agents/) "Tool", "Runtime", diff --git a/src/vercel_ai_sdk/models/__init__.py b/src/vercel_ai_sdk/models/__init__.py index c1e6000a..47c129aa 100644 --- a/src/vercel_ai_sdk/models/__init__.py +++ b/src/vercel_ai_sdk/models/__init__.py @@ -2,21 +2,98 @@ Provides the LanguageModel ABC and concrete provider adapters. Depends only on types/, never on agents/. + +Module-level API +~~~~~~~~~~~~~~~~ + +.. code-block:: python + + import vercel_ai_sdk as ai + + model = ai.models.Model(id="gpt-4o", api="openai", provider="openai") + s = ai.models.stream(model, messages) + async for msg in s: + ... + + result = await ai.models.generate(model, messages, n=2) """ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import pydantic + +from ..types import messages as messages_ +from ..types import tools as tools_ from . import ai_gateway, anthropic, core, openai from .core import ( + GenerateFn, ImageModel, LanguageModel, MediaModel, MediaResult, + Model, + Stream, StreamEvent, + StreamFn, StreamHandler, VideoModel, + get_generate_fn, + get_stream_fn, + register_generate, + register_stream, ) +# ── Module-level dispatch functions ─────────────────────────────── + + +def stream( + model: Model, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, +) -> Stream: + """Stream an LLM response for the given model. + + Looks up the registered :class:`StreamFn` for ``model.api`` and + returns a :class:`Stream` that can be async-iterated *or* awaited. + """ + fn = get_stream_fn(model.api) + return Stream(fn(model, messages, tools=tools, output_type=output_type)) + + +async def generate( + model: Model, + messages: list[messages_.Message], + **kwargs: Any, +) -> messages_.Message: + """Generate a response (image, video, etc.) for the given model. + + Looks up the registered :class:`GenerateFn` for ``model.api`` and + returns the resulting :class:`Message`. + """ + fn = get_generate_fn(model.api) + return await fn(model, messages, **kwargs) + + __all__ = [ - # Core abstractions + # Model data + "Model", + # Execution protocols + "StreamFn", + "GenerateFn", + "Stream", + # Registry + "register_stream", + "register_generate", + "get_stream_fn", + "get_generate_fn", + # Dispatch + "stream", + "generate", + # Legacy ABCs (still in use) "LanguageModel", "StreamEvent", "StreamHandler", diff --git a/src/vercel_ai_sdk/models/core/__init__.py b/src/vercel_ai_sdk/models/core/__init__.py index 63c6c89f..be7a0f38 100644 --- a/src/vercel_ai_sdk/models/core/__init__.py +++ b/src/vercel_ai_sdk/models/core/__init__.py @@ -4,9 +4,29 @@ from .image import ImageModel from .llm import LanguageModel, StreamEvent, StreamHandler from .media.base import MediaModel, MediaResult +from .model import Model +from .protocol import GenerateFn, Stream, StreamFn +from .registry import ( + get_generate_fn, + get_stream_fn, + register_generate, + register_stream, +) from .video import VideoModel __all__ = [ + # Model data + "Model", + # Execution protocols + "StreamFn", + "GenerateFn", + "Stream", + # Registry + "register_stream", + "register_generate", + "get_stream_fn", + "get_generate_fn", + # Legacy ABCs (still in use) "LanguageModel", "StreamEvent", "StreamHandler", diff --git a/src/vercel_ai_sdk/models/core/model.py b/src/vercel_ai_sdk/models/core/model.py new file mode 100644 index 00000000..3b6d8797 --- /dev/null +++ b/src/vercel_ai_sdk/models/core/model.py @@ -0,0 +1,29 @@ +"""Model — pure data describing a model, no execution logic.""" + +from __future__ import annotations + +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class Model: + """Immutable description of a model. + + ``id`` + The model identifier sent to the provider + (e.g. ``"claude-sonnet-4-20250514"``, ``"gpt-4o"``). + + ``api`` + Wire protocol discriminator used to look up the execution function + (e.g. ``"anthropic"``, ``"openai"``, ``"ai-gateway"``). + A single ``api`` value may be shared by multiple providers that speak + the same wire format. + + ``provider`` + The actual host / provider name + (e.g. ``"anthropic"``, ``"azure"``, ``"ai-gateway"``). + """ + + id: str + api: str + provider: str diff --git a/src/vercel_ai_sdk/models/core/protocol.py b/src/vercel_ai_sdk/models/core/protocol.py new file mode 100644 index 00000000..56d3eb8a --- /dev/null +++ b/src/vercel_ai_sdk/models/core/protocol.py @@ -0,0 +1,145 @@ +"""Execution protocols and the Stream result type. + +``StreamFn`` and ``GenerateFn`` define the execution contract that +provider adapters must satisfy. ``Stream`` wraps an async generator +of :class:`Message` objects into an async-iterable *and* awaitable +result with convenience properties. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Generator, Sequence +from typing import Any, Protocol, runtime_checkable + +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from .model import Model + +# ── Execution protocols ─────────────────────────────────────────── + + +@runtime_checkable +class StreamFn(Protocol): + """Protocol for streaming LLM calls. + + Implementations accept a :class:`Model`, messages, and optional tools / + output type, and return an async generator that yields + :class:`Message` snapshots as the response streams in. + """ + + def __call__( + self, + model: Model, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + ) -> AsyncGenerator[messages_.Message]: ... + + +@runtime_checkable +class GenerateFn(Protocol): + """Protocol for non-streaming generation (image, video, etc.). + + Implementations accept a :class:`Model`, messages, and arbitrary + keyword arguments forwarded from the caller. + """ + + async def __call__( + self, + model: Model, + messages: list[messages_.Message], + **kwargs: Any, + ) -> messages_.Message: ... + + +# ── Stream result ───────────────────────────────────────────────── + + +class Stream: + """Async-iterable *and* awaitable wrapper around a message generator. + + Usage:: + + # Streaming + stream = Stream(gen) + async for msg in stream: + print(msg.text) + + # Or just await the final result + stream = Stream(gen) + await stream + stream.result # last Message + stream.text # concatenated text + """ + + def __init__(self, generator: AsyncGenerator[messages_.Message]) -> None: + self._generator = generator + self._messages: list[messages_.Message] = [] + self._done = False + + # ── Async iteration ─────────────────────────────────────────── + + async def __aiter__(self) -> AsyncGenerator[messages_.Message]: + if self._done: + # Already consumed — replay from buffer + for msg in self._messages: + yield msg + return + + async for msg in self._generator: + self._messages.append(msg) + yield msg + self._done = True + + # ── Awaitable ───────────────────────────────────────────────── + + def __await__(self) -> Generator[Any, None, Stream]: + return self._drain().__await__() + + async def _drain(self) -> Stream: + """Consume the entire generator, populating result fields.""" + if not self._done: + async for _ in self: + pass + return self + + # ── Result properties (available after iteration / await) ───── + + @property + def messages(self) -> list[messages_.Message]: + """All messages yielded during streaming.""" + return list(self._messages) + + @property + def result(self) -> messages_.Message | None: + """The last message (final snapshot), or ``None`` if empty.""" + return self._messages[-1] if self._messages else None + + @property + def tool_calls(self) -> list[messages_.ToolPart]: + """Tool-call parts from the final message.""" + if not self._messages: + return [] + return [ + p for p in self._messages[-1].parts if isinstance(p, messages_.ToolPart) + ] + + @property + def text(self) -> str: + """Concatenated text from the final message.""" + if not self._messages: + return "" + return "".join( + p.text + for p in self._messages[-1].parts + if isinstance(p, messages_.TextPart) + ) + + @property + def usage(self) -> messages_.Usage | None: + """Usage from the final message, if available.""" + if not self._messages: + return None + return self._messages[-1].usage diff --git a/src/vercel_ai_sdk/models/core/registry.py b/src/vercel_ai_sdk/models/core/registry.py new file mode 100644 index 00000000..033dba45 --- /dev/null +++ b/src/vercel_ai_sdk/models/core/registry.py @@ -0,0 +1,54 @@ +"""Registry mapping ``api`` strings to execution functions. + +Provider adapters call :func:`register_stream` / :func:`register_generate` +to make themselves available. The module-level ``stream()`` and +``generate()`` functions in :mod:`vercel_ai_sdk.models` use +:func:`get_stream_fn` / :func:`get_generate_fn` to dispatch. +""" + +from __future__ import annotations + +from .protocol import GenerateFn, StreamFn + +_stream_fns: dict[str, StreamFn] = {} +_generate_fns: dict[str, GenerateFn] = {} + + +def register_stream(api: str, fn: StreamFn) -> None: + """Register a :class:`StreamFn` for the given wire-protocol ``api``.""" + _stream_fns[api] = fn + + +def register_generate(api: str, fn: GenerateFn) -> None: + """Register a :class:`GenerateFn` for the given wire-protocol ``api``.""" + _generate_fns[api] = fn + + +def get_stream_fn(api: str) -> StreamFn: + """Look up the registered :class:`StreamFn` for ``api``. + + Raises :class:`KeyError` with a descriptive message if no function + has been registered for the given ``api``. + """ + try: + return _stream_fns[api] + except KeyError: + registered = ", ".join(sorted(_stream_fns)) or "(none)" + raise KeyError( + f"No StreamFn registered for api={api!r}. Registered: {registered}" + ) from None + + +def get_generate_fn(api: str) -> GenerateFn: + """Look up the registered :class:`GenerateFn` for ``api``. + + Raises :class:`KeyError` with a descriptive message if no function + has been registered for the given ``api``. + """ + try: + return _generate_fns[api] + except KeyError: + registered = ", ".join(sorted(_generate_fns)) or "(none)" + raise KeyError( + f"No GenerateFn registered for api={api!r}. Registered: {registered}" + ) from None From 9a4c9419f6a18a3ddd67a73c2ef45d86e3ddcec0 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 31 Mar 2026 16:40:22 -0700 Subject: [PATCH 05/18] Implement the first approximation of the updated model API --- examples/models2/buffer.py | 32 + examples/models2/direct_wire.py | 42 ++ examples/models2/explicit_client.py | 41 ++ examples/models2/image_generation.py | 52 ++ examples/models2/inline_image.py | 74 +++ examples/models2/multimodal_input.py | 38 ++ examples/models2/stream.py | 33 + examples/models2/structured_output.py | 45 ++ examples/models2/tools.py | 50 ++ examples/models2/video_generation.py | 54 ++ src/vercel_ai_sdk/models2/__init__.py | 171 +++++ src/vercel_ai_sdk/models2/core/__init__.py | 13 + src/vercel_ai_sdk/models2/core/client.py | 45 ++ .../models2/core/helpers/media.py | 370 +++++++++++ .../models2/core/helpers/streaming.py | 264 ++++++++ src/vercel_ai_sdk/models2/core/model.py | 34 + src/vercel_ai_sdk/models2/core/wire.py | 54 ++ src/vercel_ai_sdk/models2/wires/__init__.py | 1 + .../models2/wires/ai_gateway_v3.py | 594 ++++++++++++++++++ 19 files changed, 2007 insertions(+) create mode 100644 examples/models2/buffer.py create mode 100644 examples/models2/direct_wire.py create mode 100644 examples/models2/explicit_client.py create mode 100644 examples/models2/image_generation.py create mode 100644 examples/models2/inline_image.py create mode 100644 examples/models2/multimodal_input.py create mode 100644 examples/models2/stream.py create mode 100644 examples/models2/structured_output.py create mode 100644 examples/models2/tools.py create mode 100644 examples/models2/video_generation.py create mode 100644 src/vercel_ai_sdk/models2/__init__.py create mode 100644 src/vercel_ai_sdk/models2/core/__init__.py create mode 100644 src/vercel_ai_sdk/models2/core/client.py create mode 100644 src/vercel_ai_sdk/models2/core/helpers/media.py create mode 100644 src/vercel_ai_sdk/models2/core/helpers/streaming.py create mode 100644 src/vercel_ai_sdk/models2/core/model.py create mode 100644 src/vercel_ai_sdk/models2/core/wire.py create mode 100644 src/vercel_ai_sdk/models2/wires/__init__.py create mode 100644 src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py diff --git a/examples/models2/buffer.py b/examples/models2/buffer.py new file mode 100644 index 00000000..5b05f32c --- /dev/null +++ b/examples/models2/buffer.py @@ -0,0 +1,32 @@ +"""Buffered response — drain the stream, get the final message.""" + +import asyncio + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="What is 2 + 2?")], + ), +] + + +async def main() -> None: + result = await m.buffer(m.stream(model, messages)) + print(result.text) + if result.usage: + print( + f"tokens: {result.usage.input_tokens} in, {result.usage.output_tokens} out" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/direct_wire.py b/examples/models2/direct_wire.py new file mode 100644 index 00000000..66175483 --- /dev/null +++ b/examples/models2/direct_wire.py @@ -0,0 +1,42 @@ +"""Direct wire call — bypass the registry, call the wire function directly.""" + +import asyncio +import os + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.models2.wires import ai_gateway_v3 +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +client = m.Client( + base_url="https://ai-gateway.vercel.sh/v3/ai", + api_key=os.environ["AI_GATEWAY_API_KEY"], +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Say hello in three languages.")], + ), +] + + +async def main() -> None: + # Call the wire function directly — no registry lookup, no auto-client. + # This is the lowest level of the API. + try: + async for msg in ai_gateway_v3.stream(client, model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + finally: + await client.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/explicit_client.py b/examples/models2/explicit_client.py new file mode 100644 index 00000000..baa319cc --- /dev/null +++ b/examples/models2/explicit_client.py @@ -0,0 +1,41 @@ +"""Explicit client — bring your own auth and base URL.""" + +import asyncio +import os + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +# Explicit client — useful for custom auth, proxies, or self-hosted gateways. +client = m.Client( + base_url="https://ai-gateway.vercel.sh/v3/ai", + api_key=os.environ["AI_GATEWAY_API_KEY"], + headers={"X-Custom-Header": "example"}, +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Hello!")], + ), +] + + +async def main() -> None: + try: + async for msg in m.stream(model, messages, client=client): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + finally: + await client.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/image_generation.py b/examples/models2/image_generation.py new file mode 100644 index 00000000..565cf78b --- /dev/null +++ b/examples/models2/image_generation.py @@ -0,0 +1,52 @@ +"""Image generation — dedicated image model via generate().""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="google/imagen-4.0-generate-001", + api="ai-gateway", + provider="ai-gateway", + capabilities=("image",), +) + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "Anime girl with twin tails and cat ears, wearing a " + "sailor school uniform, striking a victory pose in front " + "of a futuristic Tokyo skyline at night, neon lights " + "reflecting in her eyes, digital art style" + ) + ), + ], + ), +] + + +async def main() -> None: + result = await m.generate(model, messages, n=2, aspect_ratio="16:9") + + print(f"Generated {len(result.images)} image(s)") + for i, img in enumerate(result.images): + filename = f"generated_{i}.png" + data = img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {img.media_type}, {len(data)} bytes") + + if result.usage: + print( + f"Usage: {result.usage.input_tokens} input, " + f"{result.usage.output_tokens} output tokens" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/inline_image.py b/examples/models2/inline_image.py new file mode 100644 index 00000000..17dfb7e6 --- /dev/null +++ b/examples/models2/inline_image.py @@ -0,0 +1,74 @@ +"""Inline image generation — LLM that outputs images alongside text. + +Models like Gemini 3 Pro Image can generate images as part of their +language model response. The images arrive as FileParts in the streamed +Message. +""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +# This is a language model that can also output images inline. +model = m.Model( + id="google/gemini-3-pro-image", + api="ai-gateway", + provider="ai-gateway", + capabilities=("text", "image"), +) + +messages = [ + messages_.Message( + role="system", + parts=[ + messages_.TextPart( + text=( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style." + ) + ), + ], + ), + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "Draw an anime girl with long silver hair and violet eyes, " + "sitting in a field of cherry blossoms at sunset." + ) + ), + ], + ), +] + + +async def main() -> None: + last_msg: messages_.Message | None = None + + # Stream — text deltas arrive as usual, images arrive as FileParts + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + last_msg = msg + + print() + + # Check for images in the final message + if last_msg and last_msg.images: + for i, img in enumerate(last_msg.images): + filename = f"inline_{i}.png" + data = ( + img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + ) + pathlib.Path(filename).write_bytes(data) + print(f"Saved {filename} ({img.media_type}, {len(data)} bytes)") + else: + print("No images were generated in this response.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/multimodal_input.py b/examples/models2/multimodal_input.py new file mode 100644 index 00000000..21a5f43d --- /dev/null +++ b/examples/models2/multimodal_input.py @@ -0,0 +1,38 @@ +"""Multimodal input — send a local image to the model and ask about it.""" + +import asyncio +import pathlib + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +# Load a local image file (replace with your own path) +image_path = pathlib.Path("sample_image.jpg") +image_data = image_path.read_bytes() + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart(text="Describe this image in detail."), + messages_.FilePart(data=image_data, media_type="image/jpeg"), + ], + ), +] + + +async def main() -> None: + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/stream.py b/examples/models2/stream.py new file mode 100644 index 00000000..76155401 --- /dev/null +++ b/examples/models2/stream.py @@ -0,0 +1,33 @@ +"""Basic streaming — print text deltas as they arrive.""" + +import asyncio + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +messages = [ + messages_.Message(role="system", parts=[messages_.TextPart(text="Be concise.")]), + messages_.Message( + role="user", + parts=[ + messages_.TextPart(text="Explain why the sky is blue in two sentences.") + ], + ), +] + + +async def main() -> None: + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/structured_output.py b/examples/models2/structured_output.py new file mode 100644 index 00000000..c984d672 --- /dev/null +++ b/examples/models2/structured_output.py @@ -0,0 +1,45 @@ +"""Structured output — get validated JSON from the model.""" + +import asyncio + +import pydantic + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + + +class Recipe(pydantic.BaseModel): + name: str + ingredients: list[str] + steps: list[str] + prep_time_minutes: int + + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Give me a simple pancake recipe.")], + ), +] + + +async def main() -> None: + # Stream with structured output — watch JSON arrive, get validated at the end + async for msg in m.stream(model, messages, output_type=Recipe): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + if msg.output: + recipe: Recipe = msg.output + print(f"\n\nParsed recipe: {recipe.name}") + print(f" Ingredients: {', '.join(recipe.ingredients)}") + print(f" Prep time: {recipe.prep_time_minutes} min") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/tools.py b/examples/models2/tools.py new file mode 100644 index 00000000..aa1d389f --- /dev/null +++ b/examples/models2/tools.py @@ -0,0 +1,50 @@ +"""Tools — pass tool schemas to the model.""" + +import asyncio + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ +from vercel_ai_sdk.types import tools as tools_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", +) + +# Define a tool schema — anything matching the ToolLike protocol works. +get_weather = tools_.ToolSchema( + name="get_weather", + description="Get the current weather for a city.", + param_schema={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + }, + "required": ["city"], + }, + return_type=str, +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="What's the weather in Tokyo?")], + ), +] + + +async def main() -> None: + # Stream with tools — the model may emit tool calls + async for msg in m.stream(model, messages, tools=[get_weather]): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + + for tc in msg.tool_calls: + if tc.state == "done": + print(f"\nTool call: {tc.tool_name}({tc.tool_args})") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models2/video_generation.py b/examples/models2/video_generation.py new file mode 100644 index 00000000..8b9e17d1 --- /dev/null +++ b/examples/models2/video_generation.py @@ -0,0 +1,54 @@ +"""Video generation — dedicated video model via generate().""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models2 as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="google/veo-3.0-generate-001", + api="ai-gateway", + provider="ai-gateway", + capabilities=("video",), +) + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "An anime girl with long pink hair and a flowing white " + "dress stands on a hilltop at golden hour. A warm breeze " + "lifts her hair as she releases a paper lantern into the " + "sunset sky. Soft cel-shaded anime art style, warm palette." + ) + ), + ], + ), +] + + +async def main() -> None: + print("Generating video (this may take a minute or two)...") + + result = await m.generate( + model, + messages, + aspect_ratio="16:9", + duration=8, + ) + + print(f"Generated {len(result.videos)} video(s)") + for i, vid in enumerate(result.videos): + ext = "mp4" if "mp4" in vid.media_type else "webm" + filename = f"generated_{i}.{ext}" + data = vid.data if isinstance(vid.data, bytes) else base64.b64decode(vid.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {vid.media_type}, {len(data)} bytes") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py new file mode 100644 index 00000000..b41252d0 --- /dev/null +++ b/src/vercel_ai_sdk/models2/__init__.py @@ -0,0 +1,171 @@ +"""models2 — composable model layer. + +Usage:: + + from vercel_ai_sdk import models2 as m + from vercel_ai_sdk.types import Message, TextPart + + model = m.Model( + id="anthropic/claude-sonnet-4", + api="ai-gateway", + provider="ai-gateway", + ) + msgs = [Message(role="user", parts=[TextPart(text="hello")])] + + # stream — auto-creates client from env vars + async for msg in m.stream(model, msgs): + print(msg.text_delta, end="") + + # buffer the whole response + result = await m.buffer(m.stream(model, msgs)) + print(result.text) + + # explicit client + client = m.Client(base_url="https://custom.example.com/v3/ai", api_key="sk-...") + async for msg in m.stream(model, msgs, client=client): + ... +""" + +from __future__ import annotations + +import os +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import pydantic + +from ..types import messages as messages_ +from ..types import tools as tools_ +from .core.client import Client +from .core.model import Model, ModelCost +from .core.wire import GenerateFn, StreamFn + +# --------------------------------------------------------------------------- +# Wire registry — maps api string → wire function. +# Wire modules are imported lazily on first use. +# --------------------------------------------------------------------------- + +_stream_wires: dict[str, StreamFn] = {} +_generate_wires: dict[str, GenerateFn] = {} +_wires_loaded = False + + +def _ensure_wires() -> None: + """Lazily register built-in wire functions on first call.""" + global _wires_loaded # noqa: PLW0603 + if _wires_loaded: + return + _wires_loaded = True + + from .wires import ai_gateway_v3 + + _stream_wires["ai-gateway"] = ai_gateway_v3.stream + _generate_wires["ai-gateway"] = ai_gateway_v3.generate + + +# --------------------------------------------------------------------------- +# Provider defaults — base URLs and env var names for auto-client creation. +# --------------------------------------------------------------------------- + +_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { + "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), + "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), + "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), +} + + +def _auto_client(model: Model) -> Client: + """Create a :class:`Client` from env vars for the given model's provider.""" + defaults = _PROVIDER_DEFAULTS.get(model.provider) + if defaults is None: + raise ValueError( + f"No default client config for provider {model.provider!r}. " + f"Pass an explicit client= argument." + ) + base_url, env_var = defaults + return Client(base_url=base_url, api_key=os.environ.get(env_var)) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def stream( + model: Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + client: Client | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response. + + Resolves the wire function from ``model.api``, auto-creates a + :class:`Client` from env vars if none is provided, and yields + ``Message`` snapshots. + """ + _ensure_wires() + c = client or _auto_client(model) + wire_fn = _stream_wires.get(model.api) + if wire_fn is None: + registered = ", ".join(sorted(_stream_wires)) or "(none)" + raise KeyError( + f"No stream wire registered for api={model.api!r}. Registered: {registered}" + ) + async for msg in wire_fn( + c, model, messages, tools=tools, output_type=output_type, **kwargs + ): + yield msg + + +async def generate( + model: Model, + messages: list[messages_.Message], + *, + client: Client | None = None, + **kwargs: Any, +) -> messages_.Message: + """Generate a response (images, video, etc.). + + Resolves the wire function from ``model.api``, auto-creates a + :class:`Client` from env vars if none is provided. + """ + _ensure_wires() + c = client or _auto_client(model) + wire_fn = _generate_wires.get(model.api) + if wire_fn is None: + registered = ", ".join(sorted(_generate_wires)) or "(none)" + raise KeyError( + f"No generate wire registered for api={model.api!r}. " + f"Registered: {registered}" + ) + return await wire_fn(c, model, messages, **kwargs) + + +async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: + """Drain a stream and return the final ``Message``. + + Raises :class:`ValueError` if the stream yields nothing. + """ + result: messages_.Message | None = None + async for msg in gen: + result = msg + if result is None: + raise ValueError("empty stream") + return result + + +__all__ = [ + # Core types + "Client", + "GenerateFn", + "Model", + "ModelCost", + "StreamFn", + # Public API + "buffer", + "generate", + "stream", +] diff --git a/src/vercel_ai_sdk/models2/core/__init__.py b/src/vercel_ai_sdk/models2/core/__init__.py new file mode 100644 index 00000000..38cdf5c6 --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/__init__.py @@ -0,0 +1,13 @@ +"""Core types for models2.""" + +from .client import Client +from .model import Model, ModelCost +from .wire import GenerateFn, StreamFn + +__all__ = [ + "Client", + "GenerateFn", + "Model", + "ModelCost", + "StreamFn", +] diff --git a/src/vercel_ai_sdk/models2/core/client.py b/src/vercel_ai_sdk/models2/core/client.py new file mode 100644 index 00000000..88ae5454 --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/client.py @@ -0,0 +1,45 @@ +"""HTTP client for wire functions.""" + +from __future__ import annotations + +import dataclasses + +import httpx + + +@dataclasses.dataclass +class Client: + """Connection parameters for a provider API. + + Wire functions receive a ``Client`` instead of creating their own HTTP + session. This keeps auth and base URL decoupled from the wire logic. + + The :pyattr:`http` property lazily creates a shared + :class:`httpx.AsyncClient` so that consecutive calls reuse the same + connection pool. + """ + + base_url: str + api_key: str | None = None + headers: dict[str, str] = dataclasses.field(default_factory=dict) + + _http: httpx.AsyncClient | None = dataclasses.field( + default=None, repr=False, compare=False + ) + + @property + def http(self) -> httpx.AsyncClient: + """Lazy-init shared httpx client.""" + if self._http is None or self._http.is_closed: + self._http = httpx.AsyncClient( + base_url=self.base_url, + headers=self.headers, + timeout=httpx.Timeout(timeout=300.0, connect=10.0), + ) + return self._http + + async def aclose(self) -> None: + """Close the underlying HTTP client if open.""" + if self._http is not None and not self._http.is_closed: + await self._http.aclose() + self._http = None diff --git a/src/vercel_ai_sdk/models2/core/helpers/media.py b/src/vercel_ai_sdk/models2/core/helpers/media.py new file mode 100644 index 00000000..3fc3e793 --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/helpers/media.py @@ -0,0 +1,370 @@ +from __future__ import annotations + +import base64 +import base64 as _b64 +import mimetypes + +import httpx + +# -- URL helpers ----------------------------------------------------------- + + +def is_url(data: str) -> bool: + """Return True if *data* looks like a URL rather than raw base-64.""" + return data.startswith(("http://", "https://", "data:")) + + +def is_downloadable_url(data: str) -> bool: + """Return True if *data* is an ``http(s)://`` URL that can be fetched.""" + return data.startswith(("http://", "https://")) + + +def split_data_url(url: str) -> tuple[str | None, str | None]: + """Parse a ``data:`` URL into ``(media_type, base64_content)``. + + Returns ``(None, None)`` if the input is not a valid ``data:`` URL. + + Example:: + + >>> split_data_url("data:image/png;base64,iVBOR...") + ("image/png", "iVBOR...") + """ + if not url.startswith("data:"): + return None, None + try: + header, b64_content = url.split(",", 1) + # header = "data:image/png;base64" + mt = header.split(";")[0].split(":", 1)[1] + return (mt or None), (b64_content or None) + except (ValueError, IndexError): + return None, None + + +# -- encoding helpers ------------------------------------------------------ + + +def data_to_base64(data: str | bytes) -> str: + """Ensure *data* is a base-64 encoded string. + + * ``bytes`` -> base-64 encoded. + * ``str`` that is a ``data:`` URL -> base-64 content extracted. + * ``str`` that is an ``http(s)://`` URL -> returned as-is (caller + must handle). + * ``str`` that is not a URL -> assumed to already be base-64. + """ + if isinstance(data, bytes): + return base64.b64encode(data).decode("ascii") + if data.startswith("data:"): + _, b64 = split_data_url(data) + if b64 is not None: + return b64 + return data + + +def data_to_data_url(data: str | bytes, media_type: str) -> str: + """Convert *data* to a ``data:`` URL. Passes through existing URLs.""" + if isinstance(data, str) and is_url(data): + return data + b64 = data_to_base64(data) + return f"data:{media_type};base64,{b64}" + + +# -- media-type inference -------------------------------------------------- + + +def infer_media_type(url: str) -> str: + """Infer IANA media type from a URL. + + * ``data:image/png;base64,...`` -> ``"image/png"`` + * ``https://example.com/cat.jpg`` -> ``"image/jpeg"`` (via :mod:`mimetypes`) + * Unknown -> raises :class:`ValueError` + """ + if url.startswith("data:"): + # data:[][;base64], + rest = url[5:] # strip "data:" + sep = rest.find(",") + meta = rest[:sep] if sep != -1 else rest + mt = meta.split(";")[0] + if mt: + return mt + else: + guessed, _ = mimetypes.guess_type(url) + if guessed: + return guessed + raise ValueError( + f"Cannot infer media_type from URL: {url!r}. Provide media_type explicitly." + ) + + +# --------------------------------------------------------------------------- +# Signature definitions +# --------------------------------------------------------------------------- + +# Each signature is a tuple of (media_type, byte_prefix) where byte_prefix +# is a tuple of ``int | None`` values. ``None`` is a wildcard that matches +# any byte (mirrors the TS SDK's ``null`` sentinel). + +_Signature = tuple[str, tuple[int | None, ...]] + +IMAGE_SIGNATURES: list[_Signature] = [ + ("image/gif", (0x47, 0x49, 0x46)), + ("image/png", (0x89, 0x50, 0x4E, 0x47)), + ("image/jpeg", (0xFF, 0xD8)), + ( + "image/webp", + (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x45, 0x42, 0x50), + ), + ("image/bmp", (0x42, 0x4D)), + ("image/tiff", (0x49, 0x49, 0x2A, 0x00)), # little-endian + ("image/tiff", (0x4D, 0x4D, 0x00, 0x2A)), # big-endian + ( + "image/avif", + (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66), + ), + ( + "image/heic", + (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63), + ), +] + +AUDIO_SIGNATURES: list[_Signature] = [ + ("audio/mpeg", (0xFF, 0xFB)), + ("audio/mpeg", (0xFF, 0xFA)), + ("audio/mpeg", (0xFF, 0xF3)), + ("audio/mpeg", (0xFF, 0xF2)), + ("audio/mpeg", (0xFF, 0xE3)), + ("audio/mpeg", (0xFF, 0xE2)), + ( + "audio/wav", + (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x41, 0x56, 0x45), + ), + ("audio/ogg", (0x4F, 0x67, 0x67, 0x53)), + ("audio/flac", (0x66, 0x4C, 0x61, 0x43)), + ("audio/aac", (0x40, 0x15, 0x00, 0x00)), + ("audio/mp4", (0x66, 0x74, 0x79, 0x70)), + ("audio/webm", (0x1A, 0x45, 0xDF, 0xA3)), +] + +VIDEO_SIGNATURES: list[_Signature] = [ + ("video/mp4", (0x00, 0x00, 0x00, None, 0x66, 0x74, 0x79, 0x70)), + ("video/webm", (0x1A, 0x45, 0xDF, 0xA3)), + ( + "video/quicktime", + (0x00, 0x00, 0x00, 0x14, 0x66, 0x74, 0x79, 0x70, 0x71, 0x74), + ), + ("video/x-msvideo", (0x52, 0x49, 0x46, 0x46)), +] + + +# --------------------------------------------------------------------------- +# ID3 tag stripping (for MP3 files that start with ID3v2 metadata) +# --------------------------------------------------------------------------- + +_ID3_HEADER = bytes([0x49, 0x44, 0x33]) # "ID3" +_ID3_BASE64 = "SUQz" # base64("ID3") + + +def _strip_id3_tags(data: bytes) -> bytes: + """Strip an ID3v2 tag header if present, returning the audio data.""" + if len(data) < 10 or data[:3] != _ID3_HEADER: + return data + # Syncsafe integer: 4 bytes, 7 bits each + size = ( + (data[6] & 0x7F) << 21 + | (data[7] & 0x7F) << 14 + | (data[8] & 0x7F) << 7 + | (data[9] & 0x7F) + ) + offset = size + 10 + return data[offset:] if offset < len(data) else data + + +def _strip_id3_tags_base64(data: str) -> str: + """Strip an ID3v2 tag from base64-encoded data if present.""" + if not data.startswith(_ID3_BASE64): + return data + # Decode enough to read the ID3 header (10 bytes = ~16 base64 chars) + try: + header = _b64.b64decode(data[:16]) + except Exception: + return data + if len(header) < 10 or header[:3] != _ID3_HEADER: + return data + size = ( + (header[6] & 0x7F) << 21 + | (header[7] & 0x7F) << 14 + | (header[8] & 0x7F) << 7 + | (header[9] & 0x7F) + ) + offset = size + 10 + # Re-encode: decode full data, strip, re-encode + try: + full = _b64.b64decode(data) + stripped = full[offset:] if offset < len(full) else full + return _b64.b64encode(stripped).decode("ascii") + except Exception: + return data + + +# --------------------------------------------------------------------------- +# Core detection +# --------------------------------------------------------------------------- + + +def _to_bytes(data: bytes | str, *, max_bytes: int = 24) -> bytes: + """Convert *data* to bytes for signature comparison. + + For ``str`` input (base-64), decodes only the first *max_bytes* + characters worth of data to avoid decoding large payloads. + """ + if isinstance(data, bytes): + return data[:max_bytes] + # base-64: 4 chars → 3 bytes. Decode ~32 chars to get enough bytes. + chunk = data[: max_bytes * 2] + # Pad to multiple of 4 for valid base64 + padded = chunk + "=" * (-len(chunk) % 4) + try: + return _b64.b64decode(padded)[:max_bytes] + except Exception: + return b"" + + +def detect_media_type( + data: bytes | str, + signatures: list[_Signature], +) -> str | None: + """Detect media type from magic bytes. + + Args: + data: Raw bytes or a base-64 encoded string. + signatures: List of ``(media_type, byte_prefix)`` tuples to + match against (e.g. :data:`IMAGE_SIGNATURES`). + + Returns: + The matched IANA media type, or ``None`` if no signature matches. + """ + # Strip ID3 tags for audio detection + if signatures is AUDIO_SIGNATURES: + if isinstance(data, bytes): + data = _strip_id3_tags(data) + else: + data = _strip_id3_tags_base64(data) + + raw = _to_bytes(data) + if not raw: + return None + + for media_type, prefix in signatures: + if len(raw) < len(prefix): + continue + if all( + expected is None or raw[i] == expected for i, expected in enumerate(prefix) + ): + return media_type + + return None + + +def detect_image_media_type(data: bytes | str) -> str | None: + """Detect image format from magic bytes.""" + return detect_media_type(data, IMAGE_SIGNATURES) + + +def detect_audio_media_type(data: bytes | str) -> str | None: + """Detect audio format from magic bytes.""" + return detect_media_type(data, AUDIO_SIGNATURES) + + +DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +class DownloadError(Exception): + """Raised when a URL download fails.""" + + def __init__( + self, + url: str, + *, + status_code: int | None = None, + status_text: str | None = None, + cause: BaseException | None = None, + ) -> None: + parts = [f"Failed to download {url!r}"] + if status_code is not None: + parts.append(f"status={status_code}") + if status_text: + parts.append(status_text) + super().__init__(": ".join(parts)) + self.url = url + self.status_code = status_code + if cause is not None: + self.__cause__ = cause + + +def _validate_url(url: str) -> None: + """Reject non-HTTP(S) URLs (SSRF prevention).""" + from urllib.parse import urlparse + + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_SCHEMES: + raise DownloadError( + url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" + ) + + +async def download( + url: str, + *, + max_bytes: int = DEFAULT_MAX_BYTES, +) -> tuple[bytes, str | None]: + """Download *url* and return ``(data, content_type)``. + + Args: + url: The URL to fetch (must be ``http`` or ``https``). + max_bytes: Maximum response size. Defaults to 100 MiB. + + Returns: + A tuple of ``(raw_bytes, content_type_or_None)``. + + Raises: + DownloadError: On any failure (network, HTTP status, size, etc.). + """ + _validate_url(url) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(url) + + # Validate redirect target + if resp.url is not None and str(resp.url) != url: + _validate_url(str(resp.url)) + + if resp.status_code >= 400: + raise DownloadError( + url, + status_code=resp.status_code, + status_text=resp.reason_phrase or "", + ) + + data = resp.content + if len(data) > max_bytes: + raise DownloadError( + url, + status_text=( + f"Response exceeds maximum size " + f"({len(data)} > {max_bytes} bytes)" + ), + ) + + content_type = resp.headers.get("content-type") + # Strip charset/parameters: "image/png; charset=..." → "image/png" + if content_type: + content_type = content_type.split(";")[0].strip() + + return data, content_type or None + + except DownloadError: + raise + except Exception as exc: + raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/models2/core/helpers/streaming.py b/src/vercel_ai_sdk/models2/core/helpers/streaming.py new file mode 100644 index 00000000..11d27006 --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/helpers/streaming.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import AsyncGenerator + +import pydantic + +from ....types import messages as messages_ + + +@dataclasses.dataclass +class TextStart: + block_id: str + + +@dataclasses.dataclass +class TextDelta: + block_id: str + delta: str + + +@dataclasses.dataclass +class TextEnd: + block_id: str + + +@dataclasses.dataclass +class ReasoningStart: + block_id: str + + +@dataclasses.dataclass +class ReasoningDelta: + block_id: str + delta: str + + +@dataclasses.dataclass +class ReasoningEnd: + block_id: str + signature: str | None = None + + +@dataclasses.dataclass +class ToolStart: + tool_call_id: str + tool_name: str + + +@dataclasses.dataclass +class ToolArgsDelta: + tool_call_id: str + delta: str + + +@dataclasses.dataclass +class ToolEnd: + tool_call_id: str + + +@dataclasses.dataclass +class FileEvent: + """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" + + block_id: str + media_type: str + data: str # base64 string or data-URL from the gateway + + +@dataclasses.dataclass +class MessageDone: + finish_reason: str | None = None + usage: messages_.Usage | None = None + + +StreamEvent = ( + TextStart + | TextDelta + | TextEnd + | ReasoningStart + | ReasoningDelta + | ReasoningEnd + | ToolStart + | ToolArgsDelta + | ToolEnd + | FileEvent + | MessageDone +) + + +@dataclasses.dataclass +class StreamHandler: + """ + Accumulates LLM adapter events and produces Messages with stateful parts. + + This is the normalization layer between LLM adapters and the rest of the system. + """ + + message_id: str = dataclasses.field(default_factory=messages_._gen_id) + + # Accumulators + _text_blocks: dict[str, str] = dataclasses.field(default_factory=dict) + _reasoning_blocks: dict[str, tuple[str, str | None]] = dataclasses.field( + default_factory=dict + ) # (text, signature) + _tool_calls: dict[str, tuple[str, str]] = dataclasses.field( + default_factory=dict + ) # (name, args) + _files: dict[str, tuple[str, str]] = dataclasses.field( + default_factory=dict + ) # block_id -> (media_type, data) + + # Active tracking + _active_text_id: str | None = None + _active_reasoning_id: str | None = None + _active_tool_ids: set[str] = dataclasses.field(default_factory=set) + + _is_done: bool = False + _usage: messages_.Usage | None = None + + def handle_event(self, event: StreamEvent) -> messages_.Message: + """Process event and return current Message state.""" + + # Current deltas (reset each call) + text_delta: str | None = None + reasoning_delta: str | None = None + tool_deltas: dict[str, str] = {} # tool_call_id -> delta + + match event: + case TextStart(block_id=bid): + self._text_blocks[bid] = "" + self._active_text_id = bid + + case TextDelta(block_id=bid, delta=d): + self._text_blocks[bid] += d + text_delta = d + + case TextEnd(block_id=bid): + if self._active_text_id == bid: + self._active_text_id = None + + case ReasoningStart(block_id=bid): + self._reasoning_blocks[bid] = ("", None) + self._active_reasoning_id = bid + + case ReasoningDelta(block_id=bid, delta=d): + text, sig = self._reasoning_blocks[bid] + self._reasoning_blocks[bid] = (text + d, sig) + reasoning_delta = d + + case ReasoningEnd(block_id=bid, signature=sig): + text, _ = self._reasoning_blocks[bid] + self._reasoning_blocks[bid] = (text, sig) + if self._active_reasoning_id == bid: + self._active_reasoning_id = None + + case ToolStart(tool_call_id=tcid, tool_name=name): + self._tool_calls[tcid] = (name, "") + self._active_tool_ids.add(tcid) + + case ToolArgsDelta(tool_call_id=tcid, delta=d): + name, args = self._tool_calls[tcid] + self._tool_calls[tcid] = (name, args + d) + tool_deltas[tcid] = d + + case ToolEnd(tool_call_id=tcid): + self._active_tool_ids.discard(tcid) + + case FileEvent(block_id=bid, media_type=mt, data=d): + self._files[bid] = (mt, d) + + case MessageDone(usage=usage): + self._is_done = True + self._usage = usage + self._active_text_id = None + self._active_reasoning_id = None + self._active_tool_ids.clear() + + return self._build_message(text_delta, reasoning_delta, tool_deltas) + + def _build_message( + self, + text_delta: str | None, + reasoning_delta: str | None, + tool_deltas: dict[str, str], + ) -> messages_.Message: + parts: list[messages_.Part] = [] + + # Reasoning parts first (like thinking blocks) + for bid, (text, sig) in self._reasoning_blocks.items(): + is_active = bid == self._active_reasoning_id + parts.append( + messages_.ReasoningPart( + text=text, + signature=sig, + state="streaming" if is_active else "done", + delta=reasoning_delta if is_active else None, + ) + ) + + # Text parts + for bid, text in self._text_blocks.items(): + is_active = bid == self._active_text_id + parts.append( + messages_.TextPart( + text=text, + state="streaming" if is_active else "done", + delta=text_delta if is_active else None, + ) + ) + + # Tool parts + for tcid, (name, args) in self._tool_calls.items(): + is_active = tcid in self._active_tool_ids + parts.append( + messages_.ToolPart( + tool_call_id=tcid, + tool_name=name, + tool_args=args, + state="streaming" if is_active else "done", + args_delta=tool_deltas.get(tcid), + ) + ) + + # File parts (inline images/videos from LLMs like Gemini, GPT-5) + for _bid, (media_type, data) in self._files.items(): + parts.append(messages_.FilePart(data=data, media_type=media_type)) + + return messages_.Message( + id=self.message_id, + role="assistant", + parts=parts, + usage=self._usage if self._is_done else None, + ) + + +async def events_to_messages( + events: AsyncGenerator[StreamEvent], + output_type: type[pydantic.BaseModel] | None = None, +) -> AsyncGenerator[messages_.Message]: + """Convert a stream of events into Message snapshots. + + This is the standalone version of the logic that ``LanguageModel.stream()`` + uses. Wire functions call this to turn their ``StreamEvent`` generators + into ``Message`` generators suitable for ``Stream``. + """ + handler = StreamHandler() + msg: messages_.Message | None = None + async for event in events: + msg = handler.handle_event(event) + yield msg + + # After stream completes, validate and attach structured output part + if output_type is not None and msg is not None and msg.text: + data = json.loads(msg.text) + output_type.model_validate(data) # fail fast on bad data + part = messages_.StructuredOutputPart( + data=data, + output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", + ) + msg = msg.model_copy() + msg.parts = [*msg.parts, part] + yield msg diff --git a/src/vercel_ai_sdk/models2/core/model.py b/src/vercel_ai_sdk/models2/core/model.py new file mode 100644 index 00000000..e9a4d129 --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/model.py @@ -0,0 +1,34 @@ +"""Model metadata types.""" + +from __future__ import annotations + +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class ModelCost: + """Per-million-token pricing.""" + + input: float = 0.0 + output: float = 0.0 + cache_read: float = 0.0 + cache_write: float = 0.0 + + +@dataclasses.dataclass(frozen=True) +class Model: + """Pure-data description of a model. + + * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). + * ``api`` — wire protocol key (e.g. ``"ai-gateway"``, ``"anthropic-messages"``). + * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). + """ + + id: str + api: str + provider: str + name: str = "" + capabilities: tuple[str, ...] = ("text",) + context_window: int = 0 + max_output_tokens: int = 0 + cost: ModelCost | None = None diff --git a/src/vercel_ai_sdk/models2/core/wire.py b/src/vercel_ai_sdk/models2/core/wire.py new file mode 100644 index 00000000..bd557bac --- /dev/null +++ b/src/vercel_ai_sdk/models2/core/wire.py @@ -0,0 +1,54 @@ +"""Wire function protocols. + +A *wire function* translates between our ``Message`` types and a specific +provider API (e.g. ``"ai-gateway"``, ``"anthropic-messages"``). + +Wire functions are plain async generators / coroutines — no base class +required. The protocols below exist only for static type-checking. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Protocol, runtime_checkable + +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from .client import Client +from .model import Model + + +@runtime_checkable +class StreamFn(Protocol): + """Protocol for streaming wire functions. + + Implementations yield ``Message`` snapshots as the response streams + in. Each snapshot is a complete, self-contained message reflecting + the accumulated state up to that point. + """ + + def __call__( + self, + client: Client, + model: Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[messages_.Message]: ... + + +@runtime_checkable +class GenerateFn(Protocol): + """Protocol for non-streaming wire functions (images, video, etc.).""" + + async def __call__( + self, + client: Client, + model: Model, + messages: list[messages_.Message], + **kwargs: Any, + ) -> messages_.Message: ... diff --git a/src/vercel_ai_sdk/models2/wires/__init__.py b/src/vercel_ai_sdk/models2/wires/__init__.py new file mode 100644 index 00000000..77c50464 --- /dev/null +++ b/src/vercel_ai_sdk/models2/wires/__init__.py @@ -0,0 +1 @@ +"""Wire implementations for provider APIs.""" diff --git a/src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py b/src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py new file mode 100644 index 00000000..fbd1efd7 --- /dev/null +++ b/src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py @@ -0,0 +1,594 @@ +"""Vercel AI Gateway v3 wire — streaming and generation. + +Wire protocol for the AI Gateway's v3 endpoints: + +* ``/language-model`` — streaming text/tool/reasoning responses. +* ``/image-model`` — dedicated image generation. +* ``/video-model`` — dedicated video generation (SSE response). +""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import httpx +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ + +_PROTOCOL_VERSION = "0.0.1" + +# --------------------------------------------------------------------------- +# Request building — Message list → v3 prompt +# --------------------------------------------------------------------------- + + +async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to a v3 ``file`` content part.""" + data = part.data + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + + entry: dict[str, Any] = { + "type": "file", + "mediaType": part.media_type, + "data": media_.data_to_data_url(data, part.media_type), + } + if part.filename is not None: + entry["filename"] = part.filename + return entry + + +async def _messages_to_prompt( + messages: list[messages_.Message], +) -> list[dict[str, Any]]: + """Convert ``Message`` list to the v3 prompt wire format.""" + result: list[dict[str, Any]] = [] + + for msg in messages: + match msg.role: + case "system": + text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "system", "content": text}) + + case "user": + content: list[dict[str, Any]] = [] + for p in msg.parts: + if isinstance(p, messages_.TextPart): + content.append({"type": "text", "text": p.text}) + elif isinstance(p, messages_.FilePart): + content.append(await _file_part_to_v3(p)) + result.append({"role": "user", "content": content}) + + case "assistant": + assistant_content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text): + assistant_content.append( + {"type": "reasoning", "text": text} + ) + + case messages_.TextPart(text=text): + assistant_content.append({"type": "text", "text": text}) + + case messages_.ToolPart() as tp: + tool_input: Any = ( + json.loads(tp.tool_args) if tp.tool_args else {} + ) + assistant_content.append( + { + "type": "tool-call", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "input": tool_input, + } + ) + if tp.status in ("result", "error"): + output = ( + { + "type": "error-text", + "value": ( + str(tp.result) + if tp.result is not None + else "" + ), + } + if tp.status == "error" + else { + "type": "json", + "value": tp.result, + } + ) + tool_results.append( + { + "type": "tool-result", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "output": output, + } + ) + + result.append({"role": "assistant", "content": assistant_content}) + if tool_results: + result.append({"role": "tool", "content": tool_results}) + + return result + + +async def _build_request_body( + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[Any] | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """Build the ``LanguageModelV3CallOptions`` request body.""" + body: dict[str, Any] = { + "prompt": await _messages_to_prompt(messages), + } + if tools: + body["tools"] = [ + { + "type": "function", + "name": tool.name, + "description": tool.description, + "inputSchema": tool.param_schema, + } + for tool in tools + ] + if output_type is not None and issubclass(output_type, pydantic.BaseModel): + body["responseFormat"] = { + "type": "json", + "schema": output_type.model_json_schema(), + "name": output_type.__name__, + } + if kwargs.get("provider_options"): + body["providerOptions"] = kwargs["provider_options"] + return body + + +# --------------------------------------------------------------------------- +# SSE response parsing — v3 stream parts → StreamEvent +# --------------------------------------------------------------------------- + + +def _expand_tool_call(data: dict[str, Any]) -> list[streaming_.StreamEvent]: + """Expand a complete ``tool-call`` part into Start + ArgsDelta + End.""" + tc_id = data.get("toolCallId", "") + tool_name = data.get("toolName", "") + tool_input = data.get("input", "") + args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) + return [ + streaming_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + streaming_.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), + streaming_.ToolEnd(tool_call_id=tc_id), + ] + + +def _parse_usage(data: Any) -> messages_.Usage: + """Parse v3 usage data into an internal ``Usage``.""" + if not isinstance(data, dict): + return messages_.Usage() + + input_tokens_obj = data.get("inputTokens") + output_tokens_obj = data.get("outputTokens") + + if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): + inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} + out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} + return messages_.Usage( + input_tokens=inp.get("total") or 0, + output_tokens=out.get("total") or 0, + reasoning_tokens=out.get("reasoning"), + cache_read_tokens=inp.get("cacheRead"), + cache_write_tokens=inp.get("cacheWrite"), + raw=data, + ) + + return messages_.Usage( + input_tokens=data.get("prompt_tokens") or data.get("inputTokens") or 0, + output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), + raw=data, + ) + + +def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: + """Convert a ``LanguageModelV3StreamPart`` to internal events.""" + match data.get("type", ""): + case "text-start": + return [streaming_.TextStart(block_id=data.get("id", "text"))] + + case "text-delta": + return [ + streaming_.TextDelta( + block_id=data.get("id", "text"), + delta=data.get("textDelta", data.get("delta", "")), + ) + ] + + case "text-end": + return [streaming_.TextEnd(block_id=data.get("id", "text"))] + + case "reasoning-start": + return [streaming_.ReasoningStart(block_id=data.get("id", "reasoning"))] + + case "reasoning-delta": + return [ + streaming_.ReasoningDelta( + block_id=data.get("id", "reasoning"), + delta=data.get("delta", ""), + ) + ] + + case "reasoning-end": + return [streaming_.ReasoningEnd(block_id=data.get("id", "reasoning"))] + + case "tool-input-start": + return [ + streaming_.ToolStart( + tool_call_id=data.get("id", ""), + tool_name=data.get("toolName", ""), + ) + ] + + case "tool-input-delta": + return [ + streaming_.ToolArgsDelta( + tool_call_id=data.get("id", ""), + delta=data.get("delta", ""), + ) + ] + + case "tool-input-end": + return [streaming_.ToolEnd(tool_call_id=data.get("id", ""))] + + case "tool-call": + return _expand_tool_call(data) + + case "file": + return [ + streaming_.FileEvent( + block_id=data.get("id", f"file-{len(data)}"), + media_type=data.get("mediaType", "application/octet-stream"), + data=data.get("data", ""), + ) + ] + + case "finish": + usage_data = data.get("usage") + usage = _parse_usage(usage_data) if usage_data else None + match data.get("finishReason"): + case dict() as d: + finish_reason = d.get("unified", "stop") + case str() as s: + finish_reason = s + case _: + finish_reason = "stop" + return [streaming_.MessageDone(finish_reason=finish_reason, usage=usage)] + + case _: + return [] + + +# --------------------------------------------------------------------------- +# Headers +# --------------------------------------------------------------------------- + + +def _request_headers( + client: client_.Client, + model: model_.Model, + *, + streaming: bool, +) -> dict[str, str]: + """Build gateway-specific request headers.""" + h: dict[str, str] = { + "Content-Type": "application/json", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + "ai-language-model-specification-version": "3", + "ai-language-model-id": model.id, + "ai-language-model-streaming": str(streaming).lower(), + } + if client.api_key: + h["Authorization"] = f"Bearer {client.api_key}" + h["ai-gateway-auth-method"] = "api-key" + return h + + +# --------------------------------------------------------------------------- +# Public wire functions +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response through the AI Gateway v3 protocol. + + Yields ``Message`` snapshots as the response streams in. Each + snapshot is a complete, self-contained message reflecting the + accumulated state up to that point. + """ + body = await _build_request_body( + messages, tools=tools, output_type=output_type, **kwargs + ) + headers = _request_headers(client, model, streaming=True) + url = f"{client.base_url.rstrip('/')}/language-model" + + handler = streaming_.StreamHandler() + + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + ) as response: + if response.status_code >= 400: + await response.aread() + raise RuntimeError( + f"AI Gateway returned HTTP {response.status_code}: {response.text}" + ) + + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + data = json.loads(payload) + except json.JSONDecodeError: + continue + + for event in _parse_stream_part(data): + msg = handler.handle_event(event) + yield msg + + +# --------------------------------------------------------------------------- +# Generate — image / video (non-streaming media generation) +# --------------------------------------------------------------------------- + + +def _file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to the gateway wire format for input files.""" + data = part.data + if isinstance(data, str) and media_.is_url(data): + return {"type": "url", "url": data} + if isinstance(data, bytes): + b64 = base64.b64encode(data).decode("ascii") + elif isinstance(data, str): + b64 = data + else: + b64 = str(data) + return {"type": "file", "data": b64, "mediaType": part.media_type} + + +def _extract_prompt(messages: list[messages_.Message]) -> str: + """Concatenate all text from user/system messages.""" + parts: list[str] = [] + for msg in messages: + if msg.role in ("user", "system"): + for p in msg.parts: + if isinstance(p, messages_.TextPart): + parts.append(p.text) + return " ".join(parts) + + +def _extract_input_files(messages: list[messages_.Message]) -> list[messages_.FilePart]: + """Collect all file parts from user messages.""" + files: list[messages_.FilePart] = [] + for msg in messages: + if msg.role == "user": + for p in msg.parts: + if isinstance(p, messages_.FilePart): + files.append(p) + return files + + +def _generate_headers( + client: client_.Client, + model: model_.Model, + *, + spec_version_header: str, +) -> dict[str, str]: + """Build gateway request headers for generate endpoints.""" + h: dict[str, str] = { + "Content-Type": "application/json", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + spec_version_header: "3", + "ai-model-id": model.id, + } + if client.api_key: + h["Authorization"] = f"Bearer {client.api_key}" + h["ai-gateway-auth-method"] = "api-key" + return h + + +async def _generate_image( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + **kwargs: Any, +) -> messages_.Message: + """Hit ``/image-model`` and return a Message with FileParts.""" + prompt = _extract_prompt(messages) + input_files = _extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + "n": kwargs.get("n", 1), + "providerOptions": kwargs.get("provider_options", {}), + } + if kwargs.get("size") is not None: + body["size"] = kwargs["size"] + if kwargs.get("aspect_ratio") is not None: + body["aspectRatio"] = kwargs["aspect_ratio"] + if kwargs.get("seed") is not None: + body["seed"] = kwargs["seed"] + if input_files: + body["files"] = [_file_part_to_wire(f) for f in input_files] + + url = f"{client.base_url.rstrip('/')}/image-model" + headers = _generate_headers( + client, model, spec_version_header="ai-image-model-specification-version" + ) + + response = await client.http.post( + url, + json=body, + headers=headers, + ) + if response.status_code >= 400: + raise RuntimeError( + f"AI Gateway image-model returned HTTP {response.status_code}: " + f"{response.text}" + ) + + data = response.json() + raw_images: list[str] = data.get("images", []) + usage_data = data.get("usage") + usage = None + if usage_data: + usage = messages_.Usage( + input_tokens=usage_data.get("inputTokens") or 0, + output_tokens=usage_data.get("outputTokens") or 0, + ) + + files: list[messages_.FilePart] = [] + for img_b64 in raw_images: + media_type = media_.detect_image_media_type(img_b64) or "image/png" + files.append(messages_.FilePart(data=img_b64, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files, usage=usage) + + +async def _generate_video( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + **kwargs: Any, +) -> messages_.Message: + """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" + prompt = _extract_prompt(messages) + input_files = _extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + "n": kwargs.get("n", 1), + "providerOptions": kwargs.get("provider_options", {}), + } + if kwargs.get("aspect_ratio") is not None: + body["aspectRatio"] = kwargs["aspect_ratio"] + if kwargs.get("resolution") is not None: + body["resolution"] = kwargs["resolution"] + if kwargs.get("duration") is not None: + body["duration"] = kwargs["duration"] + if kwargs.get("fps") is not None: + body["fps"] = kwargs["fps"] + if kwargs.get("seed") is not None: + body["seed"] = kwargs["seed"] + if input_files: + body["image"] = _file_part_to_wire(input_files[0]) + + url = f"{client.base_url.rstrip('/')}/video-model" + headers = _generate_headers( + client, model, spec_version_header="ai-video-model-specification-version" + ) + headers["accept"] = "text/event-stream" + + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + timeout=httpx.Timeout(timeout=600.0, connect=10.0), + ) as response: + if response.status_code >= 400: + await response.aread() + raise RuntimeError( + f"AI Gateway video-model returned HTTP {response.status_code}: " + f"{response.text}" + ) + + # Read first SSE data event — the gateway sends a single result event. + event_data: dict[str, Any] = {} + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + event_data = json.loads(payload) + break + except json.JSONDecodeError: + continue + + if event_data.get("type") == "error": + raise RuntimeError( + f"AI Gateway video generation error: " + f"{event_data.get('message', 'unknown error')}" + ) + + raw_videos: list[dict[str, Any]] = event_data.get("videos", []) + files: list[messages_.FilePart] = [] + for video_data in raw_videos: + vtype = video_data.get("type", "base64") + media_type = video_data.get("mediaType", "video/mp4") + + if vtype == "url": + downloaded_bytes, content_type = await media_.download(video_data["url"]) + if content_type: + media_type = content_type + files.append( + messages_.FilePart(data=downloaded_bytes, media_type=media_type) + ) + else: + raw_data = video_data.get("data", "") + files.append(messages_.FilePart(data=raw_data, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files) + + +async def generate( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + **kwargs: Any, +) -> messages_.Message: + """Generate media (images or video) through the AI Gateway. + + Dispatches to ``/image-model`` or ``/video-model`` based on the + model's capabilities. + + Keyword args are forwarded to the underlying endpoint and may include + ``n``, ``size``, ``aspect_ratio``, ``seed``, ``duration``, ``fps``, + ``resolution``, ``provider_options``. + """ + caps = model.capabilities + if "video" in caps: + return await _generate_video(client, model, messages, **kwargs) + # Default to image generation + return await _generate_image(client, model, messages, **kwargs) From c228556da47287b22a3875cd8e06ae2d3856e279 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 2 Apr 2026 18:21:51 -0700 Subject: [PATCH 06/18] Rename wires / apis to adapters for consistency, move gateway implementation into its own directory --- examples/models2/buffer.py | 2 +- .../{direct_wire.py => direct_adapter.py} | 8 +-- examples/models2/explicit_client.py | 2 +- examples/models2/image_generation.py | 2 +- examples/models2/inline_image.py | 2 +- examples/models2/multimodal_input.py | 2 +- examples/models2/stream.py | 2 +- examples/models2/structured_output.py | 2 +- examples/models2/tools.py | 2 +- examples/models2/video_generation.py | 2 +- src/vercel_ai_sdk/models2/__init__.py | 59 ++++++++++--------- .../models2/ai_gateway/__init__.py | 5 ++ .../adapter.py} | 6 +- src/vercel_ai_sdk/models2/core/__init__.py | 2 +- src/vercel_ai_sdk/models2/core/client.py | 6 +- src/vercel_ai_sdk/models2/core/model.py | 4 +- .../models2/core/{wire.py => proto.py} | 12 ++-- src/vercel_ai_sdk/models2/wires/__init__.py | 1 - 18 files changed, 63 insertions(+), 58 deletions(-) rename examples/models2/{direct_wire.py => direct_adapter.py} (75%) create mode 100644 src/vercel_ai_sdk/models2/ai_gateway/__init__.py rename src/vercel_ai_sdk/models2/{wires/ai_gateway_v3.py => ai_gateway/adapter.py} (99%) rename src/vercel_ai_sdk/models2/core/{wire.py => proto.py} (75%) delete mode 100644 src/vercel_ai_sdk/models2/wires/__init__.py diff --git a/examples/models2/buffer.py b/examples/models2/buffer.py index 5b05f32c..5cb88c7e 100644 --- a/examples/models2/buffer.py +++ b/examples/models2/buffer.py @@ -7,7 +7,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/direct_wire.py b/examples/models2/direct_adapter.py similarity index 75% rename from examples/models2/direct_wire.py rename to examples/models2/direct_adapter.py index 66175483..a9ff3b86 100644 --- a/examples/models2/direct_wire.py +++ b/examples/models2/direct_adapter.py @@ -1,15 +1,15 @@ -"""Direct wire call — bypass the registry, call the wire function directly.""" +"""Direct adapter call — bypass the registry, call the adapter function directly.""" import asyncio import os from vercel_ai_sdk import models2 as m -from vercel_ai_sdk.models2.wires import ai_gateway_v3 +from vercel_ai_sdk.models2.ai_gateway import adapter as ai_gateway_v3 from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) @@ -27,7 +27,7 @@ async def main() -> None: - # Call the wire function directly — no registry lookup, no auto-client. + # Call the adapter function directly — no registry lookup, no auto-client. # This is the lowest level of the API. try: async for msg in ai_gateway_v3.stream(client, model, messages): diff --git a/examples/models2/explicit_client.py b/examples/models2/explicit_client.py index baa319cc..e4539623 100644 --- a/examples/models2/explicit_client.py +++ b/examples/models2/explicit_client.py @@ -8,7 +8,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/image_generation.py b/examples/models2/image_generation.py index 565cf78b..7d9da9c0 100644 --- a/examples/models2/image_generation.py +++ b/examples/models2/image_generation.py @@ -9,7 +9,7 @@ model = m.Model( id="google/imagen-4.0-generate-001", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", capabilities=("image",), ) diff --git a/examples/models2/inline_image.py b/examples/models2/inline_image.py index 17dfb7e6..4686b33e 100644 --- a/examples/models2/inline_image.py +++ b/examples/models2/inline_image.py @@ -15,7 +15,7 @@ # This is a language model that can also output images inline. model = m.Model( id="google/gemini-3-pro-image", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", capabilities=("text", "image"), ) diff --git a/examples/models2/multimodal_input.py b/examples/models2/multimodal_input.py index 21a5f43d..f5d3b475 100644 --- a/examples/models2/multimodal_input.py +++ b/examples/models2/multimodal_input.py @@ -8,7 +8,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/stream.py b/examples/models2/stream.py index 76155401..7e1d08a9 100644 --- a/examples/models2/stream.py +++ b/examples/models2/stream.py @@ -7,7 +7,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/structured_output.py b/examples/models2/structured_output.py index c984d672..c1572988 100644 --- a/examples/models2/structured_output.py +++ b/examples/models2/structured_output.py @@ -9,7 +9,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/tools.py b/examples/models2/tools.py index aa1d389f..2e25eb96 100644 --- a/examples/models2/tools.py +++ b/examples/models2/tools.py @@ -8,7 +8,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) diff --git a/examples/models2/video_generation.py b/examples/models2/video_generation.py index 8b9e17d1..77900c89 100644 --- a/examples/models2/video_generation.py +++ b/examples/models2/video_generation.py @@ -9,7 +9,7 @@ model = m.Model( id="google/veo-3.0-generate-001", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", capabilities=("video",), ) diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py index b41252d0..09f92f5d 100644 --- a/src/vercel_ai_sdk/models2/__init__.py +++ b/src/vercel_ai_sdk/models2/__init__.py @@ -7,7 +7,7 @@ model = m.Model( id="anthropic/claude-sonnet-4", - api="ai-gateway", + adapter="ai-gateway-v3", provider="ai-gateway", ) msgs = [Message(role="user", parts=[TextPart(text="hello")])] @@ -38,29 +38,29 @@ from ..types import tools as tools_ from .core.client import Client from .core.model import Model, ModelCost -from .core.wire import GenerateFn, StreamFn +from .core.proto import GenerateFn, StreamFn # --------------------------------------------------------------------------- -# Wire registry — maps api string → wire function. -# Wire modules are imported lazily on first use. +# Adapter registry — maps adapter string → adapter function. +# Adapter modules are imported lazily on first use. # --------------------------------------------------------------------------- -_stream_wires: dict[str, StreamFn] = {} -_generate_wires: dict[str, GenerateFn] = {} -_wires_loaded = False +_stream_adapters: dict[str, StreamFn] = {} +_generate_adapters: dict[str, GenerateFn] = {} +_adapters_loaded = False -def _ensure_wires() -> None: - """Lazily register built-in wire functions on first call.""" - global _wires_loaded # noqa: PLW0603 - if _wires_loaded: +def _ensure_adapters() -> None: + """Lazily register built-in adapter functions on first call.""" + global _adapters_loaded # noqa: PLW0603 + if _adapters_loaded: return - _wires_loaded = True + _adapters_loaded = True - from .wires import ai_gateway_v3 + from .ai_gateway import adapter as ai_gateway_v3 - _stream_wires["ai-gateway"] = ai_gateway_v3.stream - _generate_wires["ai-gateway"] = ai_gateway_v3.generate + _stream_adapters["ai-gateway-v3"] = ai_gateway_v3.stream + _generate_adapters["ai-gateway-v3"] = ai_gateway_v3.generate # --------------------------------------------------------------------------- @@ -102,19 +102,20 @@ async def stream( ) -> AsyncGenerator[messages_.Message]: """Stream an LLM response. - Resolves the wire function from ``model.api``, auto-creates a + Resolves the adapter function from ``model.adapter``, auto-creates a :class:`Client` from env vars if none is provided, and yields ``Message`` snapshots. """ - _ensure_wires() + _ensure_adapters() c = client or _auto_client(model) - wire_fn = _stream_wires.get(model.api) - if wire_fn is None: - registered = ", ".join(sorted(_stream_wires)) or "(none)" + adapter_fn = _stream_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_stream_adapters)) or "(none)" raise KeyError( - f"No stream wire registered for api={model.api!r}. Registered: {registered}" + f"No stream adapter registered for adapter={model.adapter!r}. " + f"Registered: {registered}" ) - async for msg in wire_fn( + async for msg in adapter_fn( c, model, messages, tools=tools, output_type=output_type, **kwargs ): yield msg @@ -129,19 +130,19 @@ async def generate( ) -> messages_.Message: """Generate a response (images, video, etc.). - Resolves the wire function from ``model.api``, auto-creates a + Resolves the adapter function from ``model.adapter``, auto-creates a :class:`Client` from env vars if none is provided. """ - _ensure_wires() + _ensure_adapters() c = client or _auto_client(model) - wire_fn = _generate_wires.get(model.api) - if wire_fn is None: - registered = ", ".join(sorted(_generate_wires)) or "(none)" + adapter_fn = _generate_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_generate_adapters)) or "(none)" raise KeyError( - f"No generate wire registered for api={model.api!r}. " + f"No generate adapter registered for adapter={model.adapter!r}. " f"Registered: {registered}" ) - return await wire_fn(c, model, messages, **kwargs) + return await adapter_fn(c, model, messages, **kwargs) async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: diff --git a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py new file mode 100644 index 00000000..8c149cb7 --- /dev/null +++ b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py @@ -0,0 +1,5 @@ +"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" + +from .adapter import generate, stream + +__all__ = ["generate", "stream"] diff --git a/src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py b/src/vercel_ai_sdk/models2/ai_gateway/adapter.py similarity index 99% rename from src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py rename to src/vercel_ai_sdk/models2/ai_gateway/adapter.py index fbd1efd7..288a3a04 100644 --- a/src/vercel_ai_sdk/models2/wires/ai_gateway_v3.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/adapter.py @@ -1,6 +1,6 @@ -"""Vercel AI Gateway v3 wire — streaming and generation. +"""Vercel AI Gateway v3 adapter — streaming and generation. -Wire protocol for the AI Gateway's v3 endpoints: +Adapter for the AI Gateway's v3 endpoints: * ``/language-model`` — streaming text/tool/reasoning responses. * ``/image-model`` — dedicated image generation. @@ -309,7 +309,7 @@ def _request_headers( # --------------------------------------------------------------------------- -# Public wire functions +# Public adapter functions # --------------------------------------------------------------------------- diff --git a/src/vercel_ai_sdk/models2/core/__init__.py b/src/vercel_ai_sdk/models2/core/__init__.py index 38cdf5c6..a99a9797 100644 --- a/src/vercel_ai_sdk/models2/core/__init__.py +++ b/src/vercel_ai_sdk/models2/core/__init__.py @@ -2,7 +2,7 @@ from .client import Client from .model import Model, ModelCost -from .wire import GenerateFn, StreamFn +from .proto import GenerateFn, StreamFn __all__ = [ "Client", diff --git a/src/vercel_ai_sdk/models2/core/client.py b/src/vercel_ai_sdk/models2/core/client.py index 88ae5454..6cb0fb12 100644 --- a/src/vercel_ai_sdk/models2/core/client.py +++ b/src/vercel_ai_sdk/models2/core/client.py @@ -1,4 +1,4 @@ -"""HTTP client for wire functions.""" +"""HTTP client for adapter functions.""" from __future__ import annotations @@ -11,8 +11,8 @@ class Client: """Connection parameters for a provider API. - Wire functions receive a ``Client`` instead of creating their own HTTP - session. This keeps auth and base URL decoupled from the wire logic. + Adapter functions receive a ``Client`` instead of creating their own HTTP + session. This keeps auth and base URL decoupled from the adapter logic. The :pyattr:`http` property lazily creates a shared :class:`httpx.AsyncClient` so that consecutive calls reuse the same diff --git a/src/vercel_ai_sdk/models2/core/model.py b/src/vercel_ai_sdk/models2/core/model.py index e9a4d129..cbf59f50 100644 --- a/src/vercel_ai_sdk/models2/core/model.py +++ b/src/vercel_ai_sdk/models2/core/model.py @@ -20,12 +20,12 @@ class Model: """Pure-data description of a model. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). - * ``api`` — wire protocol key (e.g. ``"ai-gateway"``, ``"anthropic-messages"``). + * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). """ id: str - api: str + adapter: str provider: str name: str = "" capabilities: tuple[str, ...] = ("text",) diff --git a/src/vercel_ai_sdk/models2/core/wire.py b/src/vercel_ai_sdk/models2/core/proto.py similarity index 75% rename from src/vercel_ai_sdk/models2/core/wire.py rename to src/vercel_ai_sdk/models2/core/proto.py index bd557bac..b8c0b993 100644 --- a/src/vercel_ai_sdk/models2/core/wire.py +++ b/src/vercel_ai_sdk/models2/core/proto.py @@ -1,9 +1,9 @@ -"""Wire function protocols. +"""Adapter function protocols. -A *wire function* translates between our ``Message`` types and a specific -provider API (e.g. ``"ai-gateway"``, ``"anthropic-messages"``). +An *adapter function* translates between our ``Message`` types and a specific +provider API (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). -Wire functions are plain async generators / coroutines — no base class +Adapter functions are plain async generators / coroutines — no base class required. The protocols below exist only for static type-checking. """ @@ -22,7 +22,7 @@ @runtime_checkable class StreamFn(Protocol): - """Protocol for streaming wire functions. + """Protocol for streaming adapter functions. Implementations yield ``Message`` snapshots as the response streams in. Each snapshot is a complete, self-contained message reflecting @@ -43,7 +43,7 @@ def __call__( @runtime_checkable class GenerateFn(Protocol): - """Protocol for non-streaming wire functions (images, video, etc.).""" + """Protocol for non-streaming adapter functions (images, video, etc.).""" async def __call__( self, diff --git a/src/vercel_ai_sdk/models2/wires/__init__.py b/src/vercel_ai_sdk/models2/wires/__init__.py deleted file mode 100644 index 77c50464..00000000 --- a/src/vercel_ai_sdk/models2/wires/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Wire implementations for provider APIs.""" From 19a0f5d74166e42638b68b7aa5362008b1a31566 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 08:45:23 -0700 Subject: [PATCH 07/18] Refactor the adapter portion to improve clarity --- examples/models2/direct_adapter.py | 2 +- examples/models2/image_generation.py | 2 +- examples/models2/video_generation.py | 3 +- src/vercel_ai_sdk/models2/__init__.py | 21 +- .../models2/ai_gateway/__init__.py | 11 +- .../models2/ai_gateway/_common.py | 145 +++++++++ .../models2/ai_gateway/generate.py | 241 +++++++++++++++ .../ai_gateway/{adapter.py => stream.py} | 286 +----------------- src/vercel_ai_sdk/models2/core/proto.py | 9 +- 9 files changed, 429 insertions(+), 291 deletions(-) create mode 100644 src/vercel_ai_sdk/models2/ai_gateway/_common.py create mode 100644 src/vercel_ai_sdk/models2/ai_gateway/generate.py rename src/vercel_ai_sdk/models2/ai_gateway/{adapter.py => stream.py} (55%) diff --git a/examples/models2/direct_adapter.py b/examples/models2/direct_adapter.py index a9ff3b86..fe680dae 100644 --- a/examples/models2/direct_adapter.py +++ b/examples/models2/direct_adapter.py @@ -4,7 +4,7 @@ import os from vercel_ai_sdk import models2 as m -from vercel_ai_sdk.models2.ai_gateway import adapter as ai_gateway_v3 +from vercel_ai_sdk.models2 import ai_gateway as ai_gateway_v3 from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/image_generation.py b/examples/models2/image_generation.py index 7d9da9c0..8ba4d318 100644 --- a/examples/models2/image_generation.py +++ b/examples/models2/image_generation.py @@ -32,7 +32,7 @@ async def main() -> None: - result = await m.generate(model, messages, n=2, aspect_ratio="16:9") + result = await m.generate(model, messages, m.ImageParams(n=2, aspect_ratio="16:9")) print(f"Generated {len(result.images)} image(s)") for i, img in enumerate(result.images): diff --git a/examples/models2/video_generation.py b/examples/models2/video_generation.py index 77900c89..ece777ad 100644 --- a/examples/models2/video_generation.py +++ b/examples/models2/video_generation.py @@ -37,8 +37,7 @@ async def main() -> None: result = await m.generate( model, messages, - aspect_ratio="16:9", - duration=8, + m.VideoParams(aspect_ratio="16:9", duration=8), ) print(f"Generated {len(result.videos)} video(s)") diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py index 09f92f5d..92ef3c22 100644 --- a/src/vercel_ai_sdk/models2/__init__.py +++ b/src/vercel_ai_sdk/models2/__init__.py @@ -36,6 +36,7 @@ from ..types import messages as messages_ from ..types import tools as tools_ +from .ai_gateway.generate import GenerateParams, ImageParams, VideoParams from .core.client import Client from .core.model import Model, ModelCost from .core.proto import GenerateFn, StreamFn @@ -57,10 +58,11 @@ def _ensure_adapters() -> None: return _adapters_loaded = True - from .ai_gateway import adapter as ai_gateway_v3 + from .ai_gateway import generate as ai_gw_generate + from .ai_gateway import stream as ai_gw_stream - _stream_adapters["ai-gateway-v3"] = ai_gateway_v3.stream - _generate_adapters["ai-gateway-v3"] = ai_gateway_v3.generate + _stream_adapters["ai-gateway-v3"] = ai_gw_stream + _generate_adapters["ai-gateway-v3"] = ai_gw_generate # --------------------------------------------------------------------------- @@ -124,14 +126,20 @@ async def stream( async def generate( model: Model, messages: list[messages_.Message], + params: GenerateParams | None = None, *, client: Client | None = None, - **kwargs: Any, ) -> messages_.Message: """Generate a response (images, video, etc.). Resolves the adapter function from ``model.adapter``, auto-creates a :class:`Client` from env vars if none is provided. + + ``params`` controls the generation type: + + * :class:`ImageParams` — image generation (``/image-model``). + * :class:`VideoParams` — video generation (``/video-model``). + * ``None`` — auto-detect from ``model.capabilities``. """ _ensure_adapters() c = client or _auto_client(model) @@ -142,7 +150,7 @@ async def generate( f"No generate adapter registered for adapter={model.adapter!r}. " f"Registered: {registered}" ) - return await adapter_fn(c, model, messages, **kwargs) + return await adapter_fn(c, model, messages, params=params) async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: @@ -162,9 +170,12 @@ async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: # Core types "Client", "GenerateFn", + "GenerateParams", + "ImageParams", "Model", "ModelCost", "StreamFn", + "VideoParams", # Public API "buffer", "generate", diff --git a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py index 8c149cb7..0963a906 100644 --- a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py @@ -1,5 +1,12 @@ """AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" -from .adapter import generate, stream +from .generate import GenerateParams, ImageParams, VideoParams, generate +from .stream import stream -__all__ = ["generate", "stream"] +__all__ = [ + "GenerateParams", + "ImageParams", + "VideoParams", + "generate", + "stream", +] diff --git a/src/vercel_ai_sdk/models2/ai_gateway/_common.py b/src/vercel_ai_sdk/models2/ai_gateway/_common.py new file mode 100644 index 00000000..0031661f --- /dev/null +++ b/src/vercel_ai_sdk/models2/ai_gateway/_common.py @@ -0,0 +1,145 @@ +"""Shared helpers for the AI Gateway v3 adapter. + +Contains utilities used by both the streaming (language-model) and generation +(image-model, video-model) endpoints. + +.. note:: + + Several helpers here are candidates for lifting to framework-level: + + - ``extract_prompt`` / ``extract_input_files`` → ``Message`` methods + - ``parse_sse_lines`` → ``core/helpers/sse.py`` +""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from ...types import messages as messages_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ + +_PROTOCOL_VERSION = "0.0.1" + + +# --------------------------------------------------------------------------- +# Message extraction helpers +# --------------------------------------------------------------------------- +# TODO: lift to Message methods — these are universally useful. + + +def extract_prompt(messages: list[messages_.Message]) -> str: + """Concatenate all text from user/system messages into a single prompt string.""" + parts: list[str] = [] + for msg in messages: + if msg.role in ("user", "system"): + for p in msg.parts: + if isinstance(p, messages_.TextPart): + parts.append(p.text) + return " ".join(parts) + + +def extract_input_files(messages: list[messages_.Message]) -> list[messages_.FilePart]: + """Collect all file parts from user messages.""" + files: list[messages_.FilePart] = [] + for msg in messages: + if msg.role == "user": + for p in msg.parts: + if isinstance(p, messages_.FilePart): + files.append(p) + return files + + +# --------------------------------------------------------------------------- +# Wire format helpers +# --------------------------------------------------------------------------- + + +def file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to the gateway wire format for input files.""" + data = part.data + if isinstance(data, str) and media_.is_url(data): + return {"type": "url", "url": data} + if isinstance(data, bytes): + b64 = base64.b64encode(data).decode("ascii") + elif isinstance(data, str): + b64 = data + else: + b64 = str(data) + return {"type": "file", "data": b64, "mediaType": part.media_type} + + +# --------------------------------------------------------------------------- +# Request headers +# --------------------------------------------------------------------------- + + +def request_headers( + client: client_.Client, + model: model_.Model, + *, + model_type: str = "language", + streaming: bool = False, +) -> dict[str, str]: + """Build gateway-specific request headers. + + Args: + client: The HTTP client (provides api_key). + model: The model (provides id). + model_type: One of ``"language"``, ``"image"``, ``"video"``. + streaming: Whether this is a streaming request (language-model only). + """ + h: dict[str, str] = { + "Content-Type": "application/json", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + + if model_type == "language": + h["ai-language-model-specification-version"] = "3" + h["ai-language-model-id"] = model.id + h["ai-language-model-streaming"] = str(streaming).lower() + elif model_type == "image": + h["ai-image-model-specification-version"] = "3" + h["ai-model-id"] = model.id + elif model_type == "video": + h["ai-video-model-specification-version"] = "3" + h["ai-model-id"] = model.id + + if client.api_key: + h["Authorization"] = f"Bearer {client.api_key}" + h["ai-gateway-auth-method"] = "api-key" + + return h + + +# --------------------------------------------------------------------------- +# SSE parsing +# --------------------------------------------------------------------------- +# TODO: lift to core/helpers/sse.py — any SSE-based adapter will need this. + + +async def parse_sse_lines( + response: httpx.Response, +) -> AsyncGenerator[dict[str, Any]]: + """Yield parsed JSON dicts from an SSE response stream. + + Handles the ``data: `` / ``data: [DONE]`` protocol used by the + AI Gateway's streaming endpoints. + """ + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + yield json.loads(payload) + except json.JSONDecodeError: + continue diff --git a/src/vercel_ai_sdk/models2/ai_gateway/generate.py b/src/vercel_ai_sdk/models2/ai_gateway/generate.py new file mode 100644 index 00000000..fd65819f --- /dev/null +++ b/src/vercel_ai_sdk/models2/ai_gateway/generate.py @@ -0,0 +1,241 @@ +"""AI Gateway v3 generation adapter — image-model and video-model endpoints. + +Provides typed parameter objects (:class:`ImageParams`, :class:`VideoParams`) +and a unified :func:`generate` entry point that dispatches based on param type +and validates against model capabilities. +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pydantic + +from ...types import messages as messages_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from . import _common + +# --------------------------------------------------------------------------- +# Parameter types +# --------------------------------------------------------------------------- + +_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) + + +class ImageParams(pydantic.BaseModel): + """Parameters for image generation (``/image-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + size: str | None = None + aspect_ratio: str | None = pydantic.Field(None, alias="aspectRatio") + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, alias="providerOptions" + ) + + +class VideoParams(pydantic.BaseModel): + """Parameters for video generation (``/video-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + aspect_ratio: str | None = pydantic.Field(None, alias="aspectRatio") + resolution: str | None = None + duration: int | None = None + fps: int | None = None + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, alias="providerOptions" + ) + + +GenerateParams = ImageParams | VideoParams + + +# --------------------------------------------------------------------------- +# Image generation — /image-model +# --------------------------------------------------------------------------- + + +async def _generate_image( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: ImageParams, +) -> messages_.Message: + """Hit ``/image-model`` and return a Message with FileParts.""" + prompt = _common.extract_prompt(messages) + input_files = _common.extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["files"] = [_common.file_part_to_wire(f) for f in input_files] + + url = f"{client.base_url.rstrip('/')}/image-model" + headers = _common.request_headers(client, model, model_type="image") + + response = await client.http.post(url, json=body, headers=headers) + if response.status_code >= 400: + raise RuntimeError( + f"AI Gateway image-model returned HTTP {response.status_code}: " + f"{response.text}" + ) + + data = response.json() + raw_images: list[str] = data.get("images", []) + usage_data = data.get("usage") + usage = None + if usage_data: + usage = messages_.Usage( + input_tokens=usage_data.get("inputTokens") or 0, + output_tokens=usage_data.get("outputTokens") or 0, + ) + + files: list[messages_.FilePart] = [] + for img_b64 in raw_images: + media_type = media_.detect_image_media_type(img_b64) or "image/png" + files.append(messages_.FilePart(data=img_b64, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files, usage=usage) + + +# --------------------------------------------------------------------------- +# Video generation — /video-model (SSE response) +# --------------------------------------------------------------------------- + + +async def _generate_video( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: VideoParams, +) -> messages_.Message: + """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" + prompt = _common.extract_prompt(messages) + input_files = _common.extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["image"] = _common.file_part_to_wire(input_files[0]) + + url = f"{client.base_url.rstrip('/')}/video-model" + headers = _common.request_headers(client, model, model_type="video") + headers["accept"] = "text/event-stream" + + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + timeout=httpx.Timeout(timeout=600.0, connect=10.0), + ) as response: + if response.status_code >= 400: + await response.aread() + raise RuntimeError( + f"AI Gateway video-model returned HTTP {response.status_code}: " + f"{response.text}" + ) + + # Read first SSE data event — the gateway sends a single result event. + event_data: dict[str, Any] = {} + async for parsed in _common.parse_sse_lines(response): + event_data = parsed + break + + if event_data.get("type") == "error": + raise RuntimeError( + f"AI Gateway video generation error: " + f"{event_data.get('message', 'unknown error')}" + ) + + raw_videos: list[dict[str, Any]] = event_data.get("videos", []) + files: list[messages_.FilePart] = [] + for video_data in raw_videos: + vtype = video_data.get("type", "base64") + media_type = video_data.get("mediaType", "video/mp4") + + if vtype == "url": + downloaded_bytes, content_type = await media_.download(video_data["url"]) + if content_type: + media_type = content_type + files.append( + messages_.FilePart(data=downloaded_bytes, media_type=media_type) + ) + else: + raw_data = video_data.get("data", "") + files.append(messages_.FilePart(data=raw_data, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +def _check_capabilities( + model: model_.Model, + params: GenerateParams, +) -> None: + """Validate that model capabilities match the requested generation type.""" + caps = model.capabilities + + if isinstance(params, VideoParams): + if "video" not in caps: + raise ValueError( + f"Model {model.id!r} does not have 'video' capability " + f"(capabilities={caps}). Use ImageParams for image models." + ) + if "text" in caps and "video" not in caps: + raise ValueError( + f"Model {model.id!r} is a text model (capabilities={caps}). " + f"Use stream() for text generation, not generate()." + ) + elif isinstance(params, ImageParams): + if "video" in caps and "image" not in caps: + raise ValueError( + f"Model {model.id!r} has 'video' capability but not 'image' " + f"(capabilities={caps}). Use VideoParams for video models." + ) + if "text" in caps and "image" not in caps: + raise ValueError( + f"Model {model.id!r} is a text model (capabilities={caps}). " + f"Use stream() for text generation, not generate()." + ) + + +async def generate( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: GenerateParams | None = None, +) -> messages_.Message: + """Generate media (images or video) through the AI Gateway. + + Dispatches to ``/image-model`` or ``/video-model`` based on ``params`` + type, with fallback to model capabilities when ``params`` is ``None``. + + Raises :class:`ValueError` if the model capabilities don't match the + requested generation type. + """ + # Auto-detect from capabilities when no params provided + if params is None: + params = VideoParams() if "video" in model.capabilities else ImageParams() + + _check_capabilities(model, params) + + if isinstance(params, VideoParams): + return await _generate_video(client, model, messages, params) + return await _generate_image(client, model, messages, params) diff --git a/src/vercel_ai_sdk/models2/ai_gateway/adapter.py b/src/vercel_ai_sdk/models2/ai_gateway/stream.py similarity index 55% rename from src/vercel_ai_sdk/models2/ai_gateway/adapter.py rename to src/vercel_ai_sdk/models2/ai_gateway/stream.py index 288a3a04..0c330dfa 100644 --- a/src/vercel_ai_sdk/models2/ai_gateway/adapter.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/stream.py @@ -1,20 +1,14 @@ -"""Vercel AI Gateway v3 adapter — streaming and generation. +"""AI Gateway v3 streaming adapter — language-model endpoint. -Adapter for the AI Gateway's v3 endpoints: - -* ``/language-model`` — streaming text/tool/reasoning responses. -* ``/image-model`` — dedicated image generation. -* ``/video-model`` — dedicated video generation (SSE response). +Handles text, tool-call, reasoning, and inline file streaming via SSE. """ from __future__ import annotations -import base64 import json from collections.abc import AsyncGenerator, Sequence from typing import Any -import httpx import pydantic from ...types import messages as messages_ @@ -23,8 +17,7 @@ from ..core import model as model_ from ..core.helpers import media as media_ from ..core.helpers import streaming as streaming_ - -_PROTOCOL_VERSION = "0.0.1" +from . import _common # --------------------------------------------------------------------------- # Request building — Message list → v3 prompt @@ -284,32 +277,7 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: # --------------------------------------------------------------------------- -# Headers -# --------------------------------------------------------------------------- - - -def _request_headers( - client: client_.Client, - model: model_.Model, - *, - streaming: bool, -) -> dict[str, str]: - """Build gateway-specific request headers.""" - h: dict[str, str] = { - "Content-Type": "application/json", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - "ai-language-model-specification-version": "3", - "ai-language-model-id": model.id, - "ai-language-model-streaming": str(streaming).lower(), - } - if client.api_key: - h["Authorization"] = f"Bearer {client.api_key}" - h["ai-gateway-auth-method"] = "api-key" - return h - - -# --------------------------------------------------------------------------- -# Public adapter functions +# Public adapter function # --------------------------------------------------------------------------- @@ -331,7 +299,9 @@ async def stream( body = await _build_request_body( messages, tools=tools, output_type=output_type, **kwargs ) - headers = _request_headers(client, model, streaming=True) + headers = _common.request_headers( + client, model, model_type="language", streaming=True + ) url = f"{client.base_url.rstrip('/')}/language-model" handler = streaming_.StreamHandler() @@ -348,247 +318,7 @@ async def stream( f"AI Gateway returned HTTP {response.status_code}: {response.text}" ) - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - data = json.loads(payload) - except json.JSONDecodeError: - continue - + async for data in _common.parse_sse_lines(response): for event in _parse_stream_part(data): msg = handler.handle_event(event) yield msg - - -# --------------------------------------------------------------------------- -# Generate — image / video (non-streaming media generation) -# --------------------------------------------------------------------------- - - -def _file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to the gateway wire format for input files.""" - data = part.data - if isinstance(data, str) and media_.is_url(data): - return {"type": "url", "url": data} - if isinstance(data, bytes): - b64 = base64.b64encode(data).decode("ascii") - elif isinstance(data, str): - b64 = data - else: - b64 = str(data) - return {"type": "file", "data": b64, "mediaType": part.media_type} - - -def _extract_prompt(messages: list[messages_.Message]) -> str: - """Concatenate all text from user/system messages.""" - parts: list[str] = [] - for msg in messages: - if msg.role in ("user", "system"): - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append(p.text) - return " ".join(parts) - - -def _extract_input_files(messages: list[messages_.Message]) -> list[messages_.FilePart]: - """Collect all file parts from user messages.""" - files: list[messages_.FilePart] = [] - for msg in messages: - if msg.role == "user": - for p in msg.parts: - if isinstance(p, messages_.FilePart): - files.append(p) - return files - - -def _generate_headers( - client: client_.Client, - model: model_.Model, - *, - spec_version_header: str, -) -> dict[str, str]: - """Build gateway request headers for generate endpoints.""" - h: dict[str, str] = { - "Content-Type": "application/json", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - spec_version_header: "3", - "ai-model-id": model.id, - } - if client.api_key: - h["Authorization"] = f"Bearer {client.api_key}" - h["ai-gateway-auth-method"] = "api-key" - return h - - -async def _generate_image( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - **kwargs: Any, -) -> messages_.Message: - """Hit ``/image-model`` and return a Message with FileParts.""" - prompt = _extract_prompt(messages) - input_files = _extract_input_files(messages) - - body: dict[str, Any] = { - "prompt": prompt, - "n": kwargs.get("n", 1), - "providerOptions": kwargs.get("provider_options", {}), - } - if kwargs.get("size") is not None: - body["size"] = kwargs["size"] - if kwargs.get("aspect_ratio") is not None: - body["aspectRatio"] = kwargs["aspect_ratio"] - if kwargs.get("seed") is not None: - body["seed"] = kwargs["seed"] - if input_files: - body["files"] = [_file_part_to_wire(f) for f in input_files] - - url = f"{client.base_url.rstrip('/')}/image-model" - headers = _generate_headers( - client, model, spec_version_header="ai-image-model-specification-version" - ) - - response = await client.http.post( - url, - json=body, - headers=headers, - ) - if response.status_code >= 400: - raise RuntimeError( - f"AI Gateway image-model returned HTTP {response.status_code}: " - f"{response.text}" - ) - - data = response.json() - raw_images: list[str] = data.get("images", []) - usage_data = data.get("usage") - usage = None - if usage_data: - usage = messages_.Usage( - input_tokens=usage_data.get("inputTokens") or 0, - output_tokens=usage_data.get("outputTokens") or 0, - ) - - files: list[messages_.FilePart] = [] - for img_b64 in raw_images: - media_type = media_.detect_image_media_type(img_b64) or "image/png" - files.append(messages_.FilePart(data=img_b64, media_type=media_type)) - - return messages_.Message(role="assistant", parts=files, usage=usage) - - -async def _generate_video( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - **kwargs: Any, -) -> messages_.Message: - """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" - prompt = _extract_prompt(messages) - input_files = _extract_input_files(messages) - - body: dict[str, Any] = { - "prompt": prompt, - "n": kwargs.get("n", 1), - "providerOptions": kwargs.get("provider_options", {}), - } - if kwargs.get("aspect_ratio") is not None: - body["aspectRatio"] = kwargs["aspect_ratio"] - if kwargs.get("resolution") is not None: - body["resolution"] = kwargs["resolution"] - if kwargs.get("duration") is not None: - body["duration"] = kwargs["duration"] - if kwargs.get("fps") is not None: - body["fps"] = kwargs["fps"] - if kwargs.get("seed") is not None: - body["seed"] = kwargs["seed"] - if input_files: - body["image"] = _file_part_to_wire(input_files[0]) - - url = f"{client.base_url.rstrip('/')}/video-model" - headers = _generate_headers( - client, model, spec_version_header="ai-video-model-specification-version" - ) - headers["accept"] = "text/event-stream" - - async with client.http.stream( - "POST", - url, - json=body, - headers=headers, - timeout=httpx.Timeout(timeout=600.0, connect=10.0), - ) as response: - if response.status_code >= 400: - await response.aread() - raise RuntimeError( - f"AI Gateway video-model returned HTTP {response.status_code}: " - f"{response.text}" - ) - - # Read first SSE data event — the gateway sends a single result event. - event_data: dict[str, Any] = {} - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - event_data = json.loads(payload) - break - except json.JSONDecodeError: - continue - - if event_data.get("type") == "error": - raise RuntimeError( - f"AI Gateway video generation error: " - f"{event_data.get('message', 'unknown error')}" - ) - - raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - files: list[messages_.FilePart] = [] - for video_data in raw_videos: - vtype = video_data.get("type", "base64") - media_type = video_data.get("mediaType", "video/mp4") - - if vtype == "url": - downloaded_bytes, content_type = await media_.download(video_data["url"]) - if content_type: - media_type = content_type - files.append( - messages_.FilePart(data=downloaded_bytes, media_type=media_type) - ) - else: - raw_data = video_data.get("data", "") - files.append(messages_.FilePart(data=raw_data, media_type=media_type)) - - return messages_.Message(role="assistant", parts=files) - - -async def generate( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - **kwargs: Any, -) -> messages_.Message: - """Generate media (images or video) through the AI Gateway. - - Dispatches to ``/image-model`` or ``/video-model`` based on the - model's capabilities. - - Keyword args are forwarded to the underlying endpoint and may include - ``n``, ``size``, ``aspect_ratio``, ``seed``, ``duration``, ``fps``, - ``resolution``, ``provider_options``. - """ - caps = model.capabilities - if "video" in caps: - return await _generate_video(client, model, messages, **kwargs) - # Default to image generation - return await _generate_image(client, model, messages, **kwargs) diff --git a/src/vercel_ai_sdk/models2/core/proto.py b/src/vercel_ai_sdk/models2/core/proto.py index b8c0b993..1ceb8ff2 100644 --- a/src/vercel_ai_sdk/models2/core/proto.py +++ b/src/vercel_ai_sdk/models2/core/proto.py @@ -43,12 +43,17 @@ def __call__( @runtime_checkable class GenerateFn(Protocol): - """Protocol for non-streaming adapter functions (images, video, etc.).""" + """Protocol for non-streaming adapter functions (images, video, etc.). + + ``params`` is typed as ``Any`` at the protocol level because each adapter + defines its own parameter types (e.g. ``ImageParams | VideoParams``). + Type safety is enforced at the top-level ``generate()`` function. + """ async def __call__( self, client: Client, model: Model, messages: list[messages_.Message], - **kwargs: Any, + params: Any = None, ) -> messages_.Message: ... From 07327eb3ed444d654cee020c80f0203ff919ff0b Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 09:54:09 -0700 Subject: [PATCH 08/18] Port custom error handling from the old module --- .../models2/ai_gateway/__init__.py | 2 + .../models2/ai_gateway/errors.py | 305 ++++++++++++++++++ .../models2/ai_gateway/generate.py | 26 +- .../models2/ai_gateway/stream.py | 46 ++- 4 files changed, 354 insertions(+), 25 deletions(-) create mode 100644 src/vercel_ai_sdk/models2/ai_gateway/errors.py diff --git a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py index 0963a906..7cc9f429 100644 --- a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py @@ -1,5 +1,6 @@ """AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" +from . import errors from .generate import GenerateParams, ImageParams, VideoParams, generate from .stream import stream @@ -7,6 +8,7 @@ "GenerateParams", "ImageParams", "VideoParams", + "errors", "generate", "stream", ] diff --git a/src/vercel_ai_sdk/models2/ai_gateway/errors.py b/src/vercel_ai_sdk/models2/ai_gateway/errors.py new file mode 100644 index 00000000..d0dade24 --- /dev/null +++ b/src/vercel_ai_sdk/models2/ai_gateway/errors.py @@ -0,0 +1,305 @@ +"""Vercel AI Gateway error hierarchy. + +Maps HTTP error responses from the gateway server to typed Python exceptions. +Each error class corresponds to a specific ``error.type`` value in the +gateway's JSON error response format:: + + { + "error": { + "message": "...", + "type": "authentication_error" | "invalid_request_error" | ..., + "param": ..., + "code": ... + }, + "generationId": "..." + } +""" + +import json +from typing import Any, Self + +_KEY_URL = "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys" + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + + +class GatewayError(Exception): + """Base class for all Vercel AI Gateway errors.""" + + type: str = "gateway_error" + + def __init__( + self, + message: str = "", + *, + status_code: int = 500, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + display = f"{message} [{generation_id}]" if generation_id else message + super().__init__(display) + self.status_code = status_code + self.generation_id = generation_id + if cause is not None: + self.__cause__ = cause + + +# --------------------------------------------------------------------------- +# Concrete errors — thin subclasses that set type + default status_code +# --------------------------------------------------------------------------- + + +class GatewayAuthenticationError(GatewayError): + """Authentication failed (HTTP 401).""" + + type = "authentication_error" + + def __init__( + self, + message: str = "Authentication failed", + *, + status_code: int = 401, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + @classmethod + def create_contextual( + cls, + *, + api_key_provided: bool, + status_code: int = 401, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> Self: + """Build a helpful message based on which auth method was used.""" + if api_key_provided: + msg = ( + "AI Gateway authentication failed: Invalid API key.\n\n" + f"Create a new API key: {_KEY_URL}\n\n" + "Provide via 'api_key' option or " + "'AI_GATEWAY_API_KEY' environment variable." + ) + else: + msg = ( + "AI Gateway authentication failed: " + "No authentication provided.\n\n" + f"Create an API key: {_KEY_URL}\n" + "Provide via 'api_key' option or " + "'AI_GATEWAY_API_KEY' environment variable." + ) + return cls( + msg, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + +class GatewayInvalidRequestError(GatewayError): + """Malformed or invalid request (HTTP 400).""" + + type = "invalid_request_error" + + def __init__( + self, + message: str = "Invalid request", + *, + status_code: int = 400, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayRateLimitError(GatewayError): + """Rate limit exceeded (HTTP 429).""" + + type = "rate_limit_exceeded" + + def __init__( + self, + message: str = "Rate limit exceeded", + *, + status_code: int = 429, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayModelNotFoundError(GatewayError): + """Requested model was not found (HTTP 404).""" + + type = "model_not_found" + + def __init__( + self, + message: str = "Model not found", + *, + status_code: int = 404, + model_id: str | None = None, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + self.model_id = model_id + + +class GatewayInternalServerError(GatewayError): + """Internal error on the gateway server (HTTP 500).""" + + type = "internal_server_error" + + def __init__( + self, + message: str = "Internal server error", + *, + status_code: int = 500, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayResponseError(GatewayError): + """Malformed or unparseable response (HTTP 502).""" + + type = "response_error" + + def __init__( + self, + message: str = "Invalid response", + *, + status_code: int = 502, + response: Any = None, + validation_error: Any = None, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + self.response = response + self.validation_error = validation_error + + +class GatewayTimeoutError(GatewayError): + """Gateway request timed out (HTTP 408).""" + + type = "timeout_error" + + def __init__( + self, + message: str = "Request timed out", + *, + status_code: int = 408, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +# --------------------------------------------------------------------------- +# Error factory +# --------------------------------------------------------------------------- + +_TYPE_MAP: dict[str, type[GatewayError]] = { + "authentication_error": GatewayAuthenticationError, + "invalid_request_error": GatewayInvalidRequestError, + "rate_limit_exceeded": GatewayRateLimitError, + "model_not_found": GatewayModelNotFoundError, + "internal_server_error": GatewayInternalServerError, +} + +_MALFORMED = "Invalid error response format: Gateway request failed" + + +def create_gateway_error( + *, + response_body: Any, + status_code: int, + api_key_provided: bool = False, + cause: BaseException | None = None, +) -> GatewayError: + """Create a typed error from a gateway JSON error response. + + Falls back to :class:`GatewayResponseError` when the body doesn't + match the expected ``{"error": {"message": ..., "type": ...}}`` + shape. + """ + # Parse the response body + body: Any = response_body + if isinstance(body, (str, bytes)): + try: + body = json.loads(body) + except (json.JSONDecodeError, ValueError): + return GatewayResponseError( + message=_MALFORMED, + status_code=status_code, + response=response_body, + validation_error="Response body is not valid JSON", + cause=cause, + ) + + # Validate shape + error_obj = body.get("error") if isinstance(body, dict) else None + if not isinstance(error_obj, dict) or "message" not in error_obj: + reason = ( + "Missing 'error' field in response" + if not isinstance(error_obj, dict) + else "Missing 'message' field in error object" + ) + return GatewayResponseError( + message=_MALFORMED, + status_code=status_code, + response=body, + validation_error=reason, + cause=cause, + ) + + message: str = error_obj["message"] + error_type: str | None = error_obj.get("type") + generation_id: str | None = body.get("generationId") + + match error_type: + case "authentication_error": + return GatewayAuthenticationError.create_contextual( + api_key_provided=api_key_provided, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + case "model_not_found": + param = error_obj.get("param") + model_id = param.get("modelId") if isinstance(param, dict) else None + return GatewayModelNotFoundError( + message=message, + status_code=status_code, + model_id=model_id, + cause=cause, + generation_id=generation_id, + ) + + case _: + cls = _TYPE_MAP.get(error_type or "", GatewayInternalServerError) + return cls( + message=message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) diff --git a/src/vercel_ai_sdk/models2/ai_gateway/generate.py b/src/vercel_ai_sdk/models2/ai_gateway/generate.py index fd65819f..304c8bce 100644 --- a/src/vercel_ai_sdk/models2/ai_gateway/generate.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/generate.py @@ -17,6 +17,7 @@ from ..core import model as model_ from ..core.helpers import media as media_ from . import _common +from . import errors as errors_ # --------------------------------------------------------------------------- # Parameter types @@ -85,9 +86,10 @@ async def _generate_image( response = await client.http.post(url, json=body, headers=headers) if response.status_code >= 400: - raise RuntimeError( - f"AI Gateway image-model returned HTTP {response.status_code}: " - f"{response.text}" + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), ) data = response.json() @@ -143,9 +145,10 @@ async def _generate_video( ) as response: if response.status_code >= 400: await response.aread() - raise RuntimeError( - f"AI Gateway video-model returned HTTP {response.status_code}: " - f"{response.text}" + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), ) # Read first SSE data event — the gateway sends a single result event. @@ -154,10 +157,15 @@ async def _generate_video( event_data = parsed break + if not event_data: + raise errors_.GatewayResponseError( + "SSE stream ended without any data events", + ) + if event_data.get("type") == "error": - raise RuntimeError( - f"AI Gateway video generation error: " - f"{event_data.get('message', 'unknown error')}" + raise errors_.GatewayInvalidRequestError( + message=event_data.get("message", "unknown error"), + status_code=event_data.get("statusCode", 400), ) raw_videos: list[dict[str, Any]] = event_data.get("videos", []) diff --git a/src/vercel_ai_sdk/models2/ai_gateway/stream.py b/src/vercel_ai_sdk/models2/ai_gateway/stream.py index 0c330dfa..92a63266 100644 --- a/src/vercel_ai_sdk/models2/ai_gateway/stream.py +++ b/src/vercel_ai_sdk/models2/ai_gateway/stream.py @@ -9,6 +9,7 @@ from collections.abc import AsyncGenerator, Sequence from typing import Any +import httpx import pydantic from ...types import messages as messages_ @@ -18,6 +19,7 @@ from ..core.helpers import media as media_ from ..core.helpers import streaming as streaming_ from . import _common +from . import errors as errors_ # --------------------------------------------------------------------------- # Request building — Message list → v3 prompt @@ -306,19 +308,31 @@ async def stream( handler = streaming_.StreamHandler() - async with client.http.stream( - "POST", - url, - json=body, - headers=headers, - ) as response: - if response.status_code >= 400: - await response.aread() - raise RuntimeError( - f"AI Gateway returned HTTP {response.status_code}: {response.text}" - ) - - async for data in _common.parse_sse_lines(response): - for event in _parse_stream_part(data): - msg = handler.handle_event(event) - yield msg + try: + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + ) as response: + if response.status_code >= 400: + await response.aread() + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), + ) + + async for data in _common.parse_sse_lines(response): + for event in _parse_stream_part(data): + msg = handler.handle_event(event) + yield msg + except errors_.GatewayError: + raise + except httpx.TimeoutException as exc: + raise errors_.GatewayTimeoutError(cause=exc) from exc + except Exception as exc: + raise errors_.GatewayResponseError( + message=f"Unexpected error during streaming: {exc}", + cause=exc, + ) from exc From c21201edf9bc23e4455ea4f6c9f0c6052924beb3 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 09:54:32 -0700 Subject: [PATCH 09/18] Port tests --- tests/models2/__init__.py | 0 tests/models2/ai_gateway/__init__.py | 0 tests/models2/ai_gateway/test_errors.py | 139 ++++++ .../models2/ai_gateway/test_generate_image.py | 293 +++++++++++ .../models2/ai_gateway/test_generate_video.py | 400 +++++++++++++++ tests/models2/ai_gateway/test_protocol.py | 460 ++++++++++++++++++ tests/models2/ai_gateway/test_stream.py | 430 ++++++++++++++++ tests/models2/core/__init__.py | 0 tests/models2/core/test_media.py | 372 ++++++++++++++ tests/models2/core/test_streaming.py | 257 ++++++++++ 10 files changed, 2351 insertions(+) create mode 100644 tests/models2/__init__.py create mode 100644 tests/models2/ai_gateway/__init__.py create mode 100644 tests/models2/ai_gateway/test_errors.py create mode 100644 tests/models2/ai_gateway/test_generate_image.py create mode 100644 tests/models2/ai_gateway/test_generate_video.py create mode 100644 tests/models2/ai_gateway/test_protocol.py create mode 100644 tests/models2/ai_gateway/test_stream.py create mode 100644 tests/models2/core/__init__.py create mode 100644 tests/models2/core/test_media.py create mode 100644 tests/models2/core/test_streaming.py diff --git a/tests/models2/__init__.py b/tests/models2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models2/ai_gateway/__init__.py b/tests/models2/ai_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models2/ai_gateway/test_errors.py b/tests/models2/ai_gateway/test_errors.py new file mode 100644 index 00000000..01f151ef --- /dev/null +++ b/tests/models2/ai_gateway/test_errors.py @@ -0,0 +1,139 @@ +"""Tests for the gateway error factory. + +The factory ``create_gateway_error`` is the real point of contact: +it parses the JSON error response from the gateway server and +dispatches to the correct error class. These tests use payloads +matching the actual gateway wire format. +""" + +from __future__ import annotations + +import json + +from vercel_ai_sdk.models2.ai_gateway import errors + + +class TestGatewayErrorBase: + """Base class behaviour that all concrete errors inherit.""" + + def test_isinstance_hierarchy(self) -> None: + err = errors.GatewayRateLimitError("nope") + assert isinstance(err, errors.GatewayError) + assert isinstance(err, Exception) + + def test_generation_id_in_message(self) -> None: + err = errors.GatewayInternalServerError("boom", generation_id="gen-123") + assert "[gen-123]" in str(err) + assert err.generation_id == "gen-123" + + def test_cause_chained(self) -> None: + original = ValueError("original") + err = errors.GatewayInternalServerError("boom", cause=original) + assert err.__cause__ is original + + +class TestCreateGatewayError: + """The factory must dispatch on ``error.type`` from the response.""" + + def test_authentication_error_from_json_string(self) -> None: + body = json.dumps( + { + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + } + ) + err = errors.create_gateway_error( + response_body=body, + status_code=401, + api_key_provided=True, + ) + assert isinstance(err, errors.GatewayAuthenticationError) + assert err.status_code == 401 + # contextual message includes the key URL + assert "vercel.com/d?to=" in str(err) + + def test_invalid_request_error(self) -> None: + body = { + "error": { + "message": "Bad format", + "type": "invalid_request_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=400) + assert isinstance(err, errors.GatewayInvalidRequestError) + assert err.status_code == 400 + + def test_rate_limit_error(self) -> None: + body = { + "error": { + "message": "Rate limit exceeded", + "type": "rate_limit_exceeded", + } + } + err = errors.create_gateway_error(response_body=body, status_code=429) + assert isinstance(err, errors.GatewayRateLimitError) + + def test_model_not_found_extracts_model_id(self) -> None: + body = { + "error": { + "message": "Model xyz not found", + "type": "model_not_found", + "param": {"modelId": "xyz"}, + } + } + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayModelNotFoundError) + assert err.model_id == "xyz" + + def test_model_not_found_without_param(self) -> None: + body = { + "error": { + "message": "Not found", + "type": "model_not_found", + } + } + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayModelNotFoundError) + assert err.model_id is None + + def test_internal_server_error(self) -> None: + body = { + "error": { + "message": "Database down", + "type": "internal_server_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=500) + assert isinstance(err, errors.GatewayInternalServerError) + + def test_unknown_type_falls_back_to_internal(self) -> None: + body = { + "error": { + "message": "Something weird", + "type": "alien_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=500) + assert isinstance(err, errors.GatewayInternalServerError) + + def test_malformed_json_string(self) -> None: + err = errors.create_gateway_error(response_body="Not JSON", status_code=500) + assert isinstance(err, errors.GatewayResponseError) + + def test_missing_error_field(self) -> None: + body = {"ferror": {"message": "oops"}} + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayResponseError) + + def test_generation_id_extracted(self) -> None: + body = { + "error": { + "message": "Rate limit", + "type": "rate_limit_exceeded", + }, + "generationId": "gen-abc", + } + err = errors.create_gateway_error(response_body=body, status_code=429) + assert err.generation_id == "gen-abc" diff --git a/tests/models2/ai_gateway/test_generate_image.py b/tests/models2/ai_gateway/test_generate_image.py new file mode 100644 index 00000000..2d8d0f82 --- /dev/null +++ b/tests/models2/ai_gateway/test_generate_image.py @@ -0,0 +1,293 @@ +"""Integration tests for the AI Gateway v3 image generation adapter. + +Every test exercises the real ``generate()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: + + generate(client, model, messages, ImageParams(...)) + → extract prompt/images from messages + → httpx POST (mock) to /image-model + → JSON response parsing + → media type detection + → return Message with FileParts +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any + +import httpx +import pytest + +from vercel_ai_sdk.models2.ai_gateway import errors +from vercel_ai_sdk.models2.ai_gateway.generate import ( + ImageParams, + generate, +) +from vercel_ai_sdk.models2.core import client as client_ +from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.types import messages + +# 1x1 transparent PNG (minimal valid PNG for magic-byte detection) +_PNG_HEADER = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) +_PNG_B64 = base64.b64encode(_PNG_HEADER).decode() + +# 1x1 JPEG header +_JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) +_JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() + +_IMAGE_MODEL = model_.Model( + id="google/imagen-4.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("image",), +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Basic generation +# --------------------------------------------------------------------------- + + +class TestGenerate: + @pytest.mark.asyncio + async def test_basic_image_generation(self) -> None: + """Simple prompt -> one PNG image back.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"images": [_PNG_B64]}, + ) + + client = _client(httpx.MockTransport(handler)) + msg = await generate(client, _IMAGE_MODEL, [_user("A sunset over Tokyo")]) + + assert msg.role == "assistant" + assert len(msg.images) == 1 + assert msg.images[0].data == _PNG_B64 + assert msg.images[0].media_type == "image/png" + + @pytest.mark.asyncio + async def test_multiple_images(self) -> None: + """Request n=3 images.""" + + def handler(req: httpx.Request) -> httpx.Response: + body = json.loads(req.content) + assert body["n"] == 3 + return httpx.Response( + 200, + json={"images": [_PNG_B64, _JPEG_B64, _PNG_B64]}, + ) + + client = _client(httpx.MockTransport(handler)) + msg = await generate( + client, + _IMAGE_MODEL, + [_user("Three cats")], + params=ImageParams(n=3), + ) + + assert len(msg.images) == 3 + assert msg.images[0].media_type == "image/png" + assert msg.images[1].media_type == "image/jpeg" + assert msg.images[2].media_type == "image/png" + + @pytest.mark.asyncio + async def test_usage_parsing(self) -> None: + """Usage data from response surfaces on the Message.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "images": [_PNG_B64], + "usage": {"inputTokens": 50, "outputTokens": 100}, + }, + ) + + client = _client(httpx.MockTransport(handler)) + msg = await generate(client, _IMAGE_MODEL, [_user("a dog")]) + + assert msg.usage is not None + assert msg.usage.input_tokens == 50 + assert msg.usage.output_tokens == 100 + + +# --------------------------------------------------------------------------- +# Request format +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + model = model_.Model( + id="openai/gpt-image-1", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("image",), + ) + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await generate(client, model, [_user("Hi")]) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-image-model-specification-version"] == "3" + assert captured["ai-model-id"] == "openai/gpt-image-1" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_parameters_forwarded(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _IMAGE_MODEL, + [_user("landscape")], + params=ImageParams( + n=2, + size="1024x1024", + aspect_ratio="16:9", + seed=42, + provider_options={"google": {"style": "vivid"}}, + ), + ) + + assert captured_body["prompt"] == "landscape" + assert captured_body["n"] == 2 + assert captured_body["size"] == "1024x1024" + assert captured_body["aspectRatio"] == "16:9" + assert captured_body["seed"] == 42 + assert captured_body["providerOptions"] == {"google": {"style": "vivid"}} + + @pytest.mark.asyncio + async def test_input_images_forwarded(self) -> None: + """Input images from user messages -> files in request body.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + user_msg = messages.Message( + role="user", + parts=[ + messages.TextPart(text="Edit this"), + messages.FilePart(data=_PNG_B64, media_type="image/png"), + ], + ) + client = _client(httpx.MockTransport(handler)) + await generate(client, _IMAGE_MODEL, [user_msg]) + + assert captured_body["prompt"] == "Edit this" + assert "files" in captured_body + assert len(captured_body["files"]) == 1 + assert captured_body["files"][0]["type"] == "file" + assert captured_body["files"][0]["mediaType"] == "image/png" + + @pytest.mark.asyncio + async def test_url_posts_to_image_model_endpoint(self) -> None: + captured_url: list[str] = [] + + def handler(req: httpx.Request) -> httpx.Response: + captured_url.append(str(req.url)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + client = _client(httpx.MockTransport(handler)) + await generate(client, _IMAGE_MODEL, [_user("test")]) + + assert captured_url[0] == "https://gw.test/v3/ai/image-model" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) + + @pytest.mark.asyncio + async def test_429_rate_limit_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 429, + json={ + "error": { + "message": "Rate limited", + "type": "rate_limit_exceeded", + } + }, + ) + + with pytest.raises(errors.GatewayRateLimitError): + await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) + + @pytest.mark.asyncio + async def test_empty_images_returns_empty_message(self) -> None: + """Gateway returns empty images array -> message with no parts.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"images": []}) + + msg = await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) + assert len(msg.images) == 0 diff --git a/tests/models2/ai_gateway/test_generate_video.py b/tests/models2/ai_gateway/test_generate_video.py new file mode 100644 index 00000000..331aad07 --- /dev/null +++ b/tests/models2/ai_gateway/test_generate_video.py @@ -0,0 +1,400 @@ +"""Integration tests for the AI Gateway v3 video generation adapter. + +Every test exercises the real ``generate()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: + + generate(client, model, messages, VideoParams(...)) + → extract prompt/image from messages + → httpx POST (mock) to /video-model with SSE accept + → SSE event parsing + → video data handling (base64 or URL download) + → return Message with FileParts +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from vercel_ai_sdk.models2.ai_gateway import errors +from vercel_ai_sdk.models2.ai_gateway.generate import ( + VideoParams, + generate, +) +from vercel_ai_sdk.models2.core import client as client_ +from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.types import messages + +# MP4 magic bytes (ftyp box) +_MP4_HEADER = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D] +) +_MP4_B64 = base64.b64encode(_MP4_HEADER).decode() + +# WebM magic bytes +_WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) +_WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() + +_VIDEO_MODEL = model_.Model( + id="google/veo-3.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("video",), +) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sse(*events: dict[str, Any]) -> str: + """Build SSE response text from event dicts.""" + return "".join(f"data: {json.dumps(e)}\n\n" for e in events) + + +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Basic generation +# --------------------------------------------------------------------------- + + +class TestGenerate: + @pytest.mark.asyncio + async def test_basic_video_generation_base64(self) -> None: + """Simple prompt -> one MP4 video back via base64.""" + body = _sse( + { + "type": "result", + "videos": [ + {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"} + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + client = _client(httpx.MockTransport(handler)) + msg = await generate( + client, + _VIDEO_MODEL, + [_user("A cat walking on a beach")], + params=VideoParams(), + ) + + assert msg.role == "assistant" + assert len(msg.videos) == 1 + assert msg.videos[0].data == _MP4_B64 + assert msg.videos[0].media_type == "video/mp4" + + @pytest.mark.asyncio + async def test_video_generation_url(self) -> None: + """Video returned as URL -> downloaded automatically.""" + body = _sse( + { + "type": "result", + "videos": [ + { + "type": "url", + "url": "https://storage.example.com/video.mp4", + "mediaType": "video/mp4", + } + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + client = _client(httpx.MockTransport(handler)) + + with patch( + "vercel_ai_sdk.models2.core.helpers.media.download", + new_callable=AsyncMock, + return_value=(_MP4_HEADER, "video/mp4"), + ) as mock_dl: + msg = await generate( + client, + _VIDEO_MODEL, + [_user("A sunset timelapse")], + params=VideoParams(), + ) + + mock_dl.assert_called_once_with("https://storage.example.com/video.mp4") + assert len(msg.videos) == 1 + assert msg.videos[0].data == _MP4_HEADER + assert msg.videos[0].media_type == "video/mp4" + + @pytest.mark.asyncio + async def test_multiple_videos(self) -> None: + body = _sse( + { + "type": "result", + "videos": [ + {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"}, + {"type": "base64", "data": _WEBM_B64, "mediaType": "video/webm"}, + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + msg = await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("Two versions")], + params=VideoParams(n=2), + ) + assert len(msg.videos) == 2 + assert msg.videos[0].media_type == "video/mp4" + assert msg.videos[1].media_type == "video/webm" + + +# --------------------------------------------------------------------------- +# Request format +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await generate( + client, + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-video-model-specification-version"] == "3" + assert captured["ai-model-id"] == "google/veo-3.0-generate-001" + assert captured["accept"] == "text/event-stream" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_parameters_forwarded(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _VIDEO_MODEL, + [_user("sunset")], + params=VideoParams( + n=2, + aspect_ratio="16:9", + resolution="1920x1080", + duration=5, + fps=30, + seed=42, + provider_options={"google": {"enhancePrompt": True}}, + ), + ) + + assert captured_body["prompt"] == "sunset" + assert captured_body["n"] == 2 + assert captured_body["aspectRatio"] == "16:9" + assert captured_body["resolution"] == "1920x1080" + assert captured_body["duration"] == 5 + assert captured_body["fps"] == 30 + assert captured_body["seed"] == 42 + assert captured_body["providerOptions"] == {"google": {"enhancePrompt": True}} + + @pytest.mark.asyncio + async def test_url_posts_to_video_model_endpoint(self) -> None: + captured_url: list[str] = [] + + def handler(req: httpx.Request) -> httpx.Response: + captured_url.append(str(req.url)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) + + assert captured_url[0] == "https://gw.test/v3/ai/video-model" + + @pytest.mark.asyncio + async def test_image_to_video_input(self) -> None: + """Image in user message -> image field in request body.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + png_b64 = base64.b64encode(b"\x89PNG").decode() + user_msg = messages.Message( + role="user", + parts=[ + messages.TextPart(text="Animate this"), + messages.FilePart(data=png_b64, media_type="image/png"), + ], + ) + client = _client(httpx.MockTransport(handler)) + await generate(client, _VIDEO_MODEL, [user_msg], params=VideoParams()) + + assert captured_body["prompt"] == "Animate this" + assert "image" in captured_body + assert captured_body["image"]["type"] == "file" + assert captured_body["image"]["mediaType"] == "image/png" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_sse_error_event(self) -> None: + """Gateway returns an SSE error event -> raises.""" + body = _sse( + { + "type": "error", + "message": "Content policy violation", + "errorType": "content_filter", + "statusCode": 400, + "param": None, + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + with pytest.raises(errors.GatewayInvalidRequestError, match="Content policy"): + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) + + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Bad key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) + + @pytest.mark.asyncio + async def test_empty_sse_stream(self) -> None: + """SSE stream with no data events -> raises.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text="") + + with pytest.raises(errors.GatewayResponseError, match="SSE stream ended"): + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) diff --git a/tests/models2/ai_gateway/test_protocol.py b/tests/models2/ai_gateway/test_protocol.py new file mode 100644 index 00000000..c3afbc9a --- /dev/null +++ b/tests/models2/ai_gateway/test_protocol.py @@ -0,0 +1,460 @@ +"""Tests for the v3 protocol serialization and deserialization. + +Focus areas: +- ``_messages_to_prompt``: the critical outgoing translation layer +- ``_build_request_body``: using real ``@tool`` +- ``_parse_stream_part``: the critical incoming translation layer +- ``_parse_usage``: the two distinct wire formats +""" + +from __future__ import annotations + +import importlib +import json +from unittest.mock import AsyncMock, patch + +import pydantic +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.models2.core.helpers import streaming +from vercel_ai_sdk.types import messages + +# The ai_gateway __init__.py re-exports `stream` as a function, which +# shadows the module. Use importlib to get the actual module. +stream_mod = importlib.import_module("vercel_ai_sdk.models2.ai_gateway.stream") + +# --------------------------------------------------------------------------- +# _messages_to_prompt +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestMessagesToPrompt: + async def test_system_message(self) -> None: + msgs = [ + messages.Message( + role="system", + parts=[messages.TextPart(text="You are helpful.")], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + assert result == [{"role": "system", "content": "You are helpful."}] + + async def test_user_message(self) -> None: + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hello")], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + assert result == [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + } + ] + + async def test_assistant_with_reasoning_and_text(self) -> None: + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ReasoningPart(text="Let me think..."), + messages.TextPart(text="42"), + ], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + content = result[0]["content"] + assert content[0] == {"type": "reasoning", "text": "Let me think..."} + assert content[1] == {"type": "text", "text": "42"} + + async def test_tool_call_with_result_produces_two_messages(self) -> None: + """A completed tool call must produce an assistant message + (with the tool-call) AND a tool message (with the result).""" + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="get_weather", + tool_args='{"city": "SF"}', + status="result", + result={"temp": 72}, + ) + ], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + assert len(result) == 2 + + # Assistant message has the tool-call + tc = result[0]["content"][0] + assert tc["type"] == "tool-call" + assert tc["toolCallId"] == "tc-1" + assert tc["input"] == {"city": "SF"} + + # Tool message has the result + tr = result[1]["content"][0] + assert tr["type"] == "tool-result" + assert tr["output"] == {"type": "json", "value": {"temp": 72}} + + async def test_tool_error_result(self) -> None: + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="get_weather", + tool_args="{}", + status="error", + result="Connection timeout", + ) + ], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + tr = result[1]["content"][0] + assert tr["output"]["type"] == "error-text" + assert tr["output"]["value"] == "Connection timeout" + + async def test_user_message_with_image_url(self) -> None: + """FilePart with image URL -> downloaded and converted to data: URL.""" + fake_jpeg = b"\xff\xd8\xff\xe0" + msgs = [ + messages.Message( + role="user", + parts=[ + messages.TextPart(text="Look at this"), + messages.FilePart( + data="https://example.com/cat.jpg", media_type="image/jpeg" + ), + ], + ) + ] + with patch( + "vercel_ai_sdk.models2.core.helpers.media.download", + new_callable=AsyncMock, + return_value=(fake_jpeg, "image/jpeg"), + ): + result = await stream_mod._messages_to_prompt(msgs) + content = result[0]["content"] + assert content[0] == {"type": "text", "text": "Look at this"} + assert content[1]["type"] == "file" + assert content[1]["mediaType"] == "image/jpeg" + assert content[1]["data"].startswith("data:image/jpeg;base64,") + + async def test_user_message_with_file_bytes(self) -> None: + """FilePart with bytes -> v3 file content part with data URL.""" + msgs = [ + messages.Message( + role="user", + parts=[ + messages.FilePart( + data=b"\x89PNG", media_type="image/png", filename="pic.png" + ), + ], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + part = result[0]["content"][0] + assert part["type"] == "file" + assert part["mediaType"] == "image/png" + assert part["data"].startswith("data:image/png;base64,") + assert part["filename"] == "pic.png" + + async def test_user_message_text_only_unchanged(self) -> None: + """Regression: text-only user messages still work.""" + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hello")], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + assert result == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + async def test_pending_tool_call_no_tool_message(self) -> None: + """A pending tool call should NOT produce a tool-result message.""" + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="search", + tool_args="{}", + status="pending", + ) + ], + ) + ] + result = await stream_mod._messages_to_prompt(msgs) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + + +# --------------------------------------------------------------------------- +# _build_request_body — using real @tool +# --------------------------------------------------------------------------- + + +@ai.tool +async def get_weather(city: str, units: str = "celsius") -> str: + """Get the current weather for a city.""" + return f"Sunny in {city}" + + +@pytest.mark.asyncio +class TestBuildRequestBody: + async def test_with_real_tool(self) -> None: + """Verify @tool-produced schema round-trips through + _build_request_body -> JSON -> gateway wire format.""" + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="What's the weather?")], + ) + ] + body = await stream_mod._build_request_body(msgs, tools=[get_weather]) + + assert "tools" in body + tool_def = body["tools"][0] + assert tool_def["type"] == "function" + assert tool_def["name"] == "get_weather" + assert tool_def["description"] == ("Get the current weather for a city.") + # The schema comes from pydantic — verify structure, not exact dict + schema = tool_def["inputSchema"] + assert "properties" in schema + assert "city" in schema["properties"] + assert "units" in schema["properties"] + # 'city' is required (no default), 'units' is not (has default) + assert "city" in schema.get("required", []) + + async def test_with_output_type(self) -> None: + class WeatherResult(pydantic.BaseModel): + temp: float + condition: str + + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Weather?")], + ) + ] + body = await stream_mod._build_request_body(msgs, output_type=WeatherResult) + + assert "responseFormat" in body + rf = body["responseFormat"] + assert rf["type"] == "json" + assert rf["name"] == "WeatherResult" + assert "properties" in rf["schema"] + assert "temp" in rf["schema"]["properties"] + + async def test_provider_options_passthrough(self) -> None: + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hi")], + ) + ] + opts = {"gateway": {"order": ["bedrock", "openai"]}} + body = await stream_mod._build_request_body(msgs, provider_options=opts) + assert body["providerOptions"] == opts + + +# --------------------------------------------------------------------------- +# _parse_stream_part — parametrized simple 1:1 mappings +# --------------------------------------------------------------------------- + +_SIMPLE_STREAM_PARTS = [ + ( + {"type": "text-start", "id": "t1"}, + streaming.TextStart(block_id="t1"), + ), + ( + {"type": "text-end", "id": "t1"}, + streaming.TextEnd(block_id="t1"), + ), + ( + {"type": "reasoning-start", "id": "r1"}, + streaming.ReasoningStart(block_id="r1"), + ), + ( + {"type": "reasoning-delta", "id": "r1", "delta": "hmm"}, + streaming.ReasoningDelta(block_id="r1", delta="hmm"), + ), + ( + {"type": "reasoning-end", "id": "r1"}, + streaming.ReasoningEnd(block_id="r1"), + ), + ( + {"type": "tool-input-start", "id": "tc-1", "toolName": "search"}, + streaming.ToolStart(tool_call_id="tc-1", tool_name="search"), + ), + ( + {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q"'}, + streaming.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), + ), + ( + {"type": "tool-input-end", "id": "tc-1"}, + streaming.ToolEnd(tool_call_id="tc-1"), + ), +] + + +@pytest.mark.parametrize( + ("wire", "expected"), + _SIMPLE_STREAM_PARTS, + ids=[w["type"] for w, _ in _SIMPLE_STREAM_PARTS], +) +def test_parse_stream_part_simple( + wire: dict[str, object], expected: streaming.StreamEvent +) -> None: + events = stream_mod._parse_stream_part(wire) + assert len(events) == 1 + assert events[0] == expected + + +@pytest.mark.asyncio +class TestParseStreamPartComplex: + async def test_text_delta_uses_textDelta_key(self) -> None: + """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" + events = stream_mod._parse_stream_part( + {"type": "text-delta", "id": "t1", "textDelta": "Hello"} + ) + assert isinstance(events[0], streaming.TextDelta) + assert events[0].delta == "Hello" + + async def test_tool_call_expands_to_three_events(self) -> None: + """A complete ``tool-call`` part must expand into + ToolStart -> ToolArgsDelta -> ToolEnd.""" + events = stream_mod._parse_stream_part( + { + "type": "tool-call", + "toolCallId": "tc-1", + "toolName": "get_weather", + "input": {"city": "SF"}, + } + ) + assert len(events) == 3 + assert isinstance(events[0], streaming.ToolStart) + assert events[0].tool_name == "get_weather" + assert isinstance(events[1], streaming.ToolArgsDelta) + assert json.loads(events[1].delta) == {"city": "SF"} + assert isinstance(events[2], streaming.ToolEnd) + + async def test_finish_flat_usage(self) -> None: + events = stream_mod._parse_stream_part( + { + "type": "finish", + "finishReason": "stop", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + }, + } + ) + done = events[0] + assert isinstance(done, streaming.MessageDone) + assert done.finish_reason == "stop" + assert done.usage is not None + assert done.usage.input_tokens == 10 + assert done.usage.output_tokens == 20 + + async def test_finish_v3_nested_usage(self) -> None: + events = stream_mod._parse_stream_part( + { + "type": "finish", + "finishReason": { + "unified": "tool-calls", + "raw": "tool_calls", + }, + "usage": { + "inputTokens": { + "total": 100, + "cacheRead": 50, + }, + "outputTokens": { + "total": 200, + "reasoning": 30, + }, + }, + } + ) + done = events[0] + assert isinstance(done, streaming.MessageDone) + assert done.finish_reason == "tool-calls" + assert done.usage is not None + assert done.usage.input_tokens == 100 + assert done.usage.cache_read_tokens == 50 + assert done.usage.reasoning_tokens == 30 + + async def test_file_part(self) -> None: + """A ``file`` stream part (inline image from Gemini/GPT-5) + must produce a FileEvent.""" + events = stream_mod._parse_stream_part( + { + "type": "file", + "id": "f1", + "mediaType": "image/png", + "data": "iVBORw0KGgo=", + } + ) + assert len(events) == 1 + assert isinstance(events[0], streaming.FileEvent) + assert events[0].block_id == "f1" + assert events[0].media_type == "image/png" + assert events[0].data == "iVBORw0KGgo=" + + async def test_file_part_defaults(self) -> None: + """A minimal ``file`` part uses sensible defaults.""" + events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) + assert len(events) == 1 + assert isinstance(events[0], streaming.FileEvent) + assert events[0].media_type == "application/octet-stream" + + async def test_unknown_types_produce_no_events(self) -> None: + for t in ("stream-start", "raw", "response-metadata", "banana"): + assert stream_mod._parse_stream_part({"type": t}) == [] + + +# --------------------------------------------------------------------------- +# _parse_usage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestParseUsage: + async def test_flat_format(self) -> None: + usage = stream_mod._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) + assert usage.input_tokens == 10 + assert usage.output_tokens == 20 + + async def test_v3_nested_format(self) -> None: + usage = stream_mod._parse_usage( + { + "inputTokens": { + "total": 100, + "cacheRead": 30, + "cacheWrite": 5, + }, + "outputTokens": {"total": 50, "reasoning": 10}, + } + ) + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + assert usage.cache_read_tokens == 30 + assert usage.cache_write_tokens == 5 + assert usage.reasoning_tokens == 10 + + async def test_non_dict_returns_empty(self) -> None: + usage = stream_mod._parse_usage("not a dict") + assert usage.input_tokens == 0 + assert usage.output_tokens == 0 diff --git a/tests/models2/ai_gateway/test_stream.py b/tests/models2/ai_gateway/test_stream.py new file mode 100644 index 00000000..559f4874 --- /dev/null +++ b/tests/models2/ai_gateway/test_stream.py @@ -0,0 +1,430 @@ +"""Integration tests for the AI Gateway v3 streaming adapter. + +Every test exercises the real ``stream()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: + + stream(client, model, messages) + → _build_request_body() + → httpx POST (mock) + → SSE line parsing + → _parse_stream_part() + → StreamHandler + → yield Message +""" + +from __future__ import annotations + +import importlib +import json +from typing import Any + +import httpx +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.models2.ai_gateway import errors +from vercel_ai_sdk.models2.core import client as client_ +from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.types import messages + +# The ai_gateway __init__.py re-exports `stream` as a function, which +# shadows the module. Use importlib to get the actual module. +stream_mod = importlib.import_module("vercel_ai_sdk.models2.ai_gateway.stream") + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_TEST_MODEL = model_.Model( + id="test-provider/test-model", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + + +def _sse(*events: dict[str, Any]) -> str: + """Build SSE response text from event dicts.""" + return "".join(f"data: {json.dumps(e)}\n\n" for e in events) + + +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + """Create a Client wired to a mock transport.""" + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c + + +async def _collect( + client: client_.Client, + msgs: list[messages.Message], + model: model_.Model = _TEST_MODEL, + **kwargs: Any, +) -> list[messages.Message]: + """Drain ``stream()`` and return all yielded messages.""" + result: list[messages.Message] = [] + async for msg in stream_mod.stream(client, model, msgs, **kwargs): + result.append(msg) + return result + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Streaming: text, reasoning, tool calls +# --------------------------------------------------------------------------- + + +class TestStreaming: + @pytest.mark.asyncio + async def test_text_stream(self) -> None: + body = _sse( + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "Hello"}, + {"type": "text-delta", "id": "t1", "textDelta": " World"}, + {"type": "text-end", "id": "t1"}, + { + "type": "finish", + "finishReason": "stop", + "usage": { + "prompt_tokens": 5, + "completion_tokens": 2, + }, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + client = _client(httpx.MockTransport(handler)) + msgs = await _collect(client, [_user("Hi")]) + + final = msgs[-1] + assert final.text == "Hello World" + assert final.is_done + assert final.usage is not None + assert final.usage.input_tokens == 5 + assert final.usage.output_tokens == 2 + + @pytest.mark.asyncio + async def test_reasoning_then_text(self) -> None: + body = _sse( + {"type": "reasoning-start", "id": "r1"}, + {"type": "reasoning-delta", "id": "r1", "delta": "think"}, + {"type": "reasoning-end", "id": "r1"}, + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "42"}, + {"type": "text-end", "id": "t1"}, + {"type": "finish", "finishReason": "stop", "usage": {}}, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = (await _collect(_client(httpx.MockTransport(handler)), [_user("?")]))[ + -1 + ] + assert final.reasoning == "think" + assert final.text == "42" + + @pytest.mark.asyncio + async def test_streaming_tool_call(self) -> None: + body = _sse( + { + "type": "tool-input-start", + "id": "tc-1", + "toolName": "search", + }, + {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q":'}, + {"type": "tool-input-delta", "id": "tc-1", "delta": '"hi"}'}, + {"type": "tool-input-end", "id": "tc-1"}, + { + "type": "finish", + "finishReason": "tool-calls", + "usage": {}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_client(httpx.MockTransport(handler)), [_user("search")]) + )[-1] + tc = final.tool_calls + assert len(tc) == 1 + assert tc[0].tool_name == "search" + assert tc[0].tool_args == '{"q":"hi"}' + + @pytest.mark.asyncio + async def test_inline_file_stream(self) -> None: + """Models like Gemini-3-pro-image return inline file parts + alongside text in the language model stream.""" + body = _sse( + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "Here is an image:"}, + {"type": "text-end", "id": "t1"}, + { + "type": "file", + "id": "f1", + "mediaType": "image/png", + "data": "iVBORw0KGgo=", + }, + { + "type": "finish", + "finishReason": "stop", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_client(httpx.MockTransport(handler)), [_user("draw me")]) + )[-1] + assert final.text == "Here is an image:" + assert len(final.images) == 1 + assert final.images[0].media_type == "image/png" + assert final.images[0].data == "iVBORw0KGgo=" + assert final.is_done + + @pytest.mark.asyncio + async def test_complete_tool_call_part(self) -> None: + """Non-streaming ``tool-call`` part (one shot) must also work.""" + body = _sse( + { + "type": "tool-call", + "toolCallId": "tc-1", + "toolName": "get_weather", + "input": {"city": "SF"}, + }, + { + "type": "finish", + "finishReason": "tool-calls", + "usage": {}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_client(httpx.MockTransport(handler)), [_user("weather")]) + )[-1] + assert len(final.tool_calls) == 1 + assert json.loads(final.tool_calls[0].tool_args) == {"city": "SF"} + + +# --------------------------------------------------------------------------- +# Request: headers, body, tools +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + model = model_.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await _collect(client, [_user("Hi")], model=model) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-gateway-protocol-version"] == "0.0.1" + assert captured["ai-language-model-specification-version"] == "3" + assert captured["ai-language-model-id"] == "anthropic/claude-sonnet-4" + assert captured["ai-language-model-streaming"] == "true" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_body_prompt_format(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + await _collect(_client(httpx.MockTransport(handler)), [_user("Hello")]) + + assert captured_body["prompt"] == [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + } + ] + + @pytest.mark.asyncio + async def test_provider_options_in_body(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + opts = {"gateway": {"order": ["bedrock", "openai"]}} + await _collect( + _client(httpx.MockTransport(handler)), + [_user("Hi")], + provider_options=opts, + ) + + assert captured_body["providerOptions"] == opts + + @pytest.mark.asyncio + async def test_real_tool_in_request_body(self) -> None: + """A real ``@tool``-decorated function must appear correctly + in the request body sent to the gateway.""" + + @ai.tool + async def lookup(query: str) -> str: + """Search the database.""" + return "result" + + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + await _collect( + _client(httpx.MockTransport(handler)), + [_user("find something")], + tools=[lookup], + ) + + assert "tools" in captured_body + td = captured_body["tools"][0] + assert td["name"] == "lookup" + assert td["type"] == "function" + assert "query" in td["inputSchema"]["properties"] + + @pytest.mark.asyncio + async def test_multi_turn_request_body(self) -> None: + """A multi-turn conversation including a tool result must + serialize correctly into the v3 prompt format.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + tool_part = messages.ToolPart( + tool_call_id="tc-1", + tool_name="search", + tool_args='{"q": "weather"}', + status="result", + result={"temp": 72}, + ) + conversation = [ + _user("What's the weather?"), + messages.Message(role="assistant", parts=[tool_part]), + _user("Thanks, and tomorrow?"), + ] + + await _collect(_client(httpx.MockTransport(handler)), conversation) + + prompt = captured_body["prompt"] + # user → assistant (tool-call) → tool (tool-result) → user + assert len(prompt) == 4 + assert prompt[0]["role"] == "user" + assert prompt[1]["role"] == "assistant" + assert prompt[1]["content"][0]["type"] == "tool-call" + assert prompt[2]["role"] == "tool" + assert prompt[2]["content"][0]["type"] == "tool-result" + assert prompt[3]["role"] == "user" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) + + @pytest.mark.asyncio + async def test_429_rate_limit_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 429, + json={ + "error": { + "message": "Rate limit exceeded", + "type": "rate_limit_exceeded", + } + }, + ) + + with pytest.raises(errors.GatewayRateLimitError): + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) + + @pytest.mark.asyncio + async def test_404_model_not_found(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 404, + json={ + "error": { + "message": "Model xyz not found", + "type": "model_not_found", + "param": {"modelId": "xyz"}, + } + }, + ) + + with pytest.raises(errors.GatewayModelNotFoundError) as exc_info: + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) + assert exc_info.value.model_id == "xyz" + + @pytest.mark.asyncio + async def test_500_malformed_response(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="Not JSON") + + with pytest.raises(errors.GatewayResponseError): + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) diff --git a/tests/models2/core/__init__.py b/tests/models2/core/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models2/core/test_media.py b/tests/models2/core/test_media.py new file mode 100644 index 00000000..1ac85cdc --- /dev/null +++ b/tests/models2/core/test_media.py @@ -0,0 +1,372 @@ +"""Tests for media data helpers and magic-byte media type detection. + +Covers ``is_url``, ``data_to_base64``, ``data_to_data_url``, +``split_data_url``, ``detect_image_media_type``, ``detect_audio_media_type``, +and edge cases. +""" + +from __future__ import annotations + +import base64 + +from vercel_ai_sdk.models2.core.helpers.media import ( + data_to_base64, + data_to_data_url, + detect_audio_media_type, + detect_image_media_type, + is_url, + split_data_url, +) + +# --------------------------------------------------------------------------- +# is_url +# --------------------------------------------------------------------------- + + +class TestIsUrl: + def test_http(self) -> None: + assert is_url("https://example.com/img.png") is True + assert is_url("http://example.com/img.png") is True + + def test_data(self) -> None: + assert is_url("data:image/png;base64,iVBOR") is True + + def test_base64(self) -> None: + assert is_url("iVBORw0KGgo=") is False + + +# --------------------------------------------------------------------------- +# data_to_base64 +# --------------------------------------------------------------------------- + + +class TestDataToBase64: + def test_bytes(self) -> None: + raw = b"\x89PNG" + result = data_to_base64(raw) + assert base64.b64decode(result) == raw + + def test_passthrough(self) -> None: + b64 = base64.b64encode(b"hello").decode() + assert data_to_base64(b64) == b64 + + def test_extracts_from_data_url(self) -> None: + payload = base64.b64encode(b"hello").decode() + data_url = f"data:image/png;base64,{payload}" + assert data_to_base64(data_url) == payload + + def test_passthrough_http_url(self) -> None: + url = "https://example.com/image.png" + assert data_to_base64(url) == url + + +# --------------------------------------------------------------------------- +# data_to_data_url +# --------------------------------------------------------------------------- + + +class TestDataToDataUrl: + def test_from_bytes(self) -> None: + raw = b"\x89PNG" + result = data_to_data_url(raw, "image/png") + assert result.startswith("data:image/png;base64,") + assert base64.b64decode(result.split(",", 1)[1]) == raw + + def test_passthrough_url(self) -> None: + url = "https://example.com/image.png" + assert data_to_data_url(url, "image/png") == url + + +# --------------------------------------------------------------------------- +# split_data_url +# --------------------------------------------------------------------------- + + +class TestSplitDataUrl: + def test_valid(self) -> None: + media_type, content = split_data_url("data:image/png;base64,iVBOR") + assert media_type == "image/png" + assert content == "iVBOR" + + def test_non_data_url(self) -> None: + assert split_data_url("https://example.com") == (None, None) + + def test_malformed(self) -> None: + assert split_data_url("data:nope") == (None, None) + + +# --------------------------------------------------------------------------- +# Image detection +# --------------------------------------------------------------------------- + + +class TestGif: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x47, 0x49, 0x46])) == "image/gif" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x47, 0x49, 0x46])).decode() + ) + == "image/gif" + ) + + +class TestPng: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x89, 0x50, 0x4E, 0x47])) == "image/png" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x89, 0x50, 0x4E, 0x47])).decode() + ) + == "image/png" + ) + + +class TestJpeg: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0xFF, 0xD8, 0xFF])) == "image/jpeg" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0xFF, 0xD8, 0xFF])).decode() + ) + == "image/jpeg" + ) + + +class TestWebp: + _RIFF_WEBP = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._RIFF_WEBP) == "image/webp" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._RIFF_WEBP).decode()) + == "image/webp" + ) + + def test_riff_wave_not_webp_bytes(self) -> None: + riff_wave = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + assert detect_image_media_type(riff_wave) is None + + def test_riff_wave_not_webp_base64(self) -> None: + riff_wave = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + assert detect_image_media_type(base64.b64encode(riff_wave).decode()) is None + + +class TestBmp: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x42, 0x4D])) == "image/bmp" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(bytes([0x42, 0x4D])).decode()) + == "image/bmp" + ) + + +class TestTiff: + def test_little_endian_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x49, 0x49, 0x2A, 0x00])) == "image/tiff" + + def test_big_endian_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x4D, 0x4D, 0x00, 0x2A])) == "image/tiff" + + def test_little_endian_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x49, 0x49, 0x2A, 0x00])).decode() + ) + == "image/tiff" + ) + + def test_big_endian_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x4D, 0x4D, 0x00, 0x2A])).decode() + ) + == "image/tiff" + ) + + +class TestAvif: + _AVIF = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._AVIF) == "image/avif" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._AVIF).decode()) + == "image/avif" + ) + + +class TestHeic: + _HEIC = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._HEIC) == "image/heic" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._HEIC).decode()) + == "image/heic" + ) + + +# --------------------------------------------------------------------------- +# Audio detection +# --------------------------------------------------------------------------- + + +class TestMp3: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(bytes([0xFF, 0xFB])) == "audio/mpeg" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(bytes([0xFF, 0xFB])).decode()) + == "audio/mpeg" + ) + + def test_with_id3_tags_bytes(self) -> None: + # ID3v2 header (10 bytes) + MP3 sync bytes + id3_header = bytes([0x49, 0x44, 0x33, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + mp3_data = id3_header + bytes([0xFF, 0xFB]) + assert detect_audio_media_type(mp3_data) == "audio/mpeg" + + def test_with_id3_tags_base64(self) -> None: + id3_header = bytes([0x49, 0x44, 0x33, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + mp3_data = id3_header + bytes([0xFF, 0xFB]) + assert ( + detect_audio_media_type(base64.b64encode(mp3_data).decode()) == "audio/mpeg" + ) + + +class TestWav: + _RIFF_WAVE = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._RIFF_WAVE) == "audio/wav" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._RIFF_WAVE).decode()) + == "audio/wav" + ) + + def test_riff_webp_not_wav_bytes(self) -> None: + riff_webp = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + assert detect_audio_media_type(riff_webp) is None + + def test_riff_webp_not_wav_base64(self) -> None: + riff_webp = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + assert detect_audio_media_type(base64.b64encode(riff_webp).decode()) is None + + +class TestOgg: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(b"OggS") == "audio/ogg" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(b"OggS").decode()) == "audio/ogg" + ) + + +class TestFlac: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(b"fLaC") == "audio/flac" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(b"fLaC").decode()) == "audio/flac" + ) + + +class TestAac: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(bytes([0x40, 0x15, 0x00, 0x00])) == "audio/aac" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type( + base64.b64encode(bytes([0x40, 0x15, 0x00, 0x00])).decode() + ) + == "audio/aac" + ) + + +class TestMp4Audio: + # The audio/mp4 signature starts at the `ftyp` atom directly (no box size prefix). + _FTYP = bytes([0x66, 0x74, 0x79, 0x70]) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._FTYP) == "audio/mp4" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._FTYP).decode()) + == "audio/mp4" + ) + + +class TestWebmAudio: + _WEBM = bytes([0x1A, 0x45, 0xDF, 0xA3]) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._WEBM) == "audio/webm" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._WEBM).decode()) + == "audio/webm" + ) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_unknown_image_format(self) -> None: + assert detect_image_media_type(bytes([0x00, 0x01, 0x02, 0x03])) is None + + def test_unknown_audio_format(self) -> None: + assert detect_audio_media_type(bytes([0x00, 0x01, 0x02, 0x03])) is None + + def test_empty_bytes_image(self) -> None: + assert detect_image_media_type(b"") is None + + def test_empty_bytes_audio(self) -> None: + assert detect_audio_media_type(b"") is None + + def test_short_bytes_image(self) -> None: + assert detect_image_media_type(bytes([0x89])) is None + + def test_short_bytes_audio(self) -> None: + assert detect_audio_media_type(bytes([0xFF])) is None diff --git a/tests/models2/core/test_streaming.py b/tests/models2/core/test_streaming.py new file mode 100644 index 00000000..a0927608 --- /dev/null +++ b/tests/models2/core/test_streaming.py @@ -0,0 +1,257 @@ +"""StreamHandler: event accumulation, state transitions, message building.""" + + +from vercel_ai_sdk.models2.core.helpers.streaming import ( + FileEvent, + MessageDone, + ReasoningDelta, + ReasoningEnd, + ReasoningStart, + StreamHandler, + TextDelta, + TextEnd, + TextStart, + ToolArgsDelta, + ToolEnd, + ToolStart, +) +from vercel_ai_sdk.types.messages import ( + FilePart, + ReasoningPart, + TextPart, + ToolPart, + Usage, +) + +# -- Text streaming -------------------------------------------------------- + + +def test_text_lifecycle() -> None: + h = StreamHandler(message_id="m1") + m = h.handle_event(TextStart(block_id="b1")) + assert len(m.parts) == 1 + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "streaming" + assert part.text == "" + + m = h.handle_event(TextDelta(block_id="b1", delta="Hello")) + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.text == "Hello" + assert part.delta == "Hello" + assert part.state == "streaming" + + m = h.handle_event(TextDelta(block_id="b1", delta=" world")) + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.text == "Hello world" + assert part.delta == " world" + + m = h.handle_event(TextEnd(block_id="b1")) + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "done" + assert part.delta is None + + +# -- Reasoning streaming --------------------------------------------------- + + +def test_reasoning_lifecycle() -> None: + h = StreamHandler(message_id="m1") + h.handle_event(ReasoningStart(block_id="r1")) + m = h.handle_event(ReasoningDelta(block_id="r1", delta="thinking")) + part = m.parts[0] + assert isinstance(part, ReasoningPart) + assert part.text == "thinking" + assert part.state == "streaming" + + m = h.handle_event(ReasoningEnd(block_id="r1", signature="sig123")) + part = m.parts[0] + assert isinstance(part, ReasoningPart) + assert part.state == "done" + assert part.signature == "sig123" + + +# -- Tool streaming -------------------------------------------------------- + + +def test_tool_lifecycle() -> None: + h = StreamHandler(message_id="m1") + h.handle_event(ToolStart(tool_call_id="tc1", tool_name="get_weather")) + m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.tool_name == "get_weather" + assert part.tool_args == '{"ci' + assert part.state == "streaming" + assert part.args_delta == '{"ci' + + m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.tool_args == '{"city":"London"}' + + m = h.handle_event(ToolEnd(tool_call_id="tc1")) + part = m.parts[0] + assert isinstance(part, ToolPart) + assert part.state == "done" + assert part.args_delta is None + + +# -- Multi-part messages --------------------------------------------------- + + +def test_reasoning_then_text_then_tool() -> None: + """Full message: reasoning block, text block, tool call.""" + h = StreamHandler(message_id="m1") + h.handle_event(ReasoningStart(block_id="r1")) + h.handle_event(ReasoningDelta(block_id="r1", delta="Let me think")) + h.handle_event(ReasoningEnd(block_id="r1")) + + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="I'll check")) + h.handle_event(TextEnd(block_id="t1")) + + h.handle_event(ToolStart(tool_call_id="tc1", tool_name="search")) + h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"q":"test"}')) + m = h.handle_event(ToolEnd(tool_call_id="tc1")) + + assert len(m.parts) == 3 + assert isinstance(m.parts[0], ReasoningPart) + assert isinstance(m.parts[1], TextPart) + assert isinstance(m.parts[2], ToolPart) + assert all( + p.state == "done" + for p in m.parts + if isinstance(p, (TextPart, ToolPart, ReasoningPart)) + ) + + +def test_multiple_tool_calls() -> None: + """Parallel tool calls in one message.""" + h = StreamHandler(message_id="m1") + h.handle_event(ToolStart(tool_call_id="tc1", tool_name="read_file")) + h.handle_event(ToolStart(tool_call_id="tc2", tool_name="list_files")) + + m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}')) + # Both tools should be in parts + tool_parts = [p for p in m.parts if isinstance(p, ToolPart)] + assert len(tool_parts) == 2 + # tc1 has args, tc2 is empty + assert tool_parts[0].tool_args == '{"path":"a.py"}' + assert tool_parts[1].tool_args == "" + + h.handle_event(ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) + h.handle_event(ToolEnd(tool_call_id="tc1")) + m = h.handle_event(ToolEnd(tool_call_id="tc2")) + assert all( + p.state == "done" + for p in m.parts + if isinstance(p, (TextPart, ToolPart, ReasoningPart)) + ) + + +# -- MessageDone ----------------------------------------------------------- + + +def test_message_done_finalizes_all() -> None: + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="hello")) + # Don't send TextEnd -- MessageDone should finalize everything + m = h.handle_event(MessageDone(finish_reason="end_turn")) + part = m.parts[0] + assert isinstance(part, TextPart) + assert part.state == "done" + assert m.is_done + + +def test_message_done_propagates_usage() -> None: + """Usage on MessageDone surfaces on the built Message.""" + usage = Usage(input_tokens=10, output_tokens=20) + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="hi")) + + # Before MessageDone, usage should not be on the message + m = h.handle_event(TextEnd(block_id="t1")) + assert m.usage is None + + m = h.handle_event(MessageDone(usage=usage)) + assert m.usage is not None + assert m.usage.input_tokens == 10 + assert m.usage.output_tokens == 20 + assert m.usage.total_tokens == 30 + + +# -- Message properties propagate ------------------------------------------ + + +def test_message_id_propagates() -> None: + h = StreamHandler(message_id="custom-id") + m = h.handle_event(TextStart(block_id="b1")) + assert m.id == "custom-id" + + +def test_deltas_only_on_active_blocks() -> None: + """Delta should be None on inactive blocks, present only on active.""" + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="first")) + h.handle_event(TextEnd(block_id="t1")) + + h.handle_event(TextStart(block_id="t2")) + m = h.handle_event(TextDelta(block_id="t2", delta="second")) + + text_parts = [p for p in m.parts if isinstance(p, TextPart)] + assert text_parts[0].delta is None # t1 is done + assert text_parts[1].delta == "second" # t2 is active + + +# -- File event (inline images from LLMs like Gemini/GPT-5) --------------- + + +def test_file_event_accumulates() -> None: + """FileEvent should produce a FilePart in the message.""" + h = StreamHandler(message_id="m1") + m = h.handle_event( + FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") + ) + file_parts = [p for p in m.parts if isinstance(p, FilePart)] + assert len(file_parts) == 1 + assert file_parts[0].media_type == "image/png" + assert file_parts[0].data == "iVBORw0KGgo=" + + +def test_file_event_with_text() -> None: + """A message can have both text and file parts (e.g. Gemini image gen).""" + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="Here is your image:")) + h.handle_event(TextEnd(block_id="t1")) + h.handle_event( + FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") + ) + m = h.handle_event(MessageDone(finish_reason="stop")) + + assert len(m.parts) == 2 + assert isinstance(m.parts[0], TextPart) + assert m.parts[0].text == "Here is your image:" + assert isinstance(m.parts[1], FilePart) + assert m.parts[1].media_type == "image/png" + assert m.is_done + + +def test_multiple_file_events() -> None: + """Multiple FileEvents produce multiple FileParts.""" + h = StreamHandler(message_id="m1") + h.handle_event(FileEvent(block_id="f1", media_type="image/png", data="png_data")) + m = h.handle_event( + FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") + ) + file_parts = [p for p in m.parts if isinstance(p, FilePart)] + assert len(file_parts) == 2 + assert file_parts[0].media_type == "image/png" + assert file_parts[1].media_type == "image/jpeg" From 6e828f307bf079772e277337cd7243914313147a Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 10:38:41 -0700 Subject: [PATCH 10/18] Port openai and anthropic adapters --- src/vercel_ai_sdk/models2/__init__.py | 4 + .../models2/anthropic/__init__.py | 7 + .../models2/anthropic/adapter.py | 389 ++++++++++++++++++ src/vercel_ai_sdk/models2/openai/__init__.py | 7 + src/vercel_ai_sdk/models2/openai/adapter.py | 386 +++++++++++++++++ 5 files changed, 793 insertions(+) create mode 100644 src/vercel_ai_sdk/models2/anthropic/__init__.py create mode 100644 src/vercel_ai_sdk/models2/anthropic/adapter.py create mode 100644 src/vercel_ai_sdk/models2/openai/__init__.py create mode 100644 src/vercel_ai_sdk/models2/openai/adapter.py diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py index 92ef3c22..43636a99 100644 --- a/src/vercel_ai_sdk/models2/__init__.py +++ b/src/vercel_ai_sdk/models2/__init__.py @@ -60,9 +60,13 @@ def _ensure_adapters() -> None: from .ai_gateway import generate as ai_gw_generate from .ai_gateway import stream as ai_gw_stream + from .anthropic.adapter import stream as anthropic_stream + from .openai.adapter import stream as openai_stream _stream_adapters["ai-gateway-v3"] = ai_gw_stream _generate_adapters["ai-gateway-v3"] = ai_gw_generate + _stream_adapters["openai"] = openai_stream + _stream_adapters["anthropic"] = anthropic_stream # --------------------------------------------------------------------------- diff --git a/src/vercel_ai_sdk/models2/anthropic/__init__.py b/src/vercel_ai_sdk/models2/anthropic/__init__.py new file mode 100644 index 00000000..a9a0436b --- /dev/null +++ b/src/vercel_ai_sdk/models2/anthropic/__init__.py @@ -0,0 +1,7 @@ +"""Anthropic provider — adapter for the Anthropic messages API.""" + +from .adapter import stream + +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models2/anthropic/adapter.py b/src/vercel_ai_sdk/models2/anthropic/adapter.py new file mode 100644 index 00000000..7ad3d25c --- /dev/null +++ b/src/vercel_ai_sdk/models2/anthropic/adapter.py @@ -0,0 +1,389 @@ +"""Anthropic adapter — messages API. + +Message/tool conversion and streaming via the official ``anthropic`` SDK. +The SDK client is constructed from :class:`Client` params on each call. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import anthropic +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ + +# --------------------------------------------------------------------------- +# Message / tool conversion — internal types → Anthropic wire format +# --------------------------------------------------------------------------- + + +def _tools_to_anthropic( + tools: Sequence[tools_.ToolLike], +) -> list[dict[str, Any]]: + """Convert internal Tool objects to Anthropic tool schema format.""" + return [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.param_schema, + } + for tool in tools + ] + + +def _file_part_to_anthropic( + part: messages_.FilePart, +) -> dict[str, Any]: + """Convert a :class:`FilePart` to an Anthropic content block. + + * ``image/*`` -> ``{"type": "image", "source": ...}`` + * ``application/pdf`` -> ``{"type": "document", "source": ...}`` + * ``text/plain`` -> ``{"type": "document", "source": ...}`` + * anything else -> ``ValueError`` + """ + mt = part.media_type + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + if isinstance(part.data, str) and media_.is_url(part.data): + return { + "type": "image", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": media_.data_to_base64(part.data), + }, + } + + if mt == "application/pdf": + if isinstance(part.data, str) and media_.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": media_.data_to_base64(part.data), + }, + } + + if mt == "text/plain": + if isinstance(part.data, bytes): + text_data = part.data.decode("utf-8") + elif media_.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + else: + import base64 as _b64 + + text_data = _b64.b64decode(part.data).decode("utf-8") + return { + "type": "document", + "source": { + "type": "text", + "media_type": "text/plain", + "data": text_data, + }, + } + + raise ValueError(f"Unsupported media type for Anthropic: {mt}") + + +async def _messages_to_anthropic( + messages: list[messages_.Message], +) -> tuple[str | None, list[dict[str, Any]]]: + """Convert internal messages to Anthropic API format. + + Returns ``(system_prompt, messages_list)``. The system prompt is + extracted separately because the Anthropic API takes it as a + top-level parameter. + """ + system_prompt: str | None = None + result: list[dict[str, Any]] = [] + + for msg in messages: + match msg.role: + case "system": + system_prompt = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + case "assistant": + content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text, signature=signature): + if signature: + content.append( + { + "type": "thinking", + "thinking": text, + "signature": signature, + } + ) + case messages_.TextPart(text=text): + content.append({"type": "text", "text": text}) + case messages_.ToolPart(): + tool_input = ( + json.loads(part.tool_args) if part.tool_args else {} + ) + content.append( + { + "type": "tool_use", + "id": part.tool_call_id, + "name": part.tool_name, + "input": tool_input, + } + ) + if part.status in ("result", "error"): + entry: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + if part.status == "error": + entry["is_error"] = True + tool_results.append(entry) + + if content: + result.append({"role": "assistant", "content": content}) + if tool_results: + result.append({"role": "user", "content": tool_results}) + + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + content_text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": content_text}) + else: + user_content: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + user_content.append({"type": "text", "text": text}) + case messages_.FilePart(): + user_content.append(_file_part_to_anthropic(p)) + result.append({"role": "user", "content": user_content}) + + result = _merge_consecutive_roles(result) + return system_prompt, result + + +def _merge_consecutive_roles( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Merge consecutive messages that share the same role. + + Anthropic requires strictly alternating user/assistant roles. + """ + if not messages: + return messages + + merged: list[dict[str, Any]] = [messages[0]] + + for msg in messages[1:]: + if msg["role"] == merged[-1]["role"]: + prev = _to_content_list(merged[-1]["content"]) + cur = _to_content_list(msg["content"]) + merged[-1]["content"] = prev + cur + else: + merged.append(msg) + + return merged + + +def _to_content_list(content: Any) -> list[dict[str, Any]]: + """Normalize Anthropic message content to list-of-blocks.""" + if isinstance(content, list): + return list(content) + return [{"type": "text", "text": content}] + + +# --------------------------------------------------------------------------- +# SDK client factory +# --------------------------------------------------------------------------- + + +def _make_client( + client: client_.Client, +) -> anthropic.AsyncAnthropic: + """Construct an ``AsyncAnthropic`` from our generic ``Client``.""" + return anthropic.AsyncAnthropic( + base_url=client.base_url, + api_key=client.api_key or "", + ) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + thinking: bool = False, + budget_tokens: int = 10000, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response via the Anthropic messages API. + + Yields ``Message`` snapshots as the response streams in. + + Extra keyword arguments beyond the ``StreamFn`` protocol: + + * ``thinking`` — enable extended thinking output. + * ``budget_tokens`` — max tokens for thinking (default 10000). + """ + sdk_client = _make_client(client) + system_prompt, anthropic_messages = await _messages_to_anthropic(messages) + anthropic_tools = _tools_to_anthropic(tools) if tools else None + + api_kwargs: dict[str, Any] = { + "model": model.id, + "messages": anthropic_messages, + "max_tokens": 8192, + } + if system_prompt: + api_kwargs["system"] = system_prompt + if anthropic_tools: + api_kwargs["tools"] = anthropic_tools + + if thinking: + api_kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": budget_tokens, + } + + if output_type is not None: + api_kwargs["output_format"] = output_type + + handler = streaming_.StreamHandler() + + block_types: dict[int, str] = {} + tool_ids: dict[int, str] = {} + signature_buffer: dict[int, str] = {} + + try: + stream_cm = sdk_client.messages.stream(**api_kwargs) + + async with stream_cm as sdk_stream: + async for event in sdk_stream: + match event.type: + case "content_block_start": + block = event.content_block + idx = event.index + block_types[idx] = block.type + + match block.type: + case "text": + yield handler.handle_event( + streaming_.TextStart(block_id=str(idx)) + ) + case "thinking": + yield handler.handle_event( + streaming_.ReasoningStart(block_id=str(idx)) + ) + case "tool_use": + tool_ids[idx] = block.id + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=block.id, + tool_name=block.name, + ) + ) + + case "content_block_delta": + delta = event.delta + idx = event.index + + match delta.type: + case "text_delta": + yield handler.handle_event( + streaming_.TextDelta( + block_id=str(idx), + delta=delta.text, + ) + ) + case "thinking_delta": + yield handler.handle_event( + streaming_.ReasoningDelta( + block_id=str(idx), + delta=delta.thinking, + ) + ) + case "signature_delta": + signature_buffer[idx] = ( + signature_buffer.get(idx, "") + delta.signature + ) + case "input_json_delta": + tool_id = tool_ids.get(idx) + if tool_id: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=tool_id, + delta=delta.partial_json, + ) + ) + + case "content_block_stop": + idx = event.index + match block_types.get(idx): + case "text": + yield handler.handle_event( + streaming_.TextEnd(block_id=str(idx)) + ) + case "thinking": + yield handler.handle_event( + streaming_.ReasoningEnd( + block_id=str(idx), + signature=signature_buffer.get(idx), + ) + ) + case "tool_use": + tool_id = tool_ids.get(idx) + if tool_id: + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=tool_id) + ) + + snapshot = sdk_stream.current_message_snapshot + sdk_usage = snapshot.usage + usage = messages_.Usage( + input_tokens=sdk_usage.input_tokens or 0, + output_tokens=sdk_usage.output_tokens or 0, + cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), + cache_write_tokens=getattr( + sdk_usage, "cache_creation_input_tokens", None + ), + raw=sdk_usage.model_dump(exclude_none=True) or None, + ) + yield handler.handle_event(streaming_.MessageDone(usage=usage)) + finally: + await sdk_client.close() diff --git a/src/vercel_ai_sdk/models2/openai/__init__.py b/src/vercel_ai_sdk/models2/openai/__init__.py new file mode 100644 index 00000000..bd01bcd1 --- /dev/null +++ b/src/vercel_ai_sdk/models2/openai/__init__.py @@ -0,0 +1,7 @@ +"""OpenAI provider — adapter for the OpenAI chat completions API.""" + +from .adapter import stream + +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models2/openai/adapter.py b/src/vercel_ai_sdk/models2/openai/adapter.py new file mode 100644 index 00000000..8f63c244 --- /dev/null +++ b/src/vercel_ai_sdk/models2/openai/adapter.py @@ -0,0 +1,386 @@ +"""OpenAI adapter — chat completions API. + +Message/tool conversion and streaming via the official ``openai`` SDK. +The SDK client is constructed from :class:`Client` params on each call. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import openai +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ + +# --------------------------------------------------------------------------- +# Message / tool conversion — internal types → OpenAI wire format +# --------------------------------------------------------------------------- + + +def _tools_to_openai( + tools: Sequence[tools_.ToolLike], +) -> list[dict[str, Any]]: + """Convert internal Tool objects to OpenAI tool schema format.""" + return [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.param_schema, + }, + } + for tool in tools + ] + + +async def _file_part_to_openai( + part: messages_.FilePart, +) -> dict[str, Any]: + """Convert a :class:`FilePart` to an OpenAI content-array element. + + * ``image/*`` -> ``image_url`` (URL or ``data:`` URL) + * ``audio/*`` -> ``input_audio`` (base-64 only; URLs auto-downloaded) + * ``application/pdf`` -> ``file`` (base-64 only; URLs auto-downloaded) + * ``text/*`` -> ``text`` (decoded to string) + * anything else -> ``ValueError`` + """ + mt = part.media_type + data = part.data + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + url = media_.data_to_data_url(data, media_type) + return {"type": "image_url", "image_url": {"url": url}} + + if mt.startswith("audio/"): + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + fmt = mt.split("/", 1)[1] if "/" in mt else mt + b64 = media_.data_to_base64(data) + return { + "type": "input_audio", + "input_audio": {"data": b64, "format": fmt}, + } + + if mt == "application/pdf": + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + data_url = media_.data_to_data_url(data, mt) + filename = part.filename or "document.pdf" + return { + "type": "file", + "file": {"filename": filename, "file_data": data_url}, + } + + if mt.startswith("text/"): + if isinstance(data, bytes): + text_content = data.decode("utf-8") + elif media_.is_url(data): + text_content = data + else: + import base64 as _b64 + + text_content = _b64.b64decode(data).decode("utf-8") + return {"type": "text", "text": text_content} + + raise ValueError(f"Unsupported media type for OpenAI: {mt}") + + +async def _messages_to_openai( + messages: list[messages_.Message], +) -> list[dict[str, Any]]: + """Convert internal messages to OpenAI API format. + + * ``tool_calls`` on assistant messages + * tool results as separate ``role: "tool"`` messages + """ + result: list[dict[str, Any]] = [] + for msg in messages: + match msg.role: + case "assistant": + content = "" + reasoning = "" + tool_calls: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text): + reasoning += text + case messages_.TextPart(text=text): + content += text + case messages_.ToolPart(): + tool_calls.append( + { + "id": part.tool_call_id, + "type": "function", + "function": { + "name": part.tool_name, + "arguments": part.tool_args, + }, + } + ) + if part.status in ("result", "error"): + tool_results.append( + { + "role": "tool", + "tool_call_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + ) + + entry: dict[str, Any] = {"role": "assistant"} + if content: + entry["content"] = content + if reasoning: + entry["reasoning"] = reasoning + if tool_calls: + entry["tool_calls"] = tool_calls + result.append(entry) + result.extend(tool_results) + + case "system": + content_text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "system", "content": content_text}) + + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": text}) + else: + parts: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + parts.append({"type": "text", "text": text}) + case messages_.FilePart(): + parts.append(await _file_part_to_openai(p)) + result.append({"role": "user", "content": parts}) + return result + + +# --------------------------------------------------------------------------- +# SDK client factory +# --------------------------------------------------------------------------- + + +def _make_client(client: client_.Client) -> openai.AsyncOpenAI: + """Construct an ``AsyncOpenAI`` from our generic ``Client``.""" + return openai.AsyncOpenAI( + base_url=client.base_url, + api_key=client.api_key or "", + ) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + thinking: bool = False, + budget_tokens: int | None = None, + reasoning_effort: str | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response via the OpenAI chat completions API. + + Yields ``Message`` snapshots as the response streams in. + + Extra keyword arguments beyond the ``StreamFn`` protocol: + + * ``thinking`` — enable reasoning/thinking output. + * ``budget_tokens`` — max tokens for reasoning (mutually exclusive + with ``reasoning_effort``). + * ``reasoning_effort`` — effort level: ``"none"``, ``"minimal"``, + ``"low"``, ``"medium"``, ``"high"``, ``"xhigh"`` + (mutually exclusive with ``budget_tokens``). + """ + sdk_client = _make_client(client) + openai_messages = await _messages_to_openai(messages) + openai_tools = _tools_to_openai(tools) if tools else None + + api_kwargs: dict[str, Any] = { + "model": model.id, + "messages": openai_messages, + "stream": True, + "stream_options": {"include_usage": True}, + } + if openai_tools: + api_kwargs["tools"] = openai_tools + + if output_type is not None: + from openai.lib._pydantic import to_strict_json_schema + + api_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": output_type.__name__, + "schema": to_strict_json_schema(output_type), + "strict": True, + }, + } + + # Enable reasoning/thinking via Vercel AI Gateway's unified format + if thinking: + reasoning_config: dict[str, Any] = {"enabled": True} + if budget_tokens is not None: + reasoning_config["max_tokens"] = budget_tokens + elif reasoning_effort is not None: + reasoning_config["effort"] = reasoning_effort + api_kwargs["extra_body"] = {"reasoning": reasoning_config} + + handler = streaming_.StreamHandler() + + try: + sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) + + text_started = False + reasoning_started = False + tc_state: dict[int, dict[str, Any]] = {} + finish_reason: str | None = None + usage: messages_.Usage | None = None + + async for chunk in sdk_stream: + if chunk.usage is not None: + raw = chunk.usage.model_dump(exclude_none=True) + reasoning_tokens: int | None = None + cache_read: int | None = None + cd = getattr( + chunk.usage, + "completion_tokens_details", + None, + ) + if cd: + reasoning_tokens = getattr(cd, "reasoning_tokens", None) + pd = getattr( + chunk.usage, + "prompt_tokens_details", + None, + ) + if pd: + cache_read = getattr(pd, "cached_tokens", None) + usage = messages_.Usage( + input_tokens=chunk.usage.prompt_tokens or 0, + output_tokens=chunk.usage.completion_tokens or 0, + reasoning_tokens=reasoning_tokens, + cache_read_tokens=cache_read, + raw=raw, + ) + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + # Reasoning / thinking content + reasoning_value = None + if hasattr(delta, "reasoning") and delta.reasoning: + reasoning_value = delta.reasoning + elif hasattr(delta, "model_extra") and delta.model_extra: + reasoning_value = delta.model_extra.get("reasoning") + + if reasoning_value: + if not reasoning_started: + reasoning_started = True + yield handler.handle_event( + streaming_.ReasoningStart(block_id="reasoning") + ) + yield handler.handle_event( + streaming_.ReasoningDelta( + block_id="reasoning", delta=reasoning_value + ) + ) + + if delta.content: + if reasoning_started: + yield handler.handle_event( + streaming_.ReasoningEnd(block_id="reasoning") + ) + reasoning_started = False + + if not text_started: + text_started = True + yield handler.handle_event(streaming_.TextStart(block_id="text")) + yield handler.handle_event( + streaming_.TextDelta(block_id="text", delta=delta.content) + ) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tc_state: + tc_state[idx] = { + "id": tc.id, + "name": None, + "started": False, + } + if tc.id: + tc_state[idx]["id"] = tc.id + if tc.function: + if tc.function.name: + tc_state[idx]["name"] = tc.function.name + if tc.function.arguments: + tid = tc_state[idx]["id"] + tname = tc_state[idx]["name"] or "" + + if not tc_state[idx]["started"] and tid: + tc_state[idx]["started"] = True + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=tid, + tool_name=tname, + ) + ) + + if tid: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=tid, + delta=tc.function.arguments, + ) + ) + + if choice.finish_reason is not None: + finish_reason = choice.finish_reason + if reasoning_started: + yield handler.handle_event( + streaming_.ReasoningEnd(block_id="reasoning") + ) + if text_started: + yield handler.handle_event(streaming_.TextEnd(block_id="text")) + for tc in tc_state.values(): + if tc["started"] and tc["id"]: + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=tc["id"]) + ) + + yield handler.handle_event( + streaming_.MessageDone(finish_reason=finish_reason, usage=usage) + ) + finally: + await sdk_client.close() From 3cf005becd9cd0fc4ae62afec76197cbd5b792f6 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 11:03:23 -0700 Subject: [PATCH 11/18] Wire the updated models module into the agents module --- src/vercel_ai_sdk/__init__.py | 30 ++----- src/vercel_ai_sdk/agents/__init__.py | 2 +- src/vercel_ai_sdk/agents/runtime.py | 14 +-- src/vercel_ai_sdk/models2/__init__.py | 18 ++++ tests/adapters/ai_sdk_ui/test_adapter.py | 18 ++-- tests/agents/mcp/test_client.py | 10 +-- tests/agents/test_checkpoint.py | 59 ++++++------ tests/agents/test_hooks.py | 41 ++++----- tests/agents/test_runtime.py | 61 +++++++------ tests/agents/test_streams.py | 17 ++-- tests/conftest.py | 110 ++++++++++++++++++++++- tests/telemetry/test_otel_handler.py | 19 ++-- tests/telemetry/test_telemetry.py | 52 ++++++----- 13 files changed, 293 insertions(+), 158 deletions(-) diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index 37946271..ee63f2a8 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,4 +1,4 @@ -from . import adapters, models, telemetry +from . import adapters, models, models2, telemetry from .adapters import ai_sdk_ui from .agents import ( Checkpoint, @@ -20,18 +20,7 @@ stream_step, tool, ) -from .models import ( - ImageModel, - LanguageModel, - MediaModel, - MediaResult, - Model, - Stream, - VideoModel, - ai_gateway, - anthropic, - openai, -) +from .models2 import Client, Model, ModelCost # Re-export core types from .types import ( @@ -67,14 +56,12 @@ "ToolSchema", "Usage", "make_messages", - # Models (from models/) + # Models (from models2/) "Model", - "Stream", - "LanguageModel", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", + "ModelCost", + "Client", + "models2", + # Legacy (from models/) — kept during transition "models", # Agents (from agents/) "Tool", @@ -97,9 +84,6 @@ "hook", # Submodules "telemetry", - "ai_gateway", - "anthropic", - "openai", "mcp", "ai_sdk_ui", "adapters", diff --git a/src/vercel_ai_sdk/agents/__init__.py b/src/vercel_ai_sdk/agents/__init__.py index d33640a2..5822431d 100644 --- a/src/vercel_ai_sdk/agents/__init__.py +++ b/src/vercel_ai_sdk/agents/__init__.py @@ -1,6 +1,6 @@ """Agent loop orchestration — tools, hooks, runtime, and streaming. -Depends on types/ and models/. Provides the loop machinery that +Depends on types/ and models2/. Provides the loop machinery that plugs a model into a tool-calling loop with hooks and checkpoints. """ diff --git a/src/vercel_ai_sdk/agents/runtime.py b/src/vercel_ai_sdk/agents/runtime.py index e9d8aadf..4c156c34 100644 --- a/src/vercel_ai_sdk/agents/runtime.py +++ b/src/vercel_ai_sdk/agents/runtime.py @@ -10,7 +10,7 @@ import pydantic -from ..models.core import llm as llm_ +from .. import models2 from ..telemetry import events as telemetry_ from ..types import messages as messages_ from . import checkpoint as checkpoint_ @@ -213,15 +213,16 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: @streams_.stream async def stream_step( - llm: llm_.LanguageModel, + model: models2.Model, messages: list[messages_.Message], tools: Sequence[tools_.ToolLike] | None = None, label: str | None = None, output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, ) -> AsyncGenerator[messages_.Message]: """Single LLM call that streams to Runtime.""" - async for msg in llm.stream( - messages=messages, tools=tools, output_type=output_type + async for msg in models2.stream( + model, messages, tools=tools, output_type=output_type, **kwargs ): msg.label = label yield msg @@ -301,18 +302,19 @@ async def execute_tool( async def stream_loop( - llm: llm_.LanguageModel, + model: models2.Model, messages: list[messages_.Message], tools: Sequence[tools_.ToolLike], label: str | None = None, output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, ) -> streams_.StreamResult: """Agent loop: stream LLM, execute tools, repeat until done.""" local_messages = list(messages) while True: result = await stream_step( - llm, local_messages, tools, label=label, output_type=output_type + model, local_messages, tools, label=label, output_type=output_type, **kwargs ) if not result.tool_calls: diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py index 43636a99..fb78f460 100644 --- a/src/vercel_ai_sdk/models2/__init__.py +++ b/src/vercel_ai_sdk/models2/__init__.py @@ -69,6 +69,22 @@ def _ensure_adapters() -> None: _stream_adapters["anthropic"] = anthropic_stream +def register_stream(adapter: str, fn: StreamFn) -> None: + """Register a stream adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _stream_adapters[adapter] = fn + + +def register_generate(adapter: str, fn: GenerateFn) -> None: + """Register a generate adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _generate_adapters[adapter] = fn + + # --------------------------------------------------------------------------- # Provider defaults — base URLs and env var names for auto-client creation. # --------------------------------------------------------------------------- @@ -183,5 +199,7 @@ async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: # Public API "buffer", "generate", + "register_generate", + "register_stream", "stream", ] diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index d3cf677e..ed2e2d36 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -12,7 +12,7 @@ from vercel_ai_sdk.agents import hooks from vercel_ai_sdk.types import messages -from ...conftest import MockLLM, tool_msg +from ...conftest import MOCK_MODEL, mock_llm, tool_msg async def get_event_types(msgs: list[messages.Message]) -> list[str]: @@ -241,12 +241,12 @@ async def get_weather(city: str) -> str: async def mock_agent( - llm: ai.LanguageModel, + model: ai.Model, user_query: str, ) -> ai.StreamResult: """Agent using stream_loop directly.""" return await ai.stream_loop( - llm, + model, messages=ai.make_messages(system="You are helpful.", user=user_query), tools=[get_weather], ) @@ -294,11 +294,11 @@ async def test_runtime_tool_roundtrip() -> None: ), ] - mock_llm = MockLLM([tool_call_response, final_text_response]) + mock_llm([tool_call_response, final_text_response]) # Collect all messages from the runtime runtime_messages: list[messages.Message] = [] - async for msg in ai.run(mock_agent, mock_llm, "What's the weather in London?"): + async for msg in ai.run(mock_agent, MOCK_MODEL, "What's the weather in London?"): runtime_messages.append(msg) # Stream through UI adapter @@ -643,9 +643,9 @@ async def dangerous_action(path: str) -> str: """Do something dangerous.""" return f"deleted {path}" - async def graph(llm: ai.LanguageModel) -> None: + async def graph(model: ai.Model) -> None: result = await ai.stream_step( - llm, + model, ai.make_messages(system="You are helpful.", user="delete /tmp"), [dangerous_action], ) @@ -667,7 +667,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: await asyncio.gather(*(approve_and_execute(tc) for tc in result.tool_calls)) - mock_llm = MockLLM( + mock_llm( [ [ tool_msg( @@ -680,7 +680,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: ) runtime_messages: list[messages.Message] = [] - result = ai.run(graph, mock_llm) + result = ai.run(graph, MOCK_MODEL) async for msg in result: runtime_messages.append(msg) diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 8b5c5e9d..78bee949 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -10,7 +10,7 @@ from vercel_ai_sdk.agents.mcp.client import _mcp_tool_to_native from vercel_ai_sdk.agents.tools import _tool_registry, get_tool -from ...conftest import MockLLM, text_msg, tool_msg +from ...conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg def _fake_mcp_tool( @@ -84,18 +84,18 @@ async def fake_fn(**kwargs: str) -> str: native._fn = fake_fn _tool_registry[native.name] = native - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="echo hello"), tools=[native], ) call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] call2 = [text_msg("Done.", id="msg-2")] - llm = MockLLM([call1, call2]) + llm = mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = ai.run(graph, MOCK_MODEL) msgs = [m async for m in result] # Tool was called with the right args diff --git a/tests/agents/test_checkpoint.py b/tests/agents/test_checkpoint.py index 4f84b243..c90ee925 100644 --- a/tests/agents/test_checkpoint.py +++ b/tests/agents/test_checkpoint.py @@ -9,7 +9,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.agents.checkpoint import Checkpoint, HookEvent, StepEvent, ToolEvent -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg @ai.hook @@ -23,19 +23,19 @@ class Approval(pydantic.BaseModel): @pytest.mark.asyncio async def test_step_replay_skips_llm() -> None: - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_step( - llm, messages=ai.make_messages(system="test", user="hello") + model, messages=ai.make_messages(system="test", user="hello") ) - llm1 = MockLLM([[text_msg("Hi there!")]]) - result1 = ai.run(graph, llm1) + llm1 = mock_llm([[text_msg("Hi there!")]]) + result1 = ai.run(graph, MOCK_MODEL) [msg async for msg in result1] assert llm1.call_count == 1 cp = result1.checkpoint - llm2 = MockLLM([]) - result2 = ai.run(graph, llm2, checkpoint=cp) + llm2 = mock_llm([]) + result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) [msg async for msg in result2] assert llm2.call_count == 0 @@ -51,8 +51,8 @@ async def counting_tool(x: int) -> int: execution_count += 1 return x + 1 - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - result = await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + async def graph(model: ai.Model) -> ai.StreamResult: + result = await ai.stream_step(model, ai.make_messages(system="t", user="go")) if result.tool_calls: await asyncio.gather( *( @@ -62,14 +62,15 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: ) return result - llm1 = MockLLM([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) - result1 = ai.run(graph, llm1) + mock_llm([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) + result1 = ai.run(graph, MOCK_MODEL) [msg async for msg in result1] assert execution_count == 1 assert result1.checkpoint.tools[0].result == 6 execution_count = 0 - result2 = ai.run(graph, MockLLM([]), checkpoint=result1.checkpoint) + mock_llm([]) + result2 = ai.run(graph, MOCK_MODEL, checkpoint=result1.checkpoint) [msg async for msg in result2] assert execution_count == 0 @@ -79,11 +80,12 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_hook_cancellation_pending() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + async def graph(model: ai.Model) -> Any: + await ai.stream_step(model, ai.make_messages(system="t", user="go")) return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] - result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + result = ai.run(graph, MOCK_MODEL) msgs = [msg async for msg in result] assert "my_approval" in result.pending_hooks hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] @@ -92,17 +94,19 @@ async def graph(llm: ai.LanguageModel) -> Any: @pytest.mark.asyncio async def test_hook_resolution_on_reentry() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + async def graph(model: ai.Model) -> Any: + await ai.stream_step(model, ai.make_messages(system="t", user="go")) return await Approval.create("my_approval") # type: ignore[attr-defined] resp = [text_msg("OK")] - result1 = ai.run(graph, MockLLM([resp])) + mock_llm([resp]) + result1 = ai.run(graph, MOCK_MODEL) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] - result2 = ai.run(graph, MockLLM([]), checkpoint=cp) + mock_llm([]) + result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 assert result2.checkpoint.hooks[-1].label == "my_approval" @@ -110,8 +114,8 @@ async def graph(llm: ai.LanguageModel) -> Any: @pytest.mark.asyncio async def test_parallel_hooks_all_collected() -> None: - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + async def graph(model: ai.Model) -> None: + await ai.stream_step(model, ai.make_messages(system="t", user="go")) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -123,15 +127,16 @@ async def b() -> Any: tg.create_task(a()) tg.create_task(b()) - result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + result = ai.run(graph, MOCK_MODEL) [msg async for msg in result] assert {"hook_a", "hook_b"} <= set(result.pending_hooks) @pytest.mark.asyncio async def test_parallel_hooks_resolve_on_reentry() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + async def graph(model: ai.Model) -> Any: + await ai.stream_step(model, ai.make_messages(system="t", user="go")) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -145,13 +150,15 @@ async def b() -> Any: return ta.result(), tb.result() resp = [text_msg("OK")] - result1 = ai.run(graph, MockLLM([resp])) + mock_llm([resp]) + result1 = ai.run(graph, MOCK_MODEL) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] - result2 = ai.run(graph, MockLLM([]), checkpoint=cp) + mock_llm([]) + result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index a51c2b55..1bd86bd9 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -8,7 +8,7 @@ import vercel_ai_sdk as ai -from ..conftest import MockLLM, text_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg @ai.hook @@ -32,15 +32,15 @@ async def test_resolve_live_future() -> None: """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" resolved_value = None - async def graph(llm: ai.LanguageModel) -> None: + async def graph(model: ai.Model) -> None: nonlocal resolved_value - await ai.stream_step(llm, ai.make_messages(user="go")) + await ai.stream_step(model, ai.make_messages(user="go")) result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] resolved_value = result - llm = MockLLM([[text_msg("OK")]]) + mock_llm([[text_msg("OK")]]) # Confirmation.cancels_future=False -> long-running mode - run_result = ai.run(graph, llm) + run_result = ai.run(graph, MOCK_MODEL) collected = [] async for msg in run_result: @@ -68,16 +68,16 @@ async def test_cancel_live_hook() -> None: """Hook.cancel() cancels the future, causing CancelledError in graph.""" was_cancelled = False - async def graph(llm: ai.LanguageModel) -> None: + async def graph(model: ai.Model) -> None: nonlocal was_cancelled - await ai.stream_step(llm, ai.make_messages(user="go")) + await ai.stream_step(model, ai.make_messages(user="go")) try: await Confirmation.create("cancel_me") # type: ignore[attr-defined] except asyncio.CancelledError: was_cancelled = True - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = ai.run(graph, MOCK_MODEL) async for msg in run_result: if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -102,16 +102,16 @@ async def test_cancel_nonexistent_raises() -> None: async def test_pre_registered_resolution_consumed() -> None: """Pre-registered resolution is consumed by Hook.create() without suspending.""" - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(user="go")) + async def graph(model: ai.Model) -> Any: + await ai.stream_step(model, ai.make_messages(user="go")) result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] return result # Pre-register BEFORE run Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = ai.run(graph, MOCK_MODEL) [m async for m in run_result] # Should have completed with no pending hooks @@ -137,12 +137,12 @@ def test_resolve_validates_schema() -> None: async def test_resolved_hook_emits_message() -> None: """After resolution, a 'resolved' HookPart message is emitted.""" - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(user="go")) + async def graph(model: ai.Model) -> None: + await ai.stream_step(model, ai.make_messages(user="go")) await Confirmation.create("emit_test") # type: ignore[attr-defined] - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = ai.run(graph, MOCK_MODEL) msgs = [] async for msg in run_result: @@ -164,13 +164,14 @@ async def graph(llm: ai.LanguageModel) -> None: @pytest.mark.asyncio async def test_hook_metadata_in_pending() -> None: - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(user="go")) + async def graph(model: ai.Model) -> None: + await ai.stream_step(model, ai.make_messages(user="go")) await CancellingConfirmation.create( # type: ignore[attr-defined] "meta_test", metadata={"tool": "rm -rf", "path": "/"} ) - run_result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + run_result = ai.run(graph, MOCK_MODEL) [m async for m in run_result] info = run_result.pending_hooks["meta_test"] diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 2445d7d3..e15d4f49 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -8,7 +8,7 @@ from vercel_ai_sdk.agents.runtime import Runtime from vercel_ai_sdk.types import messages -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg # -- Tool definitions for tests -------------------------------------------- @@ -32,15 +32,15 @@ async def concat(a: str, b: str) -> str: async def test_stream_loop_text_only() -> None: """stream_loop with no tool calls returns after one LLM call.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="Hi"), tools=[double], ) - llm = MockLLM([[text_msg("Hello!")]]) - result = ai.run(graph, llm) + llm = mock_llm([[text_msg("Hello!")]]) + result = ai.run(graph, MOCK_MODEL) msgs = [m async for m in result] assert llm.call_count == 1 assert any(m.text == "Hello!" for m in msgs) @@ -53,18 +53,18 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: async def test_stream_loop_tool_then_text() -> None: """stream_loop calls tool, feeds result back, gets final text.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="Double 5"), tools=[double], ) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] call2 = [text_msg("The answer is 10.")] - llm = MockLLM([call1, call2]) + llm = mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = ai.run(graph, MOCK_MODEL) msgs = [m async for m in result] assert llm.call_count == 2 # Tool should have been executed: 5 * 2 = 10 @@ -82,9 +82,9 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: async def test_stream_loop_parallel_tools() -> None: """LLM returns two tool calls in one message; both execute.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="Double 3 and 7"), tools=[double], ) @@ -110,9 +110,9 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: ], ) call2 = [text_msg("6 and 14", id="msg-2")] - llm = MockLLM([[two_tools], call2]) + llm = mock_llm([[two_tools], call2]) - result = ai.run(graph, llm) + result = ai.run(graph, MOCK_MODEL) msgs = [m async for m in result] assert llm.call_count == 2 # Both tools should have results @@ -131,9 +131,9 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: async def test_stream_loop_multi_turn() -> None: """LLM calls a tool, then calls another tool, then returns text.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="Concat then double"), tools=[double, concat], ) @@ -143,9 +143,9 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: ] turn2 = [tool_msg(tc_id="tc-2", name="double", args='{"x": 3}', id="msg-2")] turn3 = [text_msg("Done: hello world, 6", id="msg-3")] - llm = MockLLM([turn1, turn2, turn3]) + llm = mock_llm([turn1, turn2, turn3]) - result = ai.run(graph, llm) + result = ai.run(graph, MOCK_MODEL) [m async for m in result] assert llm.call_count == 3 @@ -163,10 +163,11 @@ async def test_execute_tool_missing_raises() -> None: tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" ) - async def graph(llm: ai.LanguageModel) -> None: + async def graph(model: ai.Model) -> None: await ai.execute_tool(tc) - result = ai.run(graph, MockLLM([])) + mock_llm([]) + result = ai.run(graph, MOCK_MODEL) with pytest.raises(ExceptionGroup) as exc_info: [m async for m in result] assert any(isinstance(e, ValueError) for e in exc_info.value.exceptions) @@ -187,8 +188,8 @@ async def introspect(query: str, rt: Runtime) -> str: received_rt = rt return "ok" - async def graph(llm: ai.LanguageModel) -> None: - result = await ai.stream_step(llm, ai.make_messages(user="go")) + async def graph(model: ai.Model) -> None: + result = await ai.stream_step(model, ai.make_messages(user="go")) if result.tool_calls: await asyncio.gather( *( @@ -198,7 +199,8 @@ async def graph(llm: ai.LanguageModel) -> None: ) call = [tool_msg(tc_id="tc-1", name="introspect", args='{"query": "test"}')] - result = ai.run(graph, MockLLM([call])) + mock_llm([call]) + result = ai.run(graph, MOCK_MODEL) [m async for m in result] assert received_rt is not None assert isinstance(received_rt, Runtime) @@ -211,8 +213,8 @@ async def graph(llm: ai.LanguageModel) -> None: async def test_execute_tool_updates_message() -> None: """After execute_tool, the ToolPart in the message has status=result.""" - async def graph(llm: ai.LanguageModel) -> None: - result = await ai.stream_step(llm, ai.make_messages(user="go")) + async def graph(model: ai.Model) -> None: + result = await ai.stream_step(model, ai.make_messages(user="go")) if result.tool_calls: msg = result.last_message for tc in result.tool_calls: @@ -223,7 +225,8 @@ async def graph(llm: ai.LanguageModel) -> None: assert msg.tool_calls[0].result == 10 call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] - result = ai.run(graph, MockLLM([call])) + mock_llm([call]) + result = ai.run(graph, MOCK_MODEL) [m async for m in result] @@ -234,18 +237,18 @@ async def graph(llm: ai.LanguageModel) -> None: async def test_stream_loop_checkpoint_records_tools() -> None: """stream_loop's tool executions are recorded in the checkpoint.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: + async def graph(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, + model, messages=ai.make_messages(user="Double 4"), tools=[double], ) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 4}')] call2 = [text_msg("8", id="msg-2")] - llm = MockLLM([call1, call2]) + mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = ai.run(graph, MOCK_MODEL) [m async for m in result] cp = result.checkpoint diff --git a/tests/agents/test_streams.py b/tests/agents/test_streams.py index f67e5454..2eb3ef20 100644 --- a/tests/agents/test_streams.py +++ b/tests/agents/test_streams.py @@ -7,7 +7,7 @@ from vercel_ai_sdk.agents.streams import StreamResult from vercel_ai_sdk.types import messages -from ..conftest import MockLLM, text_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg class _Weather(pydantic.BaseModel): @@ -58,9 +58,10 @@ def test_stream_result_tool_calls() -> None: @pytest.mark.asyncio async def test_stream_outside_run_raises() -> None: """@stream-decorated fn called without ai.run() should raise.""" + mock_llm([[text_msg("hi")]]) with pytest.raises(ValueError, match="No Runtime context"): await ai.stream_step( - MockLLM([[text_msg("hi")]]), + MOCK_MODEL, ai.make_messages(user="test"), ) @@ -72,18 +73,18 @@ async def test_stream_outside_run_raises() -> None: async def test_stream_step_replays_from_checkpoint() -> None: """stream_step inside ai.run with a checkpoint replays without calling LLM.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_step(llm, ai.make_messages(user="hello")) + async def graph(model: ai.Model) -> ai.StreamResult: + return await ai.stream_step(model, ai.make_messages(user="hello")) # First run - llm1 = MockLLM([[text_msg("Hi")]]) - r1 = ai.run(graph, llm1) + mock_llm([[text_msg("Hi")]]) + r1 = ai.run(graph, MOCK_MODEL) [msg async for msg in r1] cp = r1.checkpoint # Replay - llm2 = MockLLM([]) - r2 = ai.run(graph, llm2, checkpoint=cp) + llm2 = mock_llm([]) + r2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) [msg async for msg in r2] assert llm2.call_count == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 31fc755f..4a3d0363 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,20 +1,119 @@ from __future__ import annotations from collections.abc import AsyncGenerator, Sequence +from typing import Any import pydantic import vercel_ai_sdk as ai -from vercel_ai_sdk.models.core import llm as llm_ +from vercel_ai_sdk import models2 from vercel_ai_sdk.types import messages as messages_ +# A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. +MOCK_MODEL = models2.Model(id="mock-model", adapter="mock", provider="mock") -class MockLLM(ai.LanguageModel): +# Register a dummy provider so _auto_client() doesn't error for provider="mock". +models2._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") + + +class MockAdapter: + """Mock stream adapter that yields pre-configured response sequences. + + Each call to the adapter pops the next response list and yields the + messages through a StreamHandler (matching real adapter behavior). + Tracks ``call_count`` for assertions. + """ + + def __init__(self, responses: list[list[messages_.Message]]) -> None: + self._responses = list(responses) + self._call_index = 0 + self.call_count = 0 + + async def stream( + self, + client: models2.Client, + model: models2.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[messages_.Message]: + if self._call_index >= len(self._responses): + raise RuntimeError("MockAdapter: no more responses configured") + self.call_count += 1 + seq = self._responses[self._call_index] + self._call_index += 1 + + from vercel_ai_sdk.models2.core.helpers import streaming as streaming_ + + handler = streaming_.StreamHandler() + + for msg in seq: + for i, part in enumerate(msg.parts): + if isinstance(part, messages_.TextPart): + bid = f"text-{i}" + yield handler.handle_event(streaming_.TextStart(block_id=bid)) + if part.text: + yield handler.handle_event( + streaming_.TextDelta(block_id=bid, delta=part.text) + ) + yield handler.handle_event(streaming_.TextEnd(block_id=bid)) + + elif isinstance(part, messages_.ReasoningPart): + bid = f"reasoning-{i}" + yield handler.handle_event(streaming_.ReasoningStart(block_id=bid)) + if part.text: + yield handler.handle_event( + streaming_.ReasoningDelta(block_id=bid, delta=part.text) + ) + yield handler.handle_event( + streaming_.ReasoningEnd(block_id=bid, signature=part.signature) + ) + + elif isinstance(part, messages_.ToolPart): + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + ) + if part.tool_args: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=part.tool_call_id, + delta=part.tool_args, + ) + ) + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=part.tool_call_id) + ) + + yield handler.handle_event(streaming_.MessageDone()) + + +def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: + """Create a MockAdapter and register it in the models2 adapter registry. + + Returns the adapter so tests can inspect ``call_count``. + """ + adapter = MockAdapter(responses) + models2.register_stream("mock", adapter.stream) + return adapter + + +# ── Legacy MockLLM (for tests/models/ that test the old LanguageModel ABC) ── + + +class MockLLM(ai.models.LanguageModel): """LLM that yields pre-configured response sequences, one per call. Converts pre-configured ``Message`` objects into ``StreamEvent`` sequences so the base-class ``stream()`` (which uses ``StreamHandler``) can reconstruct them. + + **Legacy** — kept for tests of the old ``models/`` module. + New agent tests should use :func:`mock_llm` + ``MOCK_MODEL`` instead. """ def __init__(self, responses: list[list[messages_.Message]]) -> None: @@ -27,7 +126,9 @@ async def stream_events( messages: list[messages_.Message], tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: + ) -> AsyncGenerator[Any, None]: + from vercel_ai_sdk.models.core import llm as llm_ + if self._call_index >= len(self._responses): raise RuntimeError("MockLLM: no more responses configured") self.call_count += 1 @@ -65,6 +166,9 @@ async def stream_events( yield llm_.MessageDone() +# ── Helpers ────────────────────────────────────────────────────── + + def text_msg( text: str, *, id: str = "msg-1", state: str = "done", delta: str | None = None ) -> messages_.Message: diff --git a/tests/telemetry/test_otel_handler.py b/tests/telemetry/test_otel_handler.py index 5f2ff2b6..b7f329f5 100644 --- a/tests/telemetry/test_otel_handler.py +++ b/tests/telemetry/test_otel_handler.py @@ -12,7 +12,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.telemetry.otel import OtelHandler -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg @pytest.fixture @@ -36,10 +36,13 @@ async def double(x: int) -> int: async def test_text_only_spans(spans: InMemorySpanExporter) -> None: """Text-only run produces ai.run > ai.stream span hierarchy.""" - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + async def root(model: ai.Model) -> ai.StreamResult: + return await ai.stream_loop( + model, messages=ai.make_messages(user="Hi"), tools=[] + ) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] finished = spans.get_finished_spans() @@ -63,18 +66,18 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: async def test_tool_call_spans(spans: InMemorySpanExporter) -> None: """Tool-calling run produces ai.tool spans with correct attributes.""" - async def root(llm: ai.LanguageModel) -> ai.StreamResult: + async def root(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, messages=ai.make_messages(user="Double 5"), tools=[double] + model, messages=ai.make_messages(user="Double 5"), tools=[double] ) - llm = MockLLM( + mock_llm( [ [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')], [text_msg("10")], ] ) - result = ai.run(root, llm) + result = ai.run(root, MOCK_MODEL) [m async for m in result] finished = spans.get_finished_spans() diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index cf470334..8950bf63 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -17,7 +17,7 @@ ToolCallStartEvent, ) -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg # ── Recording handler ──────────────────────────────────────────── @@ -56,10 +56,13 @@ async def double(x: int) -> int: async def test_text_only_run_events(handler: RecordingHandler) -> None: """Simplest run emits RunStart, StepStart, StepFinish, RunFinish.""" - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + async def root(model: ai.Model) -> ai.StreamResult: + return await ai.stream_loop( + model, messages=ai.make_messages(user="Hi"), tools=[] + ) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -79,18 +82,18 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: async def test_tool_call_events(handler: RecordingHandler) -> None: """Tool-calling run emits tool events between steps with correct payloads.""" - async def root(llm: ai.LanguageModel) -> ai.StreamResult: + async def root(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, messages=ai.make_messages(user="Double 5"), tools=[double] + model, messages=ai.make_messages(user="Double 5"), tools=[double] ) - llm = MockLLM( + mock_llm( [ [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')], [text_msg("10")], ] ) - result = ai.run(root, llm) + result = ai.run(root, MOCK_MODEL) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -131,12 +134,13 @@ def handle(self, event: TelemetryEvent) -> None: ai.telemetry.enable(Capture()) try: - async def root(llm: ai.LanguageModel) -> ai.StreamResult: + async def root(model: ai.Model) -> ai.StreamResult: return await ai.stream_loop( - llm, messages=ai.make_messages(user="Hi"), tools=[] + model, messages=ai.make_messages(user="Hi"), tools=[] ) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] assert len(captured) == 16 finally: @@ -152,17 +156,21 @@ async def test_disable_reverts_to_noop() -> None: handler = RecordingHandler() ai.telemetry.enable(handler) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + async def root(model: ai.Model) -> ai.StreamResult: + return await ai.stream_loop( + model, messages=ai.make_messages(user="Hi"), tools=[] + ) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] assert len(handler.of_type(RunStartEvent)) == 1 ai.telemetry.disable() handler.events.clear() - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] assert len(handler.events) == 0 @@ -178,11 +186,14 @@ async def test_user_emitted_custom_event(handler: RecordingHandler) -> None: class CustomEvent(TelemetryEvent): message: str - async def root(llm: ai.LanguageModel) -> ai.StreamResult: + async def root(model: ai.Model) -> ai.StreamResult: ai.telemetry.handle(CustomEvent(message="hello")) - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + return await ai.stream_loop( + model, messages=ai.make_messages(user="Hi"), tools=[] + ) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = ai.run(root, MOCK_MODEL) [m async for m in result] custom = [e for e in handler.events if isinstance(e, CustomEvent)] @@ -197,10 +208,11 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: async def test_run_error_in_finish_event(handler: RecordingHandler) -> None: """RunFinishEvent captures the error when the root function raises.""" - async def root(llm: ai.LanguageModel) -> None: + async def root(model: ai.Model) -> None: raise ValueError("boom") - result = ai.run(root, MockLLM([])) + mock_llm([]) + result = ai.run(root, MOCK_MODEL) with pytest.raises(ExceptionGroup): [m async for m in result] From 9cac4d590eda96044491a5ca45078be66a5d0ecb Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 14:05:59 -0700 Subject: [PATCH 12/18] Outline reworked agents module, split runtime into three distinct entities --- src/vercel_ai_sdk/agents2/__init__.py | 56 ++ src/vercel_ai_sdk/agents2/checkpoint.py | 48 ++ src/vercel_ai_sdk/agents2/hooks.py | 245 +++++++++ src/vercel_ai_sdk/agents2/mcp/__init__.py | 6 + src/vercel_ai_sdk/agents2/mcp/client.py | 281 ++++++++++ src/vercel_ai_sdk/agents2/runtime.py | 639 ++++++++++++++++++++++ src/vercel_ai_sdk/agents2/streams.py | 105 ++++ src/vercel_ai_sdk/agents2/tools.py | 109 ++++ 8 files changed, 1489 insertions(+) create mode 100644 src/vercel_ai_sdk/agents2/__init__.py create mode 100644 src/vercel_ai_sdk/agents2/checkpoint.py create mode 100644 src/vercel_ai_sdk/agents2/hooks.py create mode 100644 src/vercel_ai_sdk/agents2/mcp/__init__.py create mode 100644 src/vercel_ai_sdk/agents2/mcp/client.py create mode 100644 src/vercel_ai_sdk/agents2/runtime.py create mode 100644 src/vercel_ai_sdk/agents2/streams.py create mode 100644 src/vercel_ai_sdk/agents2/tools.py diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py new file mode 100644 index 00000000..748848c1 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/__init__.py @@ -0,0 +1,56 @@ +"""Agent loop orchestration — tools, hooks, runtime, and streaming. + +Depends on types/ and models2/. Provides the loop machinery that +plugs a model into a tool-calling loop with hooks and checkpoints. +""" + +from . import mcp +from .checkpoint import Checkpoint, PendingHookInfo +from .hooks import Hook, ToolApproval, hook +from .runtime import ( + EventLog, + HookInfo, + LoopExecutor, + RunResult, + Runtime, + execute_tool, + get_checkpoint, + run, + stream_loop, + stream_step, +) +from .streams import StreamResult, stream +from .tools import Tool, ToolLike, ToolSchema, get_tool, tool + +__all__ = [ + # Core loop + "run", + "stream_step", + "stream_loop", + "execute_tool", + "get_checkpoint", + # Runtime (composition) + "Runtime", + "EventLog", + "LoopExecutor", + "RunResult", + "HookInfo", + # Stream + "stream", + "StreamResult", + # Tools + "Tool", + "ToolLike", + "ToolSchema", + "tool", + "get_tool", + # Hooks + "Hook", + "hook", + "ToolApproval", + # Checkpoint + "Checkpoint", + "PendingHookInfo", + # Submodules + "mcp", +] diff --git a/src/vercel_ai_sdk/agents2/checkpoint.py b/src/vercel_ai_sdk/agents2/checkpoint.py new file mode 100644 index 00000000..c3d079bc --- /dev/null +++ b/src/vercel_ai_sdk/agents2/checkpoint.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +from typing import Any + +import pydantic + +from ..types import messages as messages_ +from . import streams as streams_ + + +class StepEvent(pydantic.BaseModel): + """A completed @stream step.""" + + index: int + messages: list[messages_.Message] + + def to_stream_result(self) -> streams_.StreamResult: + return streams_.StreamResult(messages=list(self.messages)) + + +class ToolEvent(pydantic.BaseModel): + """A completed tool execution.""" + + tool_call_id: str + result: Any + status: str = "result" # "result" | "error" + + +class HookEvent(pydantic.BaseModel): + """A resolved hook.""" + + label: str + resolution: dict[str, Any] + + +class PendingHookInfo(pydantic.BaseModel): + """A hook that was suspended but not resolved when the run ended.""" + + label: str + hook_type: str + metadata: dict[str, Any] = {} + + +class Checkpoint(pydantic.BaseModel): + steps: list[StepEvent] = [] + tools: list[ToolEvent] = [] + hooks: list[HookEvent] = [] + pending_hooks: list[PendingHookInfo] = [] diff --git a/src/vercel_ai_sdk/agents2/hooks.py b/src/vercel_ai_sdk/agents2/hooks.py new file mode 100644 index 00000000..948539bb --- /dev/null +++ b/src/vercel_ai_sdk/agents2/hooks.py @@ -0,0 +1,245 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Any, ClassVar + +import pydantic + +from ..types import messages as messages_ + +if TYPE_CHECKING: + from . import runtime as runtime_ + + +# --------------------------------------------------------------------------- +# Module-level hook registries +# +# _live_hooks: +# Populated by Hook.create() when a hook suspends inside a running graph. +# Maps hook label -> (future, metadata dict, Runtime). +# Consumed by Hook.resolve() / Hook.cancel() to unblock the awaiting +# coroutine. Entries are removed when the hook resolves, cancels, or +# the run completes. +# +# _pending_resolutions: +# Populated by Hook.resolve() when no live hook exists yet (serverless +# re-entry: the user calls resolve() *before* ai.run() replays the graph). +# Maps hook label -> validated resolution dict. +# Consumed by Hook.create() at the start of graph execution — if a +# pre-registered resolution exists for the label, the hook returns +# immediately without suspending. Entries are removed on consumption. +# --------------------------------------------------------------------------- + +_live_hooks: dict[ + str, tuple[asyncio.Future[Any], dict[str, Any], runtime_.Runtime] +] = {} + +_pending_resolutions: dict[str, dict[str, Any]] = {} +# label -> validated resolution dict + + +def _cleanup_run(labels: set[str]) -> None: + """Remove all registry entries associated with a finished run.""" + for label in labels: + _live_hooks.pop(label, None) + _pending_resolutions.pop(label, None) + + +class Hook[T: pydantic.BaseModel]: + """Hook: a suspension point that requires external input to continue. + + Usage in graph code: + + approval = await ToolApproval.create("approve_delete", metadata={...}) + if approval.granted: + ... + + Resolution from outside the graph: + + ToolApproval.resolve("approve_delete", {"granted": True, ...}) + + Behavior depends on the ``cancels_future`` class variable: + + cancels_future=False (default, long-running): the await blocks until + Hook.resolve() is called from outside the graph (e.g., websocket + handler, API endpoint). + + cancels_future=True (serverless): if no resolution is available, the + hook's future is cancelled by run(). The branch receives CancelledError + and dies cleanly. On re-entry, call Hook.resolve() before ai.run() to + pre-register the resolution, then pass checkpoint= to replay. + """ + + _schema: ClassVar[type[pydantic.BaseModel]] + hook_type: ClassVar[str] + cancels_future: ClassVar[bool] = False + + @classmethod + async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: + """Create a hook and await its resolution. + + The hook is submitted to the LoopExecutor's step queue. run() will + either: + - Resolve immediately (if a resolution is available from checkpoint + or pre-registered via Hook.resolve()) + - Cancel the future (cancels_future=True, serverless mode) + - Hold the future (cancels_future=False, long-running mode) + """ + from . import runtime as rt_mod + + rt = rt_mod._runtime.get(None) + if rt is None: + raise ValueError("No Runtime context - must be called within ai.run()") + + # Check pre-registered resolutions (serverless re-entry path) + pre_registered = _pending_resolutions.pop(label, None) + if pre_registered is not None: + rt.log.record_hook(label, pre_registered) + return cls._schema(**pre_registered) # type: ignore[return-value] + + # Check checkpoint for a previously resolved value + resolution = rt.log.get_hook_resolution(label) + if resolution is not None: + rt.log.record_hook(label, resolution) + return cls._schema(**resolution) # type: ignore[return-value] + + # Submit to executor queue — run() decides what to do + future: asyncio.Future[dict[str, Any]] = asyncio.Future() + suspension = rt_mod.HookSuspension( + label=label, + hook_type=cls.hook_type, + metadata=metadata or {}, + future=future, + cancels_future=cls.cancels_future, + ) + await rt.executor.put_hook(suspension) + + # Register in module-level registry for external resolution + hook_metadata = metadata or {} + _live_hooks[label] = (future, hook_metadata, rt) + rt.executor.track_hook_label(label) + + # Await resolution — may be resolved immediately by run(), + # cancelled by run() (serverless), or resolved later by + # Hook.resolve() (long-running). + resolution = await future + + # Clean up + _live_hooks.pop(label, None) + + # Record for checkpoint + rt.log.record_hook(label, resolution) + + # Emit resolved message + await rt.executor.put_message( + messages_.Message( + role="assistant", + parts=[ + messages_.HookPart( + hook_id=label, + hook_type=cls.hook_type, + status="resolved", + metadata=hook_metadata, + resolution=resolution, + ) + ], + ) + ) + + return cls._schema(**resolution) # type: ignore[return-value] + + @classmethod + def resolve(cls, label: str, data: T | dict[str, Any]) -> None: + """Resolve a hook by label. + + Works in two modes: + + 1. Live hook exists (long-running): validates data, resolves the + future immediately, unblocking the awaiting coroutine. + + 2. No live hook yet (serverless re-entry): validates data and + stashes it in the pre-registration registry. When ai.run() + replays the graph and Hook.create() executes, it finds the + pre-registered resolution and returns without suspending. + """ + # Validate and normalize to dict + if isinstance(data, dict): + validated = cls._schema(**data) + resolution = validated.model_dump() + else: + if not isinstance(data, cls._schema): + raise TypeError( + f"Expected {cls._schema.__name__} or dict, " + f"got {type(data).__name__}" + ) + resolution = data.model_dump() + + # Path 1: live hook — resolve the future directly + if label in _live_hooks: + future, _, _rt = _live_hooks[label] + future.set_result(resolution) + return + + # Path 2: no live hook — pre-register for later consumption + _pending_resolutions[label] = resolution + + @classmethod + async def cancel(cls, label: str, reason: str | None = None) -> None: + """Cancel a pending hook. + + Only works for live hooks (long-running mode). Raises if the + hook is not currently pending. + """ + if label not in _live_hooks: + raise ValueError(f"No pending hook with label: {label}") + + future, hook_metadata, rt = _live_hooks.pop(label) + future.cancel(reason) + + await rt.executor.put_message( + messages_.Message( + role="assistant", + parts=[ + messages_.HookPart( + hook_id=label, + hook_type=cls.hook_type, + status="cancelled", + metadata=hook_metadata, + ) + ], + ) + ) + + +def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: + """Decorator to create a Hook type from a pydantic model. + + The pydantic model defines the schema for the hook's resolution payload. + """ + hook_impl = type( + cls.__name__, + (Hook,), + { + "_schema": cls, + "hook_type": cls.__name__, + "cancels_future": cls.__dict__.get("cancels_future", False), + "__doc__": cls.__doc__, + }, + ) + + return hook_impl + + +@hook +class ToolApproval(pydantic.BaseModel): + """Prewired hook for tool call approval. + + Used by the AI SDK UI adapter to bridge the protocol's + tool-approval-request / approval-responded flow to the + hook system. + """ + + cancels_future: ClassVar[bool] = True + + granted: bool + reason: str | None = None diff --git a/src/vercel_ai_sdk/agents2/mcp/__init__.py b/src/vercel_ai_sdk/agents2/mcp/__init__.py new file mode 100644 index 00000000..c1202f63 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/mcp/__init__.py @@ -0,0 +1,6 @@ +from .client import get_http_tools, get_stdio_tools + +__all__ = [ + "get_stdio_tools", + "get_http_tools", +] diff --git a/src/vercel_ai_sdk/agents2/mcp/client.py b/src/vercel_ai_sdk/agents2/mcp/client.py new file mode 100644 index 00000000..c17a25a0 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/mcp/client.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import asyncio +import contextlib +import contextvars +import dataclasses +import json +from collections.abc import Callable +from typing import Any + +import httpx +import mcp.client.session +import mcp.client.stdio +import mcp.client.streamable_http +import mcp.types + +from .. import tools as tools_ + +__all__ = [ + "get_stdio_tools", + "get_http_tools", + "close_connections", +] + + +@dataclasses.dataclass +class _Connection: + """Internal connection state - never exposed to users.""" + + client: mcp.client.session.ClientSession + exit_stack: contextlib.AsyncExitStack + + +# Connection pool stored in contextvar, scoped to execute() +# The pool is set by execute() and cleaned up when execute() finishes +_pool: contextvars.ContextVar[dict[str, _Connection] | None] = contextvars.ContextVar( + "mcp_connections", default=None +) + +_pool_lock = asyncio.Lock() + + +async def _get_or_create_connection( + key: str, + transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], +) -> mcp.client.session.ClientSession: + """Get an existing connection or create a new one.""" + pool = _pool.get() + + if pool is None: + raise RuntimeError( + "MCP tools must be used inside ai.execute(). " + "The connection pool is not initialized." + ) + + async with _pool_lock: + if key in pool: + return pool[key].client + + # Use AsyncExitStack for clean resource management + exit_stack = contextlib.AsyncExitStack() + + try: + # Enter the transport context + streams = await exit_stack.enter_async_context(transport_factory()) + + # Handle both (read, write) and (read, write, callback) returns + read_stream, write_stream = streams[0], streams[1] + + # Create and initialize the client session + client = mcp.client.session.ClientSession( + read_stream=read_stream, + write_stream=write_stream, + ) + await exit_stack.enter_async_context(client) + await client.initialize() + + pool[key] = _Connection(client=client, exit_stack=exit_stack) + return client + + except BaseException: + # Clean up on any error during setup + await exit_stack.aclose() + raise + + +def _make_tool_fn( + connection_key: str, + tool_name: str, + transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], +) -> Callable[..., Any]: + """Create a tool function that manages its own connection.""" + + async def call_tool(**kwargs: Any) -> Any: + client = await _get_or_create_connection(connection_key, transport_factory) + try: + result = await asyncio.wait_for( + client.call_tool(tool_name, kwargs), + timeout=30.0, + ) + except TimeoutError as e: + raise RuntimeError( + f"MCP tool call timed out after 30 seconds: {tool_name}" + ) from e + + # Handle error responses + if result.isError: + error_text = " ".join( + part.text + for part in result.content + if isinstance(part, mcp.types.TextContent) + ) + raise RuntimeError(f"MCP tool error: {error_text or 'Unknown error'}") + + # Prefer structured content if available + if result.structuredContent is not None: + return result.structuredContent + + # Fall back to parsing content + for part in result.content: + if isinstance(part, mcp.types.TextContent): + text = part.text + # Try to parse JSON, otherwise return raw text + if text.startswith(("{", "[")): + try: + return json.loads(text) + except json.JSONDecodeError: + pass + return text + + return result.content + + return call_tool + + +async def get_stdio_tools( + command: str, + *args: str, + env: dict[str, str] | None = None, + cwd: str | None = None, + tool_prefix: str | None = None, +) -> list[tools_.Tool[..., Any]]: + """ + Get tools from an MCP server running as a subprocess. + + Connection is managed automatically - created on first use, cleaned up + when execute() finishes. + + Args: + command: The command to run (e.g., "npx", "python"). + *args: Arguments to pass to the command. + env: Environment variables for the subprocess. + cwd: Working directory for the subprocess. + tool_prefix: Optional prefix to add to all tool names. + + Returns: + List of Tool objects that can be passed to stream_loop. + + Example: + tools = await ai.mcp.get_stdio_tools( + "npx", "-y", "@anthropic/mcp-server-filesystem", "/tmp" + ) + """ + connection_key = f"stdio:{command}:{':'.join(args)}" + + def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: + return mcp.client.stdio.stdio_client( + mcp.client.stdio.StdioServerParameters( + command=command, + args=list(args), + env=env, + cwd=cwd, + ) + ) + + client = await _get_or_create_connection(connection_key, transport_factory) + result = await client.list_tools() + + return [ + _mcp_tool_to_native(mcp_tool, connection_key, transport_factory, tool_prefix) + for mcp_tool in result.tools + ] + + +async def get_http_tools( + url: str, + *, + headers: dict[str, str] | None = None, + tool_prefix: str | None = None, +) -> list[tools_.Tool[..., Any]]: + """ + Get tools from an MCP server over HTTP (Streamable HTTP transport). + + Connection is managed automatically - created on first use, cleaned up + when execute() finishes. + + Args: + url: The URL of the MCP server endpoint. + headers: Optional HTTP headers (e.g., for authentication). + tool_prefix: Optional prefix to add to all tool names. + + Returns: + List of Tool objects that can be passed to stream_loop. + + Example: + tools = await ai.mcp.get_http_tools( + "http://localhost:3000/mcp", + headers={"Authorization": "Bearer xxx"} + ) + """ + connection_key = f"http:{url}" + + def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: + http_client = httpx.AsyncClient(headers=headers) if headers else None + return mcp.client.streamable_http.streamable_http_client( + url=url, http_client=http_client + ) + + client = await _get_or_create_connection(connection_key, transport_factory) + result = await client.list_tools() + + return [ + _mcp_tool_to_native(mcp_tool, connection_key, transport_factory, tool_prefix) + for mcp_tool in result.tools + ] + + +def _mcp_tool_to_native( + mcp_tool: mcp.types.Tool, + connection_key: str, + transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], + tool_prefix: str | None, +) -> tools_.Tool[..., Any]: + """Convert an MCP tool to a native Tool.""" + name = mcp_tool.name + if tool_prefix: + name = f"{tool_prefix}_{name}" + + schema = tools_.ToolSchema( + name=name, + description=mcp_tool.description or "", + param_schema=mcp_tool.inputSchema, + return_type=Any, + ) + + t = tools_.Tool( + fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), + schema=schema, + ) + # Register so execute_tool() can find it by name + tools_._tool_registry[name] = t + return t + + +async def close_connections() -> None: + """ + Close all MCP connections in the current context. + + This is called automatically by execute(), but can be called + manually for explicit cleanup. + """ + pool = _pool.get() + if pool is None: + return + + async with _pool_lock: + if not pool: + return + + # Use TaskGroup for concurrent cleanup + async with asyncio.TaskGroup() as tg: + for conn in pool.values(): + tg.create_task(_close_connection_safely(conn)) + + pool.clear() + + +async def _close_connection_safely(conn: _Connection) -> None: + """Close a connection, suppressing any errors.""" + with contextlib.suppress(Exception): + await conn.exit_stack.aclose() diff --git a/src/vercel_ai_sdk/agents2/runtime.py b/src/vercel_ai_sdk/agents2/runtime.py new file mode 100644 index 00000000..6b3d2816 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/runtime.py @@ -0,0 +1,639 @@ +from __future__ import annotations + +import asyncio +import contextvars +import dataclasses +import json +import logging +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Sequence +from typing import Any, get_type_hints + +import pydantic + +from .. import models2 +from ..telemetry import events as telemetry_ +from ..types import messages as messages_ +from . import checkpoint as checkpoint_ +from . import hooks as hooks_ +from . import mcp +from . import streams as streams_ +from . import tools as tools_ + +logger = logging.getLogger(__name__) + + +# ── EventLog ────────────────────────────────────────────────────── +# +# Pure bookkeeping: replay from checkpoint + record new events. +# No asyncio, no queues — just data in, data out. +# + + +class EventLog: + """Replay/record layer backed by a Checkpoint. + + Holds replay cursors (read pointer into the checkpoint) and + recording lists (new events produced during this run). + Completely synchronous — no queues, no async. + """ + + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: + self._checkpoint = checkpoint or checkpoint_.Checkpoint() + + # Replay cursors + self._step_index: int = 0 + self._tool_replay: dict[str, checkpoint_.ToolEvent] = { + t.tool_call_id: t for t in self._checkpoint.tools + } + self._hook_replay: dict[str, dict[str, Any]] = { + h.label: h.resolution for h in self._checkpoint.hooks + } + + # Recording lists (new events from this run) + self._step_log: list[checkpoint_.StepEvent] = [] + self._tool_log: list[checkpoint_.ToolEvent] = [] + self._hook_log: list[checkpoint_.HookEvent] = [] + + # ── Steps ───────────────────────────────────────────────── + + @property + def step_index(self) -> int: + return self._step_index + + def try_replay_step(self) -> streams_.StreamResult | None: + if self._step_index < len(self._checkpoint.steps): + event = self._checkpoint.steps[self._step_index] + self._step_index += 1 + logger.info("Replaying step %d from checkpoint", event.index) + return event.to_stream_result() + return None + + def record_step(self, result: streams_.StreamResult) -> None: + event = checkpoint_.StepEvent( + index=self._step_index, + messages=list(result.messages), + ) + self._step_log.append(event) + self._step_index += 1 + + # ── Tools ───────────────────────────────────────────────── + + def try_replay_tool(self, tool_call_id: str) -> checkpoint_.ToolEvent | None: + event = self._tool_replay.get(tool_call_id) + if event is not None: + logger.info( + "Replaying tool %s (call_id=%s) from checkpoint", + event.tool_call_id, + tool_call_id, + ) + return event + + def record_tool( + self, tool_call_id: str, result: Any, *, status: str = "result" + ) -> None: + self._tool_log.append( + checkpoint_.ToolEvent( + tool_call_id=tool_call_id, result=result, status=status + ) + ) + + # ── Hooks ───────────────────────────────────────────────── + + def get_hook_resolution(self, label: str) -> dict[str, Any] | None: + resolution = self._hook_replay.get(label) + if resolution is not None: + logger.info("Resolving hook '%s' from checkpoint", label) + return resolution + + def record_hook(self, label: str, resolution: dict[str, Any]) -> None: + self._hook_log.append(checkpoint_.HookEvent(label=label, resolution=resolution)) + + # ── Snapshot ────────────────────────────────────────────── + + def checkpoint( + self, pending_hooks: list[checkpoint_.PendingHookInfo] | None = None + ) -> checkpoint_.Checkpoint: + """Build a full Checkpoint merging prior state + new recordings.""" + return checkpoint_.Checkpoint( + steps=list(self._checkpoint.steps) + self._step_log, + tools=list(self._checkpoint.tools) + self._tool_log, + hooks=list(self._checkpoint.hooks) + self._hook_log, + pending_hooks=pending_hooks or [], + ) + + +# ── LoopExecutor ───────────────────────────────────────────────── +# +# Async coordination: queues that let graph code (streams, hooks, +# tools) talk to the driver loop. Pure mailbox — no replay, no +# checkpoint awareness. +# + + +@dataclasses.dataclass +class HookSuspension: + """Submitted to the step queue when a hook needs resolution.""" + + label: str + hook_type: str + metadata: dict[str, Any] + future: asyncio.Future[Any] + cancels_future: bool = False + + +class LoopExecutor: + """Async coordination layer between graph code and the driver loop. + + Graph code (``@stream`` decorators, hooks, tool execution) submits + work via the producer methods. The driver loop consumes via + ``next()`` and ``drain_messages()``. + """ + + class _Sentinel: + pass + + _SENTINEL = _Sentinel() + + def __init__(self) -> None: + self._step_queue: asyncio.Queue[ + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | LoopExecutor._Sentinel + ] = asyncio.Queue() + + self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() + + # Pending hooks (unresolved during this run) + self._pending_hooks: dict[str, HookSuspension] = {} + + # Track hook labels registered in this run for cleanup + self._hook_labels: set[str] = set() + + # ── Producers (called by graph code) ────────────────────── + + async def put_step( + self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] + ) -> None: + await self._step_queue.put((step_fn, future)) + + async def put_hook(self, suspension: HookSuspension) -> None: + await self._step_queue.put(suspension) + + async def put_message(self, message: messages_.Message) -> None: + await self._message_queue.put(message) + + async def done(self) -> None: + await self._step_queue.put(self._SENTINEL) + + # ── Consumer (called by driver loop) ────────────────────── + + async def next( + self, timeout: float = 0.1 + ) -> ( + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | None + ): + """Pull the next item from the step queue. + + Returns ``None`` on timeout (no item available). + Returns the sentinel's semantic equivalent by raising StopIteration + when the graph signals completion. + """ + try: + item = await asyncio.wait_for(self._step_queue.get(), timeout=timeout) + except TimeoutError: + return None + + if isinstance(item, LoopExecutor._Sentinel): + raise _LoopDone + return item + + def drain_messages(self) -> list[messages_.Message]: + msgs: list[messages_.Message] = [] + while not self._message_queue.empty(): + try: + msgs.append(self._message_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return msgs + + # ── Hook label tracking ─────────────────────────────────── + + def track_hook_label(self, label: str) -> None: + self._hook_labels.add(label) + + def pending_hook_infos(self) -> list[checkpoint_.PendingHookInfo]: + return [ + checkpoint_.PendingHookInfo( + label=sus.label, + hook_type=sus.hook_type, + metadata=sus.metadata, + ) + for sus in self._pending_hooks.values() + ] + + +class _LoopDone(Exception): + """Internal signal: the loop function has finished.""" + + +# ── Runtime ─────────────────────────────────────────────────────── +# +# Thin composition of EventLog + LoopExecutor. +# The context var points here; graph code accesses rt.log and +# rt.executor directly. +# + + +class Runtime: + """Central coordinator — composes EventLog and LoopExecutor. + + Graph code accesses ``rt.log`` for replay/record and + ``rt.executor`` for async coordination. + """ + + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: + self.log = EventLog(checkpoint) + self.executor = LoopExecutor() + + def checkpoint(self) -> checkpoint_.Checkpoint: + return self.log.checkpoint( + pending_hooks=self.executor.pending_hook_infos(), + ) + + +_runtime: contextvars.ContextVar[Runtime] = contextvars.ContextVar("runtime") + + +def get_checkpoint() -> checkpoint_.Checkpoint: + """Get the current checkpoint from the active Runtime.""" + return _runtime.get().checkpoint() + + +def _find_runtime_param(fn: Callable[..., Any]) -> str | None: + """Find a parameter typed as Runtime, return its name or None.""" + try: + hints = get_type_hints(fn) + except Exception: + return None + for name, hint in hints.items(): + if hint is Runtime: + return name + return None + + +# ── Convenience functions ───────────────────────────────────────── + + +@streams_.stream +async def stream_step( + model: models2.Model, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + label: str | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Single LLM call that streams to Runtime.""" + async for msg in models2.stream( + model, messages, tools=tools, output_type=output_type, **kwargs + ): + msg.label = label + yield msg + + +async def execute_tool( + tool_call: messages_.ToolPart, + message: messages_.Message | None = None, +) -> Any: + """Execute a single tool call with replay support. + + Looks up the tool by name from the global registry, executes it, + and updates the ToolPart (and parent Message) with the result. + Emits the updated message to the LoopExecutor queue so the UI sees + the transition from status="pending" to status="result" (or "error"). + + If a checkpoint exists with a cached result for this tool_call_id, + returns the cached result without re-executing. + """ + rt = _runtime.get(None) + + # Replay: return cached result if available + if rt: + cached = rt.log.try_replay_tool(tool_call.tool_call_id) + if cached is not None: + if cached.status == "error": + tool_call.set_error(cached.result) + else: + tool_call.set_result(cached.result) + return cached.result + + telemetry_.handle( + telemetry_.ToolCallStartEvent( + tool_name=tool_call.tool_name, + tool_call_id=tool_call.tool_call_id, + args=tool_call.tool_args, + ) + ) + t0 = telemetry_.time_ms() + + # Fresh execution + tool = tools_.get_tool(tool_call.tool_name) + if tool is None: + raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") + + error_str: str | None = None + try: + result = await tool.validate_and_call(tool_call.tool_args, rt) + tool_call.set_result(result) + except (json.JSONDecodeError, pydantic.ValidationError) as exc: + result = f"{type(exc).__name__}: {exc}" + error_str = result + tool_call.set_error(result) + + telemetry_.handle( + telemetry_.ToolCallFinishEvent( + tool_name=tool_call.tool_name, + tool_call_id=tool_call.tool_call_id, + result=result, + error=error_str, + duration_ms=telemetry_.time_ms() - t0, + ) + ) + + # Record for checkpoint + if rt: + rt.log.record_tool(tool_call.tool_call_id, result, status=tool_call.status) + + # Emit updated message so UI sees status change + if rt and message: + await rt.executor.put_message(message.model_copy(deep=True)) + + return result + + +async def stream_loop( + model: models2.Model, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike], + label: str | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> streams_.StreamResult: + """Agent loop: stream LLM, execute tools, repeat until done.""" + local_messages = list(messages) + + while True: + result = await stream_step( + model, local_messages, tools, label=label, output_type=output_type, **kwargs + ) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *(execute_tool(tc, message=last_msg) for tc in result.tool_calls) + ) + + +# ── RunResult ───────────────────────────────────────────────────── + + +@dataclasses.dataclass +class HookInfo: + """Info about a pending (unresolved) hook, exposed on RunResult.""" + + label: str + hook_type: str + metadata: dict[str, Any] + + +class RunResult: + """Returned by run(). Async-iterate for messages, then check state. + + Usage: + result = ai.run(my_graph, llm, query) + async for msg in result: + ... + result.checkpoint # Checkpoint with all completed work + result.pending_hooks # dict of unresolved hooks (empty if graph completed) + """ + + def __init__(self) -> None: + self._messages: AsyncGenerator[messages_.Message] | None = None + self._runtime: Runtime | None = None + + @property + def checkpoint(self) -> checkpoint_.Checkpoint: + if self._runtime is None: + return checkpoint_.Checkpoint() + return self._runtime.checkpoint() + + @property + def pending_hooks(self) -> dict[str, HookInfo]: + if self._runtime is None: + return {} + return { + label: HookInfo( + label=sus.label, + hook_type=sus.hook_type, + metadata=sus.metadata, + ) + for label, sus in self._runtime.executor._pending_hooks.items() + } + + async def __aiter__(self) -> AsyncGenerator[messages_.Message]: + if self._messages is not None: + async for msg in self._messages: + yield msg + + +# ── run() ───────────────────────────────────────────────────────── + + +async def _stop_when_done(executor: LoopExecutor, task: Awaitable[None]) -> None: + try: + await task + finally: + await executor.done() + + +def run( + root: Callable[..., Coroutine[Any, Any, Any]], + *args: Any, + checkpoint: checkpoint_.Checkpoint | None = None, +) -> RunResult: + """Main entry point. + + 1. Starts the root function as a background task + 2. Pulls steps and hook suspensions from the LoopExecutor queue + 3. Executes each step, yielding messages + 4. Resolves or suspends hooks depending on the hook's cancels_future + 5. Returns RunResult with .checkpoint and .pending_hooks + """ + result = RunResult() + + # Discard stale checkpoints: if the checkpoint has pending hooks but + # none of them have been resolved, this isn't a resume. + effective_checkpoint = checkpoint + if checkpoint and checkpoint.pending_hooks: + pending_labels = [ph.label for ph in checkpoint.pending_hooks] + has_resolution = any( + label in hooks_._pending_resolutions for label in pending_labels + ) + if not has_resolution: + logger.info( + "Discarding stale checkpoint: pending hooks %s have no " + "matching resolutions", + pending_labels, + ) + effective_checkpoint = None + else: + logger.info( + "Resuming from checkpoint with %d pending hook(s): %s", + len(pending_labels), + pending_labels, + ) + + async def _generate() -> AsyncGenerator[messages_.Message]: + rt = Runtime(checkpoint=effective_checkpoint) + result._runtime = rt + token_runtime = _runtime.set(rt) + token_run_id = telemetry_.start_run() + + telemetry_.handle(telemetry_.RunStartEvent()) + + mcp_pool: dict[str, mcp.client._Connection] = {} + mcp_token = mcp.client._pool.set(mcp_pool) + + kwargs: dict[str, Any] = {} + if runtime_param := _find_runtime_param(root): + kwargs[runtime_param] = rt + + run_error: BaseException | None = None + total_usage: messages_.Usage | None = None + + try: + async with asyncio.TaskGroup() as tg: + _task: asyncio.Task[None] = tg.create_task( + _stop_when_done(rt.executor, root(*args, **kwargs)) + ) + + while True: + # Drain pending messages + for msg in rt.executor.drain_messages(): + yield msg + + # Pull next item from the graph executor + try: + item = await rt.executor.next() + except _LoopDone: + for msg in rt.executor.drain_messages(): + yield msg + break + + if item is None: + # Timeout — no item available, loop again + continue + + # ── Hook suspension ──────────────────────── + if isinstance(item, HookSuspension): + resolution = rt.log.get_hook_resolution(item.label) + if resolution is not None: + item.future.set_result(resolution) + rt.log.record_hook(item.label, resolution) + else: + rt.executor._pending_hooks[item.label] = item + if item.cancels_future: + item.future.cancel() + + yield messages_.Message( + role="assistant", + parts=[ + messages_.HookPart( + hook_id=item.label, + hook_type=item.hook_type, + status="pending", + metadata=item.metadata, + ) + ], + ) + + await asyncio.sleep(0) + for msg in rt.executor.drain_messages(): + yield msg + continue + + # ── Regular step ─────────────────────────── + step_fn, future = item + + telemetry_.handle( + telemetry_.StepStartEvent( + step_index=rt.log.step_index, + ) + ) + + for tool_msg in rt.executor.drain_messages(): + yield tool_msg + + result_messages: list[messages_.Message] = [] + + async for msg in step_fn(): + msg_copy = msg.model_copy(deep=True) + yield msg_copy + result_messages.append(msg) + + for tool_msg in rt.executor.drain_messages(): + yield tool_msg + + step_result = streams_.StreamResult(messages=result_messages) + future.set_result(step_result) + + telemetry_.handle( + telemetry_.StepFinishEvent( + step_index=rt.log.step_index, + result=step_result, # type: ignore[arg-type] + ) + ) + + # Accumulate usage for run-level telemetry + step_usage = step_result.usage + if step_usage is not None: + total_usage = ( + step_usage + if total_usage is None + else total_usage + step_usage + ) + + await asyncio.sleep(0) + for tool_msg in rt.executor.drain_messages(): + yield tool_msg + + except BaseException as exc: + run_error = exc + raise + + finally: + telemetry_.handle( + telemetry_.RunFinishEvent( + usage=total_usage, + error=run_error, + ) + ) + telemetry_.end_run(token_run_id) + + hooks_._cleanup_run(rt.executor._hook_labels) + + if mcp_token is not None: + await mcp.client.close_connections() + mcp.client._pool.reset(mcp_token) + + _runtime.reset(token_runtime) + + result._messages = _generate() + return result diff --git a/src/vercel_ai_sdk/agents2/streams.py b/src/vercel_ai_sdk/agents2/streams.py new file mode 100644 index 00000000..fadf6747 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/streams.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import asyncio +import dataclasses +import functools +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +from ..types import messages as messages_ + + +@dataclasses.dataclass +class StreamResult: + messages: list[messages_.Message] = dataclasses.field(default_factory=list) + + @property + def last_message(self) -> messages_.Message | None: + return self.messages[-1] if self.messages else None + + @property + def tool_calls(self) -> list[messages_.ToolPart]: + """Get tool calls from the last message.""" + if self.last_message: + return self.last_message.tool_calls + return [] + + @property + def text(self) -> str: + if self.last_message: + return self.last_message.text + return "" + + @property + def output(self) -> Any: + """Parsed structured output from the last message, if available.""" + if self.last_message: + return self.last_message.output + return None + + @property + def usage(self) -> messages_.Usage | None: + """Token usage from the last (most recent) LLM call.""" + if self.last_message: + return self.last_message.usage + return None + + @property + def total_usage(self) -> messages_.Usage | None: + """Accumulated token usage across all LLM calls in this result. + + Sums usage from every message that carries it (i.e. assistant + messages produced by LLM calls). Returns ``None`` if no message + reported usage. + """ + total: messages_.Usage | None = None + for msg in self.messages: + if msg.usage is not None: + total = msg.usage if total is None else total + msg.usage + return total + + +Stream = Callable[[], AsyncGenerator[messages_.Message]] +# maybe it should have a name and an id inferred from LLM outputs + + +def stream[**P]( + fn: Callable[P, AsyncGenerator[messages_.Message]], +) -> Callable[P, Awaitable[StreamResult]]: + """Decorator to put an async generator into the LoopExecutor queue. + + The decorated function submits its work to the executor queue and + blocks until run() processes it, then returns the StreamResult. + + If a checkpoint exists with a cached result for this step index, + returns the cached result without submitting to the queue (replay). + """ + + from . import runtime as runtime_ + + @functools.wraps(fn) + async def wrapped(*args: Any, **kwargs: Any) -> StreamResult: + rt: runtime_.Runtime | None = runtime_._runtime.get(None) + if rt is None: + raise ValueError("No Runtime context - must be called within ai.run()") + + # Replay: return cached result if available + cached = rt.log.try_replay_step() + if cached is not None: + return cached + + # Fresh execution: submit to executor queue and wait + future: asyncio.Future[StreamResult] = asyncio.Future() + + async def stream_fn() -> AsyncGenerator[messages_.Message]: + async for msg in fn(*args, **kwargs): + yield msg + + await rt.executor.put_step(stream_fn, future) + result = await future + + # Record for checkpoint + rt.log.record_step(result) + return result + + return wrapped diff --git a/src/vercel_ai_sdk/agents2/tools.py b/src/vercel_ai_sdk/agents2/tools.py new file mode 100644 index 00000000..f4e3744b --- /dev/null +++ b/src/vercel_ai_sdk/agents2/tools.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import inspect +import json +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any, get_type_hints + +import pydantic + +from ..types.tools import ToolLike as ToolLike +from ..types.tools import ToolSchema as ToolSchema + +if TYPE_CHECKING: + from . import runtime as runtime_ + +# Module-level tool registry - populated at decoration time +_tool_registry: dict[str, Tool[..., Any]] = {} + + +def get_tool(name: str) -> Tool[..., Any] | None: + """Look up a tool by name from the global registry.""" + return _tool_registry.get(name) + + +def _is_runtime_type(hint: Any) -> bool: + """Check if a type hint is the Runtime class.""" + # Import here to avoid circular import at runtime + from .runtime import Runtime + + return hint is Runtime + + +class Tool[**P, R]: + def __init__( + self, + fn: Callable[P, Awaitable[R]], + schema: ToolSchema, + validator: type[pydantic.BaseModel] | None = None, + ) -> None: + self._fn = fn + self._validator = validator + self.schema = schema + + async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + return await self._fn(*args, **kwargs) + + async def validate_and_call( + self, json_str: str, runtime: runtime_.Runtime | None + ) -> R: + from .runtime import _find_runtime_param + + kwargs = json.loads(json_str) if json_str else {} + + if runtime and (rt_param := _find_runtime_param(self._fn)): + kwargs[rt_param] = runtime + + # validate llm-generated inputs (skipped for MCP tools) + if self._validator is not None: + self._validator.model_validate(kwargs) + return await self._fn(**kwargs) # type: ignore[call-arg] + + @property + def name(self) -> str: + return self.schema.name + + @property + def description(self) -> str: + return self.schema.description + + @property + def param_schema(self) -> dict[str, Any]: + return self.schema.param_schema + + +def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: + """Decorator to define a tool from an async function.""" + + # 1. build tool schema by parsing the function + sig = inspect.signature(fn) + hints = get_type_hints(fn) if hasattr(fn, "__annotations__") else {} + + fields: dict[str, Any] = {} + + for param_name, param in sig.parameters.items(): + param_type = hints.get(param_name, str) + + if _is_runtime_type(param_type): + continue + if param.default is inspect.Parameter.empty: + fields[param_name] = (param_type, ...) + else: + fields[param_name] = (param_type, param.default) + + validator = pydantic.create_model(f"{fn.__name__}_Args", **fields) + + # 2. instantiate the tool + + schema = ToolSchema( + name=fn.__name__, + description=inspect.getdoc(fn) or "", + param_schema=validator.model_json_schema(), + return_type=hints.get("return", None), + ) + + t = Tool(fn=fn, schema=schema, validator=validator) + + # 3. register in global registry + _tool_registry[t.name] = t + return t From 73d1a7fd74e64dd04fd1c8d0fc012cfbf1f11ee0 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 14:27:01 -0700 Subject: [PATCH 13/18] Remove stream_step and stream_loop helpers --- src/vercel_ai_sdk/agents2/__init__.py | 4 --- src/vercel_ai_sdk/agents2/runtime.py | 51 +-------------------------- 2 files changed, 1 insertion(+), 54 deletions(-) diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py index 748848c1..aa4941be 100644 --- a/src/vercel_ai_sdk/agents2/__init__.py +++ b/src/vercel_ai_sdk/agents2/__init__.py @@ -16,8 +16,6 @@ execute_tool, get_checkpoint, run, - stream_loop, - stream_step, ) from .streams import StreamResult, stream from .tools import Tool, ToolLike, ToolSchema, get_tool, tool @@ -25,8 +23,6 @@ __all__ = [ # Core loop "run", - "stream_step", - "stream_loop", "execute_tool", "get_checkpoint", # Runtime (composition) diff --git a/src/vercel_ai_sdk/agents2/runtime.py b/src/vercel_ai_sdk/agents2/runtime.py index 6b3d2816..7e078894 100644 --- a/src/vercel_ai_sdk/agents2/runtime.py +++ b/src/vercel_ai_sdk/agents2/runtime.py @@ -5,12 +5,11 @@ import dataclasses import json import logging -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine from typing import Any, get_type_hints import pydantic -from .. import models2 from ..telemetry import events as telemetry_ from ..types import messages as messages_ from . import checkpoint as checkpoint_ @@ -283,26 +282,6 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: return None -# ── Convenience functions ───────────────────────────────────────── - - -@streams_.stream -async def stream_step( - model: models2.Model, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: - """Single LLM call that streams to Runtime.""" - async for msg in models2.stream( - model, messages, tools=tools, output_type=output_type, **kwargs - ): - msg.label = label - yield msg - - async def execute_tool( tool_call: messages_.ToolPart, message: messages_.Message | None = None, @@ -373,34 +352,6 @@ async def execute_tool( return result -async def stream_loop( - model: models2.Model, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike], - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kwargs: Any, -) -> streams_.StreamResult: - """Agent loop: stream LLM, execute tools, repeat until done.""" - local_messages = list(messages) - - while True: - result = await stream_step( - model, local_messages, tools, label=label, output_type=output_type, **kwargs - ) - - if not result.tool_calls: - return result - - last_msg = result.last_message - if last_msg is not None: - local_messages.append(last_msg) - - await asyncio.gather( - *(execute_tool(tc, message=last_msg) for tc in result.tool_calls) - ) - - # ── RunResult ───────────────────────────────────────────────────── From 3235af90c851842c321947384514ab37b36e9d2a Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 15:27:42 -0700 Subject: [PATCH 14/18] Introduce the context object that encapsulates everything that goes into the LLM --- src/vercel_ai_sdk/agents2/__init__.py | 5 + src/vercel_ai_sdk/agents2/context.py | 206 ++++++++++++++++++++++++ src/vercel_ai_sdk/agents2/mcp/client.py | 22 ++- src/vercel_ai_sdk/agents2/runtime.py | 32 +++- src/vercel_ai_sdk/agents2/tools.py | 29 +++- 5 files changed, 286 insertions(+), 8 deletions(-) create mode 100644 src/vercel_ai_sdk/agents2/context.py diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py index aa4941be..56cb77cb 100644 --- a/src/vercel_ai_sdk/agents2/__init__.py +++ b/src/vercel_ai_sdk/agents2/__init__.py @@ -6,6 +6,7 @@ from . import mcp from .checkpoint import Checkpoint, PendingHookInfo +from .context import Context, ToolSource, get_context from .hooks import Hook, ToolApproval, hook from .runtime import ( EventLog, @@ -25,6 +26,10 @@ "run", "execute_tool", "get_checkpoint", + # Context + "Context", + "ToolSource", + "get_context", # Runtime (composition) "Runtime", "EventLog", diff --git a/src/vercel_ai_sdk/agents2/context.py b/src/vercel_ai_sdk/agents2/context.py new file mode 100644 index 00000000..73dd56b1 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/context.py @@ -0,0 +1,206 @@ +"""Context — everything the LLM sees during a run. + +Consolidates tool registry, system prompt, message history, and model +reference into a single, serializable object. Independent of execution +machinery (Runtime) — can be constructed, inspected, and serialized +without starting a run. + +The context is stashed in a contextvar during ``run()`` so that +framework internals (``execute_tool``, MCP client, etc.) can access it. +""" + +from __future__ import annotations + +import contextvars +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import pydantic + +from ..types import messages as messages_ + +if TYPE_CHECKING: + from . import tools as tools_ + + +# ── ToolSource ──────────────────────────────────────────────────── + + +class ToolSource(pydantic.BaseModel): + """Provenance info for a tool — how to find or reconstruct it. + + Carries enough information to locate the code behind a tool, + whether it's a decorated Python function or an MCP server. + + Attributes: + kind: ``"python"``, ``"mcp_stdio"``, or ``"mcp_http"``. + module: Python module path, e.g. ``"myapp.tools"``. + qualname: Qualified name, e.g. ``"get_weather"``. + uri: Remote URL for HTTP-based MCP servers. + server_command: Launch command for stdio MCP servers. + """ + + model_config = pydantic.ConfigDict(frozen=True) + + kind: str + module: str | None = None + qualname: str | None = None + uri: str | None = None + server_command: str | None = None + + +# ── Context ─────────────────────────────────────────────────────── + + +class Context(pydantic.BaseModel): + """Everything the LLM sees: tools, system prompt, messages, model. + + Independent of execution machinery (Runtime). Constructable by the + user or auto-constructed by ``run()``. + + Usage:: + + ctx = Context( + system_prompt="You are a helpful assistant.", + tools=[get_weather, get_population], + ) + ctx.get_tool("get_weather") # look up by name + data = ctx.model_dump() # serializable snapshot + """ + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + model: Any = None + system_prompt: str = "" + messages: list[messages_.Message] = pydantic.Field(default_factory=list) + + _tools: dict[str, tools_.Tool[..., Any]] = pydantic.PrivateAttr( + default_factory=dict + ) + + def __init__( + self, + *, + tools: Sequence[tools_.Tool[..., Any]] | None = None, + **data: Any, + ) -> None: + super().__init__(**data) + if tools: + for t in tools: + self.register_tool(t) + + # ── Tool registry (scoped to this context) ──────────────── + + def register_tool(self, tool: tools_.Tool[..., Any]) -> None: + """Register a tool in this context's scoped registry.""" + self._tools[tool.name] = tool + + def get_tool(self, name: str) -> tools_.Tool[..., Any] | None: + """Look up a tool by name. Returns ``None`` if not found.""" + return self._tools.get(name) + + @property + def tools(self) -> list[tools_.Tool[..., Any]]: + """All tools registered in this context.""" + return list(self._tools.values()) + + @property + def tool_schemas(self) -> list[tools_.ToolSchema]: + """Tool schemas — what gets sent to the LLM.""" + return [t.schema for t in self._tools.values()] + + # ── Serialization ───────────────────────────────────────── + + @pydantic.model_serializer + def _serialize(self) -> dict[str, Any]: + """Serialize including tool schemas and sources. + + Tool code is not serialized — only schemas and source + references. + """ + return { + "system_prompt": self.system_prompt, + "messages": [m.model_dump() for m in self.messages], + "tools": [ + { + "schema": t.schema.model_dump(), + "source": (t.source.model_dump() if t.source is not None else None), + } + for t in self._tools.values() + ], + } + + @pydantic.model_validator(mode="wrap") + @classmethod + def _validate( + cls, + data: Any, + handler: pydantic.ValidatorFunctionWrapHandler, + ) -> Context: + """Reconstruct from serialized form or pass through normal init. + + When deserializing, tools are schema-only (not executable) + unless their sources can be resolved from the global registry. + """ + # Normal construction (already a Context, or keyword args without + # a ``tools`` key that looks like serialized tool dicts). + if isinstance(data, cls): + return data + if not isinstance(data, dict) or "tools" not in data: + result: Context = handler(data) + return result + + # Check whether tools contains serialized dicts (from model_dump) + # vs. live Tool objects (from normal __init__). + tools_value = data["tools"] + if tools_value and isinstance(tools_value[0], dict): + return cls._from_serialized(data) + + # Live Tool objects — let the normal init path handle it. + result = handler(data) + return result + + @classmethod + def _from_serialized(cls, data: dict[str, Any]) -> Context: + """Reconstruct from ``model_dump()`` output.""" + from . import tools as tools_ + + ctx = cls( + system_prompt=data.get("system_prompt", ""), + messages=[ + messages_.Message.model_validate(m) for m in data.get("messages", []) + ], + ) + + for tool_data in data.get("tools", []): + schema = tools_.ToolSchema.model_validate(tool_data["schema"]) + source_data = tool_data.get("source") + source = ToolSource(**source_data) if source_data else None + + # Try to resolve the tool from the global registry + live_tool = tools_.get_tool(schema.name) + if live_tool is not None: + ctx.register_tool(live_tool) + else: + # Schema-only placeholder — inspectable but not executable + placeholder = tools_.Tool( + fn=tools_._unresolvable_tool_fn(schema.name), + schema=schema, + source=source, + ) + ctx.register_tool(placeholder) + + return ctx + + +# ── Contextvar ──────────────────────────────────────────────────── + +_context: contextvars.ContextVar[Context] = contextvars.ContextVar("context") + + +def get_context() -> Context: + """Get the active Context from the current run. + + Raises ``LookupError`` if called outside of ``ai.run()``. + """ + return _context.get() diff --git a/src/vercel_ai_sdk/agents2/mcp/client.py b/src/vercel_ai_sdk/agents2/mcp/client.py index c17a25a0..def1f0e6 100644 --- a/src/vercel_ai_sdk/agents2/mcp/client.py +++ b/src/vercel_ai_sdk/agents2/mcp/client.py @@ -14,6 +14,7 @@ import mcp.client.streamable_http import mcp.types +from .. import context as context_ from .. import tools as tools_ __all__ = [ @@ -243,11 +244,30 @@ def _mcp_tool_to_native( return_type=Any, ) + # Determine source provenance from connection key + if connection_key.startswith("http:"): + source = context_.ToolSource( + kind="mcp_http", + uri=connection_key.removeprefix("http:"), + ) + elif connection_key.startswith("stdio:"): + source = context_.ToolSource( + kind="mcp_stdio", + server_command=connection_key.removeprefix("stdio:"), + ) + else: + source = context_.ToolSource(kind="mcp") + t = tools_.Tool( fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), schema=schema, + source=source, ) - # Register so execute_tool() can find it by name + + # Register on active Context if available, else fall back to global + ctx = context_._context.get(None) + if ctx is not None: + ctx.register_tool(t) tools_._tool_registry[name] = t return t diff --git a/src/vercel_ai_sdk/agents2/runtime.py b/src/vercel_ai_sdk/agents2/runtime.py index 7e078894..cca8fb8e 100644 --- a/src/vercel_ai_sdk/agents2/runtime.py +++ b/src/vercel_ai_sdk/agents2/runtime.py @@ -13,6 +13,7 @@ from ..telemetry import events as telemetry_ from ..types import messages as messages_ from . import checkpoint as checkpoint_ +from . import context as context_ from . import hooks as hooks_ from . import mcp from . import streams as streams_ @@ -288,10 +289,11 @@ async def execute_tool( ) -> Any: """Execute a single tool call with replay support. - Looks up the tool by name from the global registry, executes it, - and updates the ToolPart (and parent Message) with the result. - Emits the updated message to the LoopExecutor queue so the UI sees - the transition from status="pending" to status="result" (or "error"). + Looks up the tool by name — first from the active Context (if any), + then from the global registry. Executes it and updates the ToolPart + (and parent Message) with the result. Emits the updated message to + the LoopExecutor queue so the UI sees the transition from + status="pending" to status="result" (or "error"). If a checkpoint exists with a cached result for this tool_call_id, returns the cached result without re-executing. @@ -317,8 +319,13 @@ async def execute_tool( ) t0 = telemetry_.time_ms() - # Fresh execution - tool = tools_.get_tool(tool_call.tool_name) + # Fresh execution — resolve from Context first, then global registry + tool: tools_.Tool[..., Any] | None = None + ctx = context_._context.get(None) + if ctx is not None: + tool = ctx.get_tool(tool_call.tool_name) + if tool is None: + tool = tools_.get_tool(tool_call.tool_name) if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") @@ -418,6 +425,7 @@ def run( root: Callable[..., Coroutine[Any, Any, Any]], *args: Any, checkpoint: checkpoint_.Checkpoint | None = None, + context: context_.Context | None = None, ) -> RunResult: """Main entry point. @@ -426,6 +434,13 @@ def run( 3. Executes each step, yielding messages 4. Resolves or suspends hooks depending on the hook's cancels_future 5. Returns RunResult with .checkpoint and .pending_hooks + + Args: + root: The loop function to execute. + *args: Positional arguments forwarded to ``root``. + checkpoint: Checkpoint to resume from. + context: LLM prompt context (tools, system prompt, messages). + If ``None``, an empty Context is created automatically. """ result = RunResult() @@ -455,6 +470,10 @@ async def _generate() -> AsyncGenerator[messages_.Message]: rt = Runtime(checkpoint=effective_checkpoint) result._runtime = rt token_runtime = _runtime.set(rt) + + ctx = context or context_.Context() + token_context = context_._context.set(ctx) + token_run_id = telemetry_.start_run() telemetry_.handle(telemetry_.RunStartEvent()) @@ -584,6 +603,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: await mcp.client.close_connections() mcp.client._pool.reset(mcp_token) + context_._context.reset(token_context) _runtime.reset(token_runtime) result._messages = _generate() diff --git a/src/vercel_ai_sdk/agents2/tools.py b/src/vercel_ai_sdk/agents2/tools.py index f4e3744b..39a9aa28 100644 --- a/src/vercel_ai_sdk/agents2/tools.py +++ b/src/vercel_ai_sdk/agents2/tools.py @@ -9,6 +9,7 @@ from ..types.tools import ToolLike as ToolLike from ..types.tools import ToolSchema as ToolSchema +from .context import ToolSource if TYPE_CHECKING: from . import runtime as runtime_ @@ -36,10 +37,12 @@ def __init__( fn: Callable[P, Awaitable[R]], schema: ToolSchema, validator: type[pydantic.BaseModel] | None = None, + source: ToolSource | None = None, ) -> None: self._fn = fn self._validator = validator self.schema = schema + self.source = source async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return await self._fn(*args, **kwargs) @@ -102,8 +105,32 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: return_type=hints.get("return", None), ) - t = Tool(fn=fn, schema=schema, validator=validator) + source = ToolSource( + kind="python", + module=getattr(fn, "__module__", None), + qualname=getattr(fn, "__qualname__", None), + ) + + t = Tool(fn=fn, schema=schema, validator=validator, source=source) # 3. register in global registry _tool_registry[t.name] = t return t + + +def _unresolvable_tool_fn(name: str) -> Callable[..., Awaitable[Any]]: + """Create a placeholder async function for schema-only tools. + + Used by ``Context.from_dict()`` when a tool's source cannot be + resolved to live code. + """ + + async def _placeholder(**kwargs: Any) -> Any: + raise RuntimeError( + f"Tool {name!r} was reconstructed from serialized context " + f"and has no executable implementation." + ) + + _placeholder.__name__ = name + _placeholder.__qualname__ = name + return _placeholder From c58be3a0c5916c545fc732d96032975893355c8b Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 15:50:58 -0700 Subject: [PATCH 15/18] Outline the agent api in the agents module --- src/vercel_ai_sdk/agents2/__init__.py | 5 + src/vercel_ai_sdk/agents2/agent.py | 246 ++++++++++++++++++++++++++ tests/conftest.py | 2 +- 3 files changed, 252 insertions(+), 1 deletion(-) create mode 100644 src/vercel_ai_sdk/agents2/agent.py diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py index 56cb77cb..30b65791 100644 --- a/src/vercel_ai_sdk/agents2/__init__.py +++ b/src/vercel_ai_sdk/agents2/__init__.py @@ -5,6 +5,7 @@ """ from . import mcp +from .agent import Agent, AgentRun, agent from .checkpoint import Checkpoint, PendingHookInfo from .context import Context, ToolSource, get_context from .hooks import Hook, ToolApproval, hook @@ -22,6 +23,10 @@ from .tools import Tool, ToolLike, ToolSchema, get_tool, tool __all__ = [ + # Agent (primary user API) + "Agent", + "AgentRun", + "agent", # Core loop "run", "execute_tool", diff --git a/src/vercel_ai_sdk/agents2/agent.py b/src/vercel_ai_sdk/agents2/agent.py new file mode 100644 index 00000000..39165994 --- /dev/null +++ b/src/vercel_ai_sdk/agents2/agent.py @@ -0,0 +1,246 @@ +"""Agent — the primary user-facing API. + +Bundles model, system prompt, and tools into a reusable, composable +unit. Provides a default tool-calling loop and a decorator to +override it. + +Usage:: + + agent = ai.agent( + model=my_model, + system="You are a helpful assistant.", + tools=[get_weather, get_population], + ) + + # stream messages + async for msg in agent.run(messages): + print(msg.text_delta, end="") + + # or collect the final result + result = await agent.run(messages).collect() + print(result.text) +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable +from typing import Any + +from .. import models2 +from ..types import messages as messages_ +from . import checkpoint as checkpoint_ +from . import context as context_ +from . import runtime as runtime_ +from . import streams as streams_ +from . import tools as tools_ + +# ── Types ───────────────────────────────────────────────────────── + +LoopFn = Callable[["Agent", list[messages_.Message]], Awaitable[streams_.StreamResult]] + + +# ── Default loop primitives ─────────────────────────────────────── + + +@streams_.stream +async def _stream_step( + model: models2.Model, + messages: list[messages_.Message], + tools: list[tools_.Tool[..., Any]], +) -> AsyncGenerator[messages_.Message]: + """Single LLM call that streams into the Runtime queue.""" + async for msg in models2.stream(model, messages, tools=tools): + yield msg + + +async def _default_loop( + agent: Agent, messages: list[messages_.Message] +) -> streams_.StreamResult: + """Default agent loop: stream LLM, execute tools, repeat.""" + local_messages = list(messages) + + while True: + result = await _stream_step(agent.model, local_messages, agent.tools) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *(runtime_.execute_tool(tc, message=last_msg) for tc in result.tool_calls) + ) + + +# ── AgentRun ────────────────────────────────────────────────────── + + +class AgentRun: + """Returned by ``agent.run()``. Async-iterate for messages, then + inspect post-run state. + + Usage:: + + run = agent.run(messages) + + # streaming + async for msg in run: + print(msg.text_delta, end="") + run.checkpoint # checkpoint after iteration + run.pending_hooks # unresolved hooks (empty if completed) + + # non-streaming + result = await agent.run(messages).collect() + print(result.text) + """ + + def __init__(self, inner: runtime_.RunResult) -> None: + self._inner = inner + + async def __aiter__(self) -> AsyncGenerator[messages_.Message]: + async for msg in self._inner: + yield msg + + async def collect(self) -> streams_.StreamResult: + """Drain the stream and return a :class:`StreamResult`.""" + msgs: list[messages_.Message] = [] + async for msg in self._inner: + msgs.append(msg) + return streams_.StreamResult(messages=msgs) + + @property + def checkpoint(self) -> checkpoint_.Checkpoint: + return self._inner.checkpoint + + @property + def pending_hooks(self) -> dict[str, runtime_.HookInfo]: + return self._inner.pending_hooks + + +# ── Agent ───────────────────────────────────────────────────────── + + +class Agent: + """An agent — bundles model, system prompt, tools, and loop logic. + + Create via :func:`agent`:: + + weather = ai.agent( + model=my_model, + system="Answer questions about weather.", + tools=[get_weather], + ) + + Tools default to all globally registered tools when ``None`` + (the default). Pass ``tools=[]`` to explicitly disable tools. + + Override the default tool-calling loop with ``@agent.loop``:: + + @weather.loop + async def custom(agent, messages): + ... + """ + + def __init__( + self, + model: models2.Model, + system: str = "", + tools: list[tools_.Tool[..., Any]] | None = None, + ) -> None: + self._model = model + self._system = system + self._tools = tools + self._custom_loop: LoopFn | None = None + + @property + def model(self) -> models2.Model: + return self._model + + @property + def system(self) -> str: + return self._system + + @property + def tools(self) -> list[tools_.Tool[..., Any]]: + """Registered tools. ``None`` at init resolves to all globally + registered tools at access time.""" + if self._tools is None: + return list(tools_._tool_registry.values()) + return list(self._tools) + + def loop(self, fn: LoopFn) -> LoopFn: + """Decorator to override the default agent loop. + + The decorated function receives the :class:`Agent` instance and + the per-run messages:: + + @my_agent.loop + async def custom( + agent: ai.Agent, messages: list[ai.Message], + ) -> ai.StreamResult: + ... + """ + self._custom_loop = fn + return fn + + def run( + self, + messages: list[messages_.Message], + *, + checkpoint: checkpoint_.Checkpoint | None = None, + ) -> AgentRun: + """Run the agent. + + Returns an :class:`AgentRun` — async-iterate for streamed + messages, or call ``.collect()`` for the final result. + + Args: + messages: Conversation messages (user, assistant, etc.). + checkpoint: Resume from a previous checkpoint. + """ + # Prepend system prompt + full_messages: list[messages_.Message] = [] + if self._system: + full_messages.append( + messages_.Message( + role="system", + parts=[messages_.TextPart(text=self._system)], + ) + ) + full_messages.extend(messages) + + loop_fn = self._custom_loop or _default_loop + ctx = context_.Context(tools=self.tools) + + # Build the graph function that runtime_.run() expects + async def _graph() -> streams_.StreamResult: + return await loop_fn(self, full_messages) + + inner = runtime_.run( + _graph, + checkpoint=checkpoint, + context=ctx, + ) + return AgentRun(inner) + + +# ── Factory ─────────────────────────────────────────────────────── + + +def agent( + model: models2.Model, + system: str = "", + tools: list[tools_.Tool[..., Any]] | None = None, +) -> Agent: + """Create an :class:`Agent`. + + Args: + model: The language model to use. + system: System prompt. + tools: Tools available to the agent. ``None`` (default) means + all globally registered tools. Pass ``[]`` to disable. + """ + return Agent(model=model, system=system, tools=tools) diff --git a/tests/conftest.py b/tests/conftest.py index 4a3d0363..3ae4489d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -126,7 +126,7 @@ async def stream_events( messages: list[messages_.Message], tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[Any, None]: + ) -> AsyncGenerator[Any]: from vercel_ai_sdk.models.core import llm as llm_ if self._call_index >= len(self._responses): From 2157015c3616c01a6243eaf6fd891bb6312d4292 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 16:35:46 -0700 Subject: [PATCH 16/18] Port tests and examples to agent2 --- examples/fastapi-vite/backend/agent.py | 22 +- examples/fastapi-vite/backend/main.py | 10 +- examples/multiagent-textual/server.py | 89 ++++--- examples/samples/custom_loop.py | 62 +++-- examples/samples/hooks.py | 78 +++--- examples/samples/mcp_tools.py | 26 +- examples/samples/media/image_gen_inline.py | 33 ++- examples/samples/media/multimodal.py | 28 +-- examples/samples/multiagent.py | 81 +++--- examples/samples/simple.py | 25 +- examples/samples/streaming_tool.py | 25 +- examples/temporal-durable/workflow.py | 54 +++- src/vercel_ai_sdk/__init__.py | 47 ++-- .../adapters/ai_sdk_ui/adapter.py | 2 +- src/vercel_ai_sdk/agents2/__init__.py | 10 +- src/vercel_ai_sdk/agents2/agent.py | 81 +++--- tests/adapters/ai_sdk_ui/test_adapter.py | 47 ++-- tests/agents2/__init__.py | 0 tests/agents2/mcp/__init__.py | 0 tests/agents2/mcp/test_client.py | 108 ++++++++ tests/agents2/test_checkpoint.py | 207 ++++++++++++++++ tests/agents2/test_hooks.py | 185 ++++++++++++++ tests/agents2/test_runtime.py | 232 ++++++++++++++++++ tests/agents2/test_streams.py | 113 +++++++++ tests/agents2/test_tools.py | 110 +++++++++ tests/telemetry/test_otel_handler.py | 16 +- tests/telemetry/test_telemetry.py | 62 ++--- 27 files changed, 1413 insertions(+), 340 deletions(-) create mode 100644 tests/agents2/__init__.py create mode 100644 tests/agents2/mcp/__init__.py create mode 100644 tests/agents2/mcp/test_client.py create mode 100644 tests/agents2/test_checkpoint.py create mode 100644 tests/agents2/test_hooks.py create mode 100644 tests/agents2/test_runtime.py create mode 100644 tests/agents2/test_streams.py create mode 100644 tests/agents2/test_tools.py diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 023a66d2..7250fe19 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -16,10 +16,11 @@ async def talk_to_mothership(question: str) -> str: return f"Mothership says: {question} -> Soon." -def get_llm() -> ai.LanguageModel: - """Create the LLM instance.""" - return ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - +MODEL = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", +) TOOLS: list[ai.Tool[..., Any]] = [talk_to_mothership] @@ -43,10 +44,17 @@ async def _execute_with_approval( tc.set_error("Tool call was denied by the user.") +chat_agent = ai.agent( + model=MODEL, + system="", + tools=TOOLS, +) + + +@chat_agent.loop async def graph( - llm: ai.LanguageModel, + agent: ai.Agent, messages: list[ai.Message], - tools: list[ai.Tool[..., Any]], ) -> ai.StreamResult: """Agent graph with human-in-the-loop tool approval. @@ -58,7 +66,7 @@ async def graph( local_messages = list(messages) while True: - result = await ai.stream_step(llm, local_messages, tools) + result = await ai.stream_step(agent.model, local_messages, agent.tools) if not result.tool_calls: return result diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 9107adb1..0d3c31b0 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -51,20 +51,12 @@ async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: session_id = request.session_id or "default" checkpoint_key = f"checkpoint:{session_id}" - llm = agent.get_llm() - checkpoint = None saved = await file_storage.get(checkpoint_key) if saved: checkpoint = ai.Checkpoint.model_validate(saved) - result = ai.run( - agent.graph, - llm, - messages, - agent.TOOLS, - checkpoint=checkpoint, - ) + result = agent.chat_agent.run(messages, checkpoint=checkpoint) async def stream_response() -> AsyncGenerator[str]: async for chunk in ai.ai_sdk_ui.to_sse_stream(result): diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index fd87276a..a49c8602 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -54,20 +54,39 @@ class Approval(pydantic.BaseModel): # --------------------------------------------------------------------------- -# Sub-agent branches +# Model # --------------------------------------------------------------------------- +MODEL = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", +) -async def mothership_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: + +# --------------------------------------------------------------------------- +# Sub-agent branches (implemented as custom loops on per-branch agents) +# --------------------------------------------------------------------------- + + +mothership_agent = ai.agent( + model=MODEL, + system="You are assistant 1. Use contact_mothership when asked about the future.", + tools=[contact_mothership], +) + + +@mothership_agent.loop +async def mothership_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Agent that contacts the mothership, gated by an approval hook.""" - messages = ai.make_messages( - system="You are assistant 1. Use contact_mothership when asked about the future.", - user=query, - ) - tools = [contact_mothership] + local_messages = list(messages) while True: - result = await ai.stream_step(llm, messages, tools, label="mothership") + result = await ai.stream_step( + agent.model, local_messages, agent.tools, label="mothership" + ) if not result.tool_calls: break @@ -89,21 +108,29 @@ async def mothership_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResul await ai.execute_tool(tc, message=result.last_message) if result.last_message is not None: - messages.append(result.last_message) + local_messages.append(result.last_message) return result -async def data_center_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: +data_center_agent = ai.agent( + model=MODEL, + system="You are assistant 2. Use contact_data_centers when asked about the future.", + tools=[contact_data_centers], +) + + +@data_center_agent.loop +async def data_center_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Agent that contacts data centers, gated by an approval hook.""" - messages = ai.make_messages( - system="You are assistant 2. Use contact_data_centers when asked about the future.", - user=query, - ) - tools = [contact_data_centers] + local_messages = list(messages) while True: - result = await ai.stream_step(llm, messages, tools, label="data_centers") + result = await ai.stream_step( + agent.model, local_messages, agent.tools, label="data_centers" + ) if not result.tool_calls: break @@ -125,34 +152,42 @@ async def data_center_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResu await ai.execute_tool(tc, message=result.last_message) if result.last_message is not None: - messages.append(result.last_message) + local_messages.append(result.last_message) return result # --------------------------------------------------------------------------- -# Graph — fan-out, hooks, fan-in +# Orchestrator — fan-out, hooks, fan-in # --------------------------------------------------------------------------- -async def multiagent(llm: ai.LanguageModel, query: str) -> ai.StreamResult: +orchestrator = ai.agent(model=MODEL) + + +@orchestrator.loop +async def multiagent_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Run two gated agents in parallel, then summarise their results.""" + query = messages[-1].text + + # Fan out: run both sub-agent loops within this runtime r1, r2 = await asyncio.gather( - mothership_branch(llm, query), - data_center_branch(llm, query), + mothership_loop(mothership_agent, ai.make_messages(user=query)), + data_center_loop(data_center_agent, ai.make_messages(user=query)), ) combined = ( f"Mothership: {r1.messages[-1].text}\nData centers: {r2.messages[-1].text}" ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( + return await ai.stream_step( + agent.model, + ai.make_messages( system="You are assistant 3. Summarise the results from the other assistants.", user=combined, ), - tools=[], label="summary", ) @@ -180,9 +215,7 @@ async def ws_endpoint(websocket: fastapi.WebSocket) -> None: await websocket.accept() print("Client connected") - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - result = ai.run(multiagent, llm, "When will the robots take over?") + result = orchestrator.run(ai.make_messages(user="When will the robots take over?")) # Background task: read hook resolutions from the client. async def read_resolutions() -> None: diff --git a/examples/samples/custom_loop.py b/examples/samples/custom_loop.py index 9c6d9622..082f5ac1 100644 --- a/examples/samples/custom_loop.py +++ b/examples/samples/custom_loop.py @@ -23,50 +23,60 @@ async def get_population(city: str) -> int: @ai.stream async def custom_stream_step( - llm: ai.LanguageModel, + model: ai.Model, messages: list[ai.Message], tools: list[ai.Tool[..., Any]], label: str | None = None, ) -> AsyncGenerator[ai.Message]: - """Wraps llm.stream to inject a label on every message.""" - async for msg in llm.stream(messages=messages, tools=tools): + """Wraps models2.stream to inject a label on every message.""" + async for msg in ai.models2.stream(model, messages, tools=tools): msg.label = label yield msg -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Custom agent loop with manual tool execution. +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) - Uses @ai.stream for custom streaming, stream_step-style while loop, - and asyncio.gather for parallel tool execution. - """ - tools = [get_weather, get_population] - messages = ai.make_messages( + my_agent = ai.agent( + model=model, system="Answer questions using the weather and population tools.", - user=user_query, + tools=[get_weather, get_population], ) - while True: - result = await custom_stream_step(llm, messages, tools, label="agent") + @my_agent.loop + async def custom(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: + """Custom agent loop with manual tool execution. - if not result.tool_calls: - return result + Uses @ai.stream for custom streaming and + asyncio.gather for parallel tool execution. + """ + local_messages = list(messages) - if result.last_message is not None: - messages.append(result.last_message) - await asyncio.gather( - *( - ai.execute_tool(tc, message=result.last_message) - for tc in result.tool_calls + while True: + result = await custom_stream_step( + agent.model, local_messages, agent.tools, label="agent" ) - ) + if not result.tool_calls: + return result -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + if result.last_message is not None: + local_messages.append(result.last_message) + await asyncio.gather( + *( + ai.execute_tool(tc, message=result.last_message) + for tc in result.tool_calls + ) + ) - async for msg in ai.run( - agent, llm, "What's the weather and population of New York and Los Angeles?" + async for msg in my_agent.run( + ai.make_messages( + user="What's the weather and population of New York and Los Angeles?" + ) ): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 4b11bc3f..33bbe282 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -19,46 +19,56 @@ class CommunicationApproval(pydantic.BaseModel): reason: str -async def graph(llm: ai.LanguageModel, query: str) -> ai.StreamResult: - messages = ai.make_messages( - system="Use the contact_mothership tool when asked about the future.", - user=query, +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) - tools = [contact_mothership] - - while True: - result = await ai.stream_step(llm, messages, tools) - - if not result.tool_calls: - break - - for tc in result.tool_calls: - if tc.tool_name == "contact_mothership": - # Blocks until resolved (long-running) or cancelled (serverless) - # TODO: mypy doesn't support class decorators that change the - # class type — @ai.hook returns type[Hook[T]] but mypy still - # sees the original BaseModel. - approval = await CommunicationApproval.create( # type: ignore[attr-defined] - f"approve_{tc.tool_call_id}", - metadata={"tool": tc.tool_name}, - ) - if approval.granted: - await ai.execute_tool(tc, message=result.last_message) - else: - tc.set_error(f"Rejected: {approval.reason}") - else: - await ai.execute_tool(tc, message=result.last_message) - if result.last_message is not None: - messages.append(result.last_message) + my_agent = ai.agent( + model=model, + system="Use the contact_mothership tool when asked about the future.", + tools=[contact_mothership], + ) - return result + @my_agent.loop + async def with_approval( + agent: ai.Agent, messages: list[ai.Message] + ) -> ai.StreamResult: + local_messages = list(messages) + + while True: + result = await ai.stream_step(agent.model, local_messages, agent.tools) + + if not result.tool_calls: + break + + for tc in result.tool_calls: + if tc.tool_name == "contact_mothership": + # Blocks until resolved (long-running) or cancelled (serverless) + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + approval = await CommunicationApproval.create( # type: ignore[attr-defined] + f"approve_{tc.tool_call_id}", + metadata={"tool": tc.tool_name}, + ) + if approval.granted: + await ai.execute_tool(tc, message=result.last_message) + else: + tc.set_error(f"Rejected: {approval.reason}") + else: + await ai.execute_tool(tc, message=result.last_message) + if result.last_message is not None: + local_messages.append(result.last_message) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + return result - async for msg in ai.run(graph, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): # Hook parts arrive as pending, waiting for resolution if (hook := msg.get_hook_part()) and hook.status == "pending": answer = input(f"Approve {hook.hook_id}? [y/n] ") diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index 1a2f58d1..aa020560 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -9,8 +9,12 @@ import vercel_ai_sdk as ai -async def context7_agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Agent with Context7 MCP tools for up-to-date library documentation.""" +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) context7_tools: list[ai.Tool[..., Any]] = await ai.mcp.get_http_tools( "https://mcp.context7.com/mcp", @@ -18,22 +22,14 @@ async def context7_agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamRes tool_prefix="context7", ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are a helpful assistant. Use context7 to look up documentation.", - user=user_query, - ), + my_agent = ai.agent( + model=model, + system="You are a helpful assistant. Use context7 to look up documentation.", tools=context7_tools, - label="context7", ) - -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run( - context7_agent, llm, "How do I create middleware in Next.js?" + async for msg in my_agent.run( + ai.make_messages(user="How do I create middleware in Next.js?") ): rich.print(msg) diff --git a/examples/samples/media/image_gen_inline.py b/examples/samples/media/image_gen_inline.py index c23b94fc..190ef936 100644 --- a/examples/samples/media/image_gen_inline.py +++ b/examples/samples/media/image_gen_inline.py @@ -15,24 +15,13 @@ import vercel_ai_sdk as ai -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system=( - "You are an anime art assistant. When asked to draw or create " - "an image, generate it in a soft pastel anime style with " - "detailed backgrounds and expressive characters." - ), - user=user_query, - ), - tools=[], - ) - - async def main() -> None: # Gemini 3 Pro Image is a language model that can output images inline - llm = ai.ai_gateway.GatewayModel(model="google/gemini-3-pro-image") + model = ai.Model( + id="google/gemini-3-pro-image", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) prompt = ( "Draw an anime girl with long silver hair and violet eyes, " @@ -40,7 +29,17 @@ async def main() -> None: "She's wearing a traditional kimono and reading a book." ) - async for msg in ai.run(agent, llm, prompt): + my_agent = ai.agent( + model=model, + system=( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style with " + "detailed backgrounds and expressive characters." + ), + tools=[], + ) + + async for msg in my_agent.run(ai.make_messages(user=prompt)): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/media/multimodal.py b/examples/samples/media/multimodal.py index cad74e55..2f2348ce 100644 --- a/examples/samples/media/multimodal.py +++ b/examples/samples/media/multimodal.py @@ -13,26 +13,26 @@ ) -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=[ +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) + + my_agent = ai.agent(model=model, tools=[]) + + async for msg in my_agent.run( + [ ai.Message( role="user", parts=[ - ai.TextPart(text=user_query), + ai.TextPart(text="What's in this image? Be concise."), ai.FilePart.from_url(IMAGE_URL), ], ) - ], - tools=[], - ) - - -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "What's in this image? Be concise."): + ] + ): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/multiagent.py b/examples/samples/multiagent.py index efb042d3..aaf8bc8e 100644 --- a/examples/samples/multiagent.py +++ b/examples/samples/multiagent.py @@ -15,47 +15,60 @@ async def multiply_by_two(number: int) -> int: return number * 2 -async def multiagent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Run two agents in parallel, then combine their results.""" - - result1, result2 = await asyncio.gather( - ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are assistant 1. Use your tool on the number.", - user=user_query, - ), - tools=[add_one], - label="a1", - ), - ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are assistant 2. Use your tool on the number.", - user=user_query, - ), - tools=[multiply_by_two], - label="a2", - ), +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) - combined = f"{result1.messages[-1].text}\n{result2.messages[-1].text}" + agent1 = ai.agent( + model=model, + system="You are assistant 1. Use your tool on the number.", + tools=[add_one], + ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Summarize the results from the other assistants.", - user=combined, - ), - tools=[], - label="summary", + agent2 = ai.agent( + model=model, + system="You are assistant 2. Use your tool on the number.", + tools=[multiply_by_two], ) + orchestrator = ai.agent(model=model) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + @orchestrator.loop + async def multi(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: + """Run two sub-agents in parallel, then summarize.""" + user_query = messages[-1].text + + # Sub-agents run their loops within the same runtime + result1, result2 = await asyncio.gather( + ai.stream_step( + agent1.model, + ai.make_messages(system=agent1.system, user=user_query), + agent1.tools, + label="a1", + ), + ai.stream_step( + agent2.model, + ai.make_messages(system=agent2.system, user=user_query), + agent2.tools, + label="a2", + ), + ) + + combined = f"{result1.text}\n{result2.text}" + + return await ai.stream_step( + agent.model, + ai.make_messages( + system="Summarize the results from the other assistants.", + user=combined, + ), + label="summary", + ) - async for msg in ai.run(multiagent, llm, "Process the number 5"): + async for msg in orchestrator.run(ai.make_messages(user="Process the number 5")): if msg.text_delta: prefix = f"[{msg.label}] " if msg.label else "" print(f"{prefix}{msg.text_delta}", end="", flush=True) diff --git a/examples/samples/simple.py b/examples/samples/simple.py index ca8f09c4..afd0f676 100644 --- a/examples/samples/simple.py +++ b/examples/samples/simple.py @@ -8,21 +8,22 @@ async def talk_to_mothership(question: str) -> str: return "Soon." -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Start every response with 'You are absolutely right!'", - user=user_query, - ), - tools=[talk_to_mothership], +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) + my_agent = ai.agent( + model=model, + system="Start every response with 'You are absolutely right!'", + tools=[talk_to_mothership], + ) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index bae3930c..109d5368 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -21,21 +21,22 @@ async def talk_to_mothership(question: str, runtime: ai.Runtime) -> str: return "The mothership says: Soon." -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Use the mothership tool when asked about the future.", - user=user_query, - ), - tools=[talk_to_mothership], +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) + my_agent = ai.agent( + model=model, + system="Use the mothership tool when asked about the future.", + tools=[talk_to_mothership], + ) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): if msg.label == "tool_progress": print(f" [{msg.text}]") elif msg.text_delta: diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index c09f1eec..cae25e73 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -1,10 +1,16 @@ -"""Temporal workflow — the durable agent loop.""" +"""Temporal workflow — the durable agent loop. + +NOTE: This example still uses the old models.LanguageModel ABC because +it wraps Temporal activities as a custom model. When the models layer +is fully migrated to models2, this will need a custom adapter instead. +""" from __future__ import annotations +import asyncio import datetime from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence -from typing import override +from typing import Any, override import pydantic import temporalio.common @@ -16,7 +22,7 @@ import vercel_ai_sdk as ai -class DurableModel(ai.LanguageModel): +class DurableModel(ai.models.LanguageModel): def __init__( self, call_fn: Callable[ @@ -76,15 +82,45 @@ async def get_population(city: str) -> int: # ── Agent ──────────────────────────────────────────────────────── +# +# TODO: This example uses the old LanguageModel ABC and ai.run() / +# ai.stream_loop free-function patterns. Once the models layer is +# migrated, convert to use ai.agent() + models2.Model with a custom +# adapter for Temporal activity-based LLM calls. + +async def agent(llm: Any, user_query: str) -> ai.StreamResult: + """Agent loop — uses old-style stream_loop via models.LanguageModel. -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Agent loop — identical to the non-Temporal version.""" + This is a transitional pattern. The old ai.stream_loop and ai.run + are no longer part of the public API. This example needs a custom + models2 adapter to work with the new Agent API. + """ messages = ai.make_messages( system="Answer questions using the weather and population tools.", user=user_query, ) - return await ai.stream_loop(llm, messages, [get_weather, get_population]) + + # Manually implement the loop since we can't use Agent with LanguageModel + tools = [get_weather, get_population] + local_messages = list(messages) + + while True: + result_messages: list[ai.Message] = [] + async for msg in llm.stream(local_messages, tools=tools): + result_messages.append(msg) + result = ai.StreamResult(messages=result_messages) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *(ai.execute_tool(tc, message=last_msg) for tc in result.tool_calls) + ) # ── Workflow ───────────────────────────────────────────────────── @@ -103,8 +139,12 @@ async def run(self, user_query: str) -> str: ) ) + # TODO: This uses the old free-function pattern. Once models2 + # supports custom adapters for Temporal, use Agent.run() instead. + from vercel_ai_sdk.agents2 import run + final_text = "" - async for msg in ai.run(agent, llm, user_query): + async for msg in run(agent, llm, user_query): if msg.text: final_text = msg.text return final_text diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index ee63f2a8..6d73752e 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,22 +1,27 @@ from . import adapters, models, models2, telemetry from .adapters import ai_sdk_ui -from .agents import ( +from .agents2 import ( + Agent, + AgentRun, Checkpoint, + Context, Hook, HookInfo, + LoopFn, PendingHookInfo, RunResult, Runtime, StreamResult, Tool, ToolApproval, + ToolSource, + agent, execute_tool, get_checkpoint, + get_context, hook, mcp, - run, stream, - stream_loop, stream_step, tool, ) @@ -63,25 +68,35 @@ "models2", # Legacy (from models/) — kept during transition "models", - # Agents (from agents/) + # Agents — primary API + "Agent", + "AgentRun", + "agent", + "LoopFn", + # Agents — composition primitives + "stream_step", + "execute_tool", + "get_checkpoint", + "stream", + "StreamResult", + # Agents — tools "Tool", + "tool", + # Agents — hooks + "Hook", + "hook", + "ToolApproval", + # Agents — context + "Context", + "ToolSource", + "get_context", + # Agents — runtime (developer API) "Runtime", "RunResult", "HookInfo", - "StreamResult", - "Hook", - "ToolApproval", + # Agents — checkpoint "Checkpoint", "PendingHookInfo", - # Functions (from agents/) - "tool", - "stream", - "stream_step", - "stream_loop", - "execute_tool", - "get_checkpoint", - "run", - "hook", # Submodules "telemetry", "mcp", diff --git a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py index c68ea0d6..e1c108f7 100644 --- a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py @@ -11,7 +11,7 @@ from collections.abc import AsyncGenerator, AsyncIterable from typing import Any, Literal -from ...agents import hooks +from ...agents2 import hooks from ...types import messages as messages_ from . import protocol, ui_message diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py index 30b65791..8a4059d0 100644 --- a/src/vercel_ai_sdk/agents2/__init__.py +++ b/src/vercel_ai_sdk/agents2/__init__.py @@ -5,7 +5,7 @@ """ from . import mcp -from .agent import Agent, AgentRun, agent +from .agent import Agent, AgentRun, LoopFn, agent, stream_step from .checkpoint import Checkpoint, PendingHookInfo from .context import Context, ToolSource, get_context from .hooks import Hook, ToolApproval, hook @@ -27,20 +27,22 @@ "Agent", "AgentRun", "agent", - # Core loop - "run", + "LoopFn", + # Composition primitives + "stream_step", "execute_tool", "get_checkpoint", # Context "Context", "ToolSource", "get_context", - # Runtime (composition) + # Runtime (developer API) "Runtime", "EventLog", "LoopExecutor", "RunResult", "HookInfo", + "run", # Stream "stream", "StreamResult", diff --git a/src/vercel_ai_sdk/agents2/agent.py b/src/vercel_ai_sdk/agents2/agent.py index 39165994..9c312593 100644 --- a/src/vercel_ai_sdk/agents2/agent.py +++ b/src/vercel_ai_sdk/agents2/agent.py @@ -24,9 +24,11 @@ from __future__ import annotations import asyncio -from collections.abc import AsyncGenerator, Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence from typing import Any +import pydantic + from .. import models2 from ..types import messages as messages_ from . import checkpoint as checkpoint_ @@ -37,42 +39,35 @@ # ── Types ───────────────────────────────────────────────────────── -LoopFn = Callable[["Agent", list[messages_.Message]], Awaitable[streams_.StreamResult]] +LoopFn = Callable[ + ["Agent", list[messages_.Message]], Awaitable[streams_.StreamResult | None] +] -# ── Default loop primitives ─────────────────────────────────────── +# ── Composition primitives ──────────────────────────────────────── @streams_.stream -async def _stream_step( +async def stream_step( model: models2.Model, messages: list[messages_.Message], - tools: list[tools_.Tool[..., Any]], + tools: Sequence[tools_.ToolLike] | None = None, + label: str | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, ) -> AsyncGenerator[messages_.Message]: - """Single LLM call that streams into the Runtime queue.""" - async for msg in models2.stream(model, messages, tools=tools): - yield msg - - -async def _default_loop( - agent: Agent, messages: list[messages_.Message] -) -> streams_.StreamResult: - """Default agent loop: stream LLM, execute tools, repeat.""" - local_messages = list(messages) - - while True: - result = await _stream_step(agent.model, local_messages, agent.tools) + """Single LLM call that streams into the Runtime queue. - if not result.tool_calls: - return result - - last_msg = result.last_message - if last_msg is not None: - local_messages.append(last_msg) - - await asyncio.gather( - *(runtime_.execute_tool(tc, message=last_msg) for tc in result.tool_calls) - ) + This is a composition primitive for custom ``@agent.loop`` + functions and multi-agent orchestration. It is decorated with + ``@stream``, so each call becomes a replayable step in the + event log. + """ + async for msg in models2.stream( + model, messages, tools=tools, output_type=output_type, **kwargs + ): + msg.label = label + yield msg # ── AgentRun ────────────────────────────────────────────────────── @@ -186,6 +181,29 @@ async def custom( self._custom_loop = fn return fn + async def _default_loop( + self, messages: list[messages_.Message] + ) -> streams_.StreamResult: + """Built-in loop: stream LLM, execute tools, repeat.""" + local_messages = list(messages) + + while True: + result = await stream_step(self.model, local_messages, self.tools) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *( + runtime_.execute_tool(tc, message=last_msg) + for tc in result.tool_calls + ) + ) + def run( self, messages: list[messages_.Message], @@ -212,12 +230,13 @@ def run( ) full_messages.extend(messages) - loop_fn = self._custom_loop or _default_loop ctx = context_.Context(tools=self.tools) # Build the graph function that runtime_.run() expects - async def _graph() -> streams_.StreamResult: - return await loop_fn(self, full_messages) + async def _graph() -> streams_.StreamResult | None: + if self._custom_loop: + return await self._custom_loop(self, full_messages) + return await self._default_loop(full_messages) inner = runtime_.run( _graph, diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index ed2e2d36..a7910630 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -9,7 +9,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.adapters.ai_sdk_ui import adapter, ui_message -from vercel_ai_sdk.agents import hooks +from vercel_ai_sdk.agents2 import hooks from vercel_ai_sdk.types import messages from ...conftest import MOCK_MODEL, mock_llm, tool_msg @@ -240,22 +240,10 @@ async def get_weather(city: str) -> str: return f"Sunny in {city}" -async def mock_agent( - model: ai.Model, - user_query: str, -) -> ai.StreamResult: - """Agent using stream_loop directly.""" - return await ai.stream_loop( - model, - messages=ai.make_messages(system="You are helpful.", user=user_query), - tools=[get_weather], - ) - - @pytest.mark.asyncio async def test_runtime_tool_roundtrip() -> None: """ - Integration test: run a mock agent loop through ai.run() and verify + Integration test: run an Agent through agent.run() and verify that tool-input-available and tool-output-available events are emitted. This test demonstrates the bug: the runtime yields the message with @@ -263,11 +251,17 @@ async def test_runtime_tool_roundtrip() -> None: executed and the ToolPart has been mutated to status="result". The UI adapter never sees the intermediate status="pending" state. - Root cause: stream_loop appends the message, then executes tools which - mutate the message in-place. The message was already yielded with + Root cause: the default loop appends the message, then executes tools + which mutate the message in-place. The message was already yielded with status="pending", but pydantic models are mutable so when we collect them at the end, we see the mutated state. """ + weather_agent = ai.agent( + model=MOCK_MODEL, + system="You are helpful.", + tools=[get_weather], + ) + # First LLM call: returns a tool call tool_call_response = [ messages.Message( @@ -298,7 +292,9 @@ async def test_runtime_tool_roundtrip() -> None: # Collect all messages from the runtime runtime_messages: list[messages.Message] = [] - async for msg in ai.run(mock_agent, MOCK_MODEL, "What's the weather in London?"): + async for msg in weather_agent.run( + ai.make_messages(user="What's the weather in London?") + ): runtime_messages.append(msg) # Stream through UI adapter @@ -643,12 +639,15 @@ async def dangerous_action(path: str) -> str: """Do something dangerous.""" return f"deleted {path}" - async def graph(model: ai.Model) -> None: - result = await ai.stream_step( - model, - ai.make_messages(system="You are helpful.", user="delete /tmp"), - [dangerous_action], - ) + approval_agent = ai.agent( + model=MOCK_MODEL, + system="You are helpful.", + tools=[dangerous_action], + ) + + @approval_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if not result.tool_calls: return @@ -680,7 +679,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: ) runtime_messages: list[messages.Message] = [] - result = ai.run(graph, MOCK_MODEL) + result = approval_agent.run(ai.make_messages(user="delete /tmp")) async for msg in result: runtime_messages.append(msg) diff --git a/tests/agents2/__init__.py b/tests/agents2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents2/mcp/__init__.py b/tests/agents2/mcp/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/agents2/mcp/test_client.py b/tests/agents2/mcp/test_client.py new file mode 100644 index 00000000..531b16be --- /dev/null +++ b/tests/agents2/mcp/test_client.py @@ -0,0 +1,108 @@ +"""MCP client: tool registration in global registry, end-to-end execution.""" + +import contextlib +from typing import Any + +import mcp.types +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.agents2.mcp.client import _mcp_tool_to_native +from vercel_ai_sdk.agents2.tools import _tool_registry, get_tool + +from ...conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg + + +def _fake_mcp_tool( + name: str = "mcp_echo", description: str = "Echo input" +) -> mcp.types.Tool: + """Build a minimal mcp.types.Tool for testing.""" + return mcp.types.Tool( + name=name, + description=description, + inputSchema={ + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + ) + + +def _noop_transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: + """Dummy transport factory — never actually called in these tests.""" + raise NotImplementedError("should not be called") + + +# -- _mcp_tool_to_native registers in global registry ---------------------- + + +def test_mcp_tool_to_native_registers_in_global_registry() -> None: + """Converting an MCP tool to native registers it in _tool_registry.""" + mcp_tool = _fake_mcp_tool(name="mcp_reg_test") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + + assert native.name == "mcp_reg_test" + assert get_tool("mcp_reg_test") is native + assert _tool_registry["mcp_reg_test"] is native + + +def test_mcp_tool_to_native_with_prefix() -> None: + """Tool prefix is prepended to the name and both name forms are correct.""" + mcp_tool = _fake_mcp_tool(name="echo") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, "ctx7") + + assert native.name == "ctx7_echo" + assert get_tool("ctx7_echo") is native + + +def test_mcp_tool_to_native_schema_preserved() -> None: + """The inputSchema from the MCP tool is passed through as param_schema.""" + mcp_tool = _fake_mcp_tool() + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + + assert native.param_schema == mcp_tool.inputSchema + assert native.description == "Echo input" + + +# -- End-to-end: MCP tool executes through Agent default loop --------------- + + +@pytest.mark.asyncio +async def test_mcp_tool_executes_through_agent() -> None: + """MCP-style tool via _mcp_tool_to_native works with Agent.""" + call_log: list[dict[str, str]] = [] + + async def fake_fn(**kwargs: str) -> str: + call_log.append(kwargs) + return f"echoed: {kwargs.get('text', '')}" + + # Build and register a tool the same way the MCP client does, + # but with a fake fn so we don't need a real MCP server. + mcp_tool = _fake_mcp_tool(name="mcp_e2e_echo") + native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) + # Replace the real fn (which would try to connect) with our fake + native._fn = fake_fn + _tool_registry[native.name] = native + + my_agent = ai.agent(model=MOCK_MODEL, tools=[native]) + + call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] + call2 = [text_msg("Done.", id="msg-2")] + llm = mock_llm([call1, call2]) + + result = my_agent.run(ai.make_messages(user="echo hello")) + msgs = [m async for m in result] + + # Tool was called with the right args + assert len(call_log) == 1 + assert call_log[0] == {"text": "hello"} + + # Tool result is visible in messages + tool_results = [ + m for m in msgs if m.tool_calls and m.tool_calls[0].status == "result" + ] + assert len(tool_results) >= 1 + assert tool_results[0].tool_calls[0].result == "echoed: hello" + + # LLM was called twice (tool call + final text) + assert llm.call_count == 2 diff --git a/tests/agents2/test_checkpoint.py b/tests/agents2/test_checkpoint.py new file mode 100644 index 00000000..1f5f0490 --- /dev/null +++ b/tests/agents2/test_checkpoint.py @@ -0,0 +1,207 @@ +"""Checkpoint replay, hook cancellation/resolution, serialization.""" + +import asyncio +from typing import Any, ClassVar + +import pydantic +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.agents2.checkpoint import Checkpoint, HookEvent, StepEvent, ToolEvent + +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg + + +@ai.hook +class Approval(pydantic.BaseModel): + cancels_future: ClassVar[bool] = True + granted: bool + + +# -- Replay ---------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_step_replay_skips_llm() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) + + llm1 = mock_llm([[text_msg("Hi there!")]]) + result1 = my_agent.run(ai.make_messages(system="test", user="hello")) + [msg async for msg in result1] + assert llm1.call_count == 1 + + cp = result1.checkpoint + llm2 = mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="test", user="hello"), checkpoint=cp) + [msg async for msg in result2] + assert llm2.call_count == 0 + + +@pytest.mark.asyncio +async def test_tool_replay_skips_execution() -> None: + execution_count = 0 + + @ai.tool + async def counting_tool(x: int) -> int: + """Counts calls.""" + nonlocal execution_count + execution_count += 1 + return x + 1 + + my_agent = ai.agent(model=MOCK_MODEL, tools=[counting_tool]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + result = await ai.stream_step(agent.model, msgs, agent.tools) + if result.tool_calls: + await asyncio.gather( + *( + ai.execute_tool(tc, message=result.last_message) + for tc in result.tool_calls + ) + ) + return result + + mock_llm([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) + [msg async for msg in result1] + assert execution_count == 1 + assert result1.checkpoint.tools[0].result == 6 + + execution_count = 0 + mock_llm([]) + result2 = my_agent.run( + ai.make_messages(system="t", user="go"), checkpoint=result1.checkpoint + ) + [msg async for msg in result2] + assert execution_count == 0 + + +# -- Hooks ----------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_cancellation_pending() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) + return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] + + mock_llm([[text_msg("OK")]]) + result = my_agent.run(ai.make_messages(system="t", user="go")) + msgs = [msg async for msg in result] + assert "my_approval" in result.pending_hooks + hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] + assert hook_msgs[0].parts[0].status == "pending" # type: ignore[union-attr] + + +@pytest.mark.asyncio +async def test_hook_resolution_on_reentry() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) + return await Approval.create("my_approval") # type: ignore[attr-defined] + + resp = [text_msg("OK")] + mock_llm([resp]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) + [msg async for msg in result1] + cp = result1.checkpoint + + Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] + mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) + [msg async for msg in result2] + assert len(result2.pending_hooks) == 0 + assert result2.checkpoint.hooks[-1].label == "my_approval" + + +@pytest.mark.asyncio +async def test_parallel_hooks_all_collected() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) + + async def a() -> Any: + return await Approval.create("hook_a") # type: ignore[attr-defined] + + async def b() -> Any: + return await Approval.create("hook_b") # type: ignore[attr-defined] + + async with asyncio.TaskGroup() as tg: + tg.create_task(a()) + tg.create_task(b()) + + mock_llm([[text_msg("OK")]]) + result = my_agent.run(ai.make_messages(system="t", user="go")) + [msg async for msg in result] + assert {"hook_a", "hook_b"} <= set(result.pending_hooks) + + +@pytest.mark.asyncio +async def test_parallel_hooks_resolve_on_reentry() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) + + async def a() -> Any: + return await Approval.create("hook_a") # type: ignore[attr-defined] + + async def b() -> Any: + return await Approval.create("hook_b") # type: ignore[attr-defined] + + async with asyncio.TaskGroup() as tg: + ta = tg.create_task(a()) + tb = tg.create_task(b()) + return ta.result(), tb.result() + + resp = [text_msg("OK")] + mock_llm([resp]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) + [msg async for msg in result1] + cp = result1.checkpoint + + Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] + Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] + mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) + [msg async for msg in result2] + assert len(result2.pending_hooks) == 0 + + +# -- Serialization --------------------------------------------------------- + + +def test_checkpoint_serialization_roundtrip() -> None: + cp = Checkpoint( + steps=[ + StepEvent( + index=0, + messages=[ + ai.Message( + id="m1", + role="assistant", + parts=[ai.TextPart(text="hi")], + ) + ], + ) + ], + tools=[ToolEvent(tool_call_id="tc-1", result=42)], + hooks=[HookEvent(label="h1", resolution={"granted": True})], + ) + cp2 = Checkpoint.model_validate(cp.model_dump()) + assert cp2.steps[0].index == 0 + assert cp2.tools[0].result == 42 + assert cp2.hooks[0].label == "h1" diff --git a/tests/agents2/test_hooks.py b/tests/agents2/test_hooks.py new file mode 100644 index 00000000..65bfd15b --- /dev/null +++ b/tests/agents2/test_hooks.py @@ -0,0 +1,185 @@ +"""Hooks: live resolution, cancellation, pre-registration, schema validation.""" + +import asyncio +from typing import Any, ClassVar + +import pydantic +import pytest + +import vercel_ai_sdk as ai + +from ..conftest import MOCK_MODEL, mock_llm, text_msg + + +@ai.hook +class Confirmation(pydantic.BaseModel): + approved: bool + reason: str = "" + + +@ai.hook +class CancellingConfirmation(pydantic.BaseModel): + cancels_future: ClassVar[bool] = True + approved: bool + reason: str = "" + + +# -- Hook.resolve() with live future (long-running mode) ------------------- + + +@pytest.mark.asyncio +async def test_resolve_live_future() -> None: + """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" + resolved_value = None + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + nonlocal resolved_value + await ai.stream_step(agent.model, msgs) + result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] + resolved_value = result + + mock_llm([[text_msg("OK")]]) + # Confirmation.cancels_future=False -> long-running mode + run_result = my_agent.run(ai.make_messages(user="go")) + + collected = [] + async for msg in run_result: + collected.append(msg) + # When we see the pending hook message, resolve it + if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + Confirmation.resolve( # type: ignore[attr-defined] + "confirm_1", {"approved": True, "reason": "looks good"} + ) + + assert resolved_value is not None + assert resolved_value.approved is True + assert resolved_value.reason == "looks good" + + +# -- Hook.cancel() -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_cancel_live_hook() -> None: + """Hook.cancel() cancels the future, causing CancelledError in graph.""" + was_cancelled = False + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + nonlocal was_cancelled + await ai.stream_step(agent.model, msgs) + try: + await Confirmation.create("cancel_me") # type: ignore[attr-defined] + except asyncio.CancelledError: + was_cancelled = True + + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) + + async for msg in run_result: + if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + await Confirmation.cancel("cancel_me", reason="denied") # type: ignore[attr-defined] + + assert was_cancelled + + +# -- Hook.cancel() on non-existent label raises ---------------------------- + + +@pytest.mark.asyncio +async def test_cancel_nonexistent_raises() -> None: + with pytest.raises(ValueError, match="No pending hook"): + await Confirmation.cancel("does_not_exist_xyz") # type: ignore[attr-defined] + + +# -- Pre-registration (serverless re-entry) -------------------------------- + + +@pytest.mark.asyncio +async def test_pre_registered_resolution_consumed() -> None: + """Pre-registered resolution is consumed by Hook.create() without suspending.""" + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) + result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] + return result + + # Pre-register BEFORE run + Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] + + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) + [m async for m in run_result] + + # Should have completed with no pending hooks + assert len(run_result.pending_hooks) == 0 + # Hook event should be in checkpoint + assert any(h.label == "pre_reg_1" for h in run_result.checkpoint.hooks) + + +# -- Schema validation on resolve ----------------------------------------- + + +def test_resolve_validates_schema() -> None: + """resolve() with invalid data raises from pydantic validation.""" + # 'approved' is required bool, passing string should raise + with pytest.raises(pydantic.ValidationError): + Confirmation.resolve("schema_test", {"approved": "not_a_bool"}) # type: ignore[attr-defined] + + +# -- Resolved hook emits message ------------------------------------------- + + +@pytest.mark.asyncio +async def test_resolved_hook_emits_message() -> None: + """After resolution, a 'resolved' HookPart message is emitted.""" + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) + await Confirmation.create("emit_test") # type: ignore[attr-defined] + + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) + + msgs = [] + async for msg in run_result: + msgs.append(msg) + if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): + Confirmation.resolve("emit_test", {"approved": False}) # type: ignore[attr-defined] + + hook_msgs = [ + m + for m in msgs + if any(isinstance(p, ai.HookPart) and p.status == "resolved" for p in m.parts) + ] + assert len(hook_msgs) == 1 + assert hook_msgs[0].parts[0].resolution == {"approved": False, "reason": ""} # type: ignore[union-attr] + + +# -- Hook metadata surfaces in pending message ----------------------------- + + +@pytest.mark.asyncio +async def test_hook_metadata_in_pending() -> None: + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) + await CancellingConfirmation.create( # type: ignore[attr-defined] + "meta_test", metadata={"tool": "rm -rf", "path": "/"} + ) + + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) + [m async for m in run_result] + + info = run_result.pending_hooks["meta_test"] + assert info.metadata == {"tool": "rm -rf", "path": "/"} diff --git a/tests/agents2/test_runtime.py b/tests/agents2/test_runtime.py new file mode 100644 index 00000000..b423dc1b --- /dev/null +++ b/tests/agents2/test_runtime.py @@ -0,0 +1,232 @@ +"""Agent default loop, execute_tool, multi-turn, Runtime injection.""" + +import asyncio + +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.agents2.runtime import Runtime +from vercel_ai_sdk.types import messages + +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg + +# -- Tool definitions for tests -------------------------------------------- + + +@ai.tool +async def double(x: int) -> int: + """Double a number.""" + return x * 2 + + +@ai.tool +async def concat(a: str, b: str) -> str: + """Concatenate strings.""" + return a + b + + +# -- Agent default loop: single turn (no tools) ---------------------------- + + +@pytest.mark.asyncio +async def test_agent_text_only() -> None: + """Agent default loop with no tool calls returns after one LLM call.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + llm = mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) + msgs = [m async for m in result] + assert llm.call_count == 1 + assert any(m.text == "Hello!" for m in msgs) + + +# -- Agent default loop: tool call + follow-up ----------------------------- + + +@pytest.mark.asyncio +async def test_agent_tool_then_text() -> None: + """Agent default loop calls tool, feeds result back, gets final text.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] + call2 = [text_msg("The answer is 10.")] + llm = mock_llm([call1, call2]) + + result = my_agent.run(ai.make_messages(user="Double 5")) + msgs = [m async for m in result] + assert llm.call_count == 2 + # Tool should have been executed: 5 * 2 = 10 + tool_results = [ + m for m in msgs if m.tool_calls and m.tool_calls[0].status == "result" + ] + assert len(tool_results) >= 1 + assert tool_results[0].tool_calls[0].result == 10 + + +# -- Agent default loop: multiple tool calls in one message ---------------- + + +@pytest.mark.asyncio +async def test_agent_parallel_tools() -> None: + """LLM returns two tool calls in one message; both execute.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + two_tools = messages.Message( + id="msg-1", + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="double", + tool_args='{"x": 3}', + status="pending", + state="done", + ), + messages.ToolPart( + tool_call_id="tc-2", + tool_name="double", + tool_args='{"x": 7}', + status="pending", + state="done", + ), + ], + ) + call2 = [text_msg("6 and 14", id="msg-2")] + llm = mock_llm([[two_tools], call2]) + + result = my_agent.run(ai.make_messages(user="Double 3 and 7")) + msgs = [m async for m in result] + assert llm.call_count == 2 + # Both tools should have results + tool_result_msgs = [ + m + for m in msgs + if m.tool_calls and any(tc.status == "result" for tc in m.tool_calls) + ] + assert len(tool_result_msgs) >= 1 + + +# -- Agent default loop: multi-turn (tool -> tool -> text) ----------------- + + +@pytest.mark.asyncio +async def test_agent_multi_turn() -> None: + """LLM calls a tool, then calls another tool, then returns text.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double, concat]) + + turn1 = [ + tool_msg(tc_id="tc-1", name="concat", args='{"a": "hello", "b": " world"}') + ] + turn2 = [tool_msg(tc_id="tc-2", name="double", args='{"x": 3}', id="msg-2")] + turn3 = [text_msg("Done: hello world, 6", id="msg-3")] + llm = mock_llm([turn1, turn2, turn3]) + + result = my_agent.run(ai.make_messages(user="Concat then double")) + [m async for m in result] + assert llm.call_count == 3 + + +# -- execute_tool: missing tool raises ------------------------------------ + + +@pytest.mark.asyncio +async def test_execute_tool_missing_raises() -> None: + """execute_tool with unknown tool name raises ValueError. + + Wrapped in ExceptionGroup by TaskGroup. + """ + tc = messages.ToolPart( + tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" + ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.execute_tool(tc) + + mock_llm([]) + result = my_agent.run(ai.make_messages(user="go")) + with pytest.raises(ExceptionGroup) as exc_info: + [m async for m in result] + assert any(isinstance(e, ValueError) for e in exc_info.value.exceptions) + + +# -- execute_tool: Runtime injection --------------------------------------- + + +@pytest.mark.asyncio +async def test_execute_tool_injects_runtime() -> None: + """Tools with a Runtime parameter get the active runtime injected.""" + received_rt = None + + @ai.tool + async def introspect(query: str, rt: Runtime) -> str: + """Tool that inspects runtime.""" + nonlocal received_rt + received_rt = rt + return "ok" + + my_agent = ai.agent(model=MOCK_MODEL, tools=[introspect]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) + if result.tool_calls: + await asyncio.gather( + *( + ai.execute_tool(tc, message=result.last_message) + for tc in result.tool_calls + ) + ) + + call = [tool_msg(tc_id="tc-1", name="introspect", args='{"query": "test"}')] + mock_llm([call]) + result = my_agent.run(ai.make_messages(user="go")) + [m async for m in result] + assert received_rt is not None + assert isinstance(received_rt, Runtime) + + +# -- execute_tool: result updates ToolPart in message ---------------------- + + +@pytest.mark.asyncio +async def test_execute_tool_updates_message() -> None: + """After execute_tool, the ToolPart in the message has status=result.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) + if result.tool_calls: + msg = result.last_message + for tc in result.tool_calls: + await ai.execute_tool(tc, message=msg) + # Verify the tool part was mutated + assert msg is not None + assert msg.tool_calls[0].status == "result" + assert msg.tool_calls[0].result == 10 + + call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] + mock_llm([call]) + result = my_agent.run(ai.make_messages(user="go")) + [m async for m in result] + + +# -- Checkpoint records tools from Agent default loop ---------------------- + + +@pytest.mark.asyncio +async def test_agent_checkpoint_records_tools() -> None: + """Agent default loop's tool executions are recorded in the checkpoint.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 4}')] + call2 = [text_msg("8", id="msg-2")] + mock_llm([call1, call2]) + + result = my_agent.run(ai.make_messages(user="Double 4")) + [m async for m in result] + + cp = result.checkpoint + assert any(t.tool_call_id == "tc-1" and t.result == 8 for t in cp.tools) diff --git a/tests/agents2/test_streams.py b/tests/agents2/test_streams.py new file mode 100644 index 00000000..80a459fa --- /dev/null +++ b/tests/agents2/test_streams.py @@ -0,0 +1,113 @@ +"""@stream decorator: context requirement, replay, queue submission.""" + +import pydantic +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.agents2.streams import StreamResult +from vercel_ai_sdk.types import messages + +from ..conftest import MOCK_MODEL, mock_llm, text_msg + + +class _Weather(pydantic.BaseModel): + city: str + temperature: float + + +# -- StreamResult properties ----------------------------------------------- + + +def test_stream_result_empty() -> None: + r = StreamResult() + assert r.last_message is None + assert r.tool_calls == [] + assert r.text == "" + + +def test_stream_result_last_message() -> None: + m1 = text_msg("first", id="m1") + m2 = text_msg("second", id="m2") + r = StreamResult(messages=[m1, m2]) + last = r.last_message + assert last is not None + assert last.id == "m2" + assert r.text == "second" + + +def test_stream_result_tool_calls() -> None: + m = messages.Message( + id="m1", + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc1", tool_name="t", tool_args="{}", state="done" + ), + messages.ToolPart( + tool_call_id="tc2", tool_name="u", tool_args="{}", state="done" + ), + ], + ) + r = StreamResult(messages=[m]) + assert len(r.tool_calls) == 2 + + +# -- @stream requires Runtime context ------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_outside_run_raises() -> None: + """@stream-decorated fn called without ai.run() should raise.""" + mock_llm([[text_msg("hi")]]) + with pytest.raises(ValueError, match="No Runtime context"): + await ai.stream_step( + MOCK_MODEL, + ai.make_messages(user="test"), + ) + + +# -- @stream replays from checkpoint -------------------------------------- + + +@pytest.mark.asyncio +async def test_stream_step_replays_from_checkpoint() -> None: + """stream_step inside Agent.run with a checkpoint replays without calling LLM.""" + + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) + + # First run + mock_llm([[text_msg("Hi")]]) + r1 = my_agent.run(ai.make_messages(user="hello")) + [msg async for msg in r1] + cp = r1.checkpoint + + # Replay + llm2 = mock_llm([]) + r2 = my_agent.run(ai.make_messages(user="hello"), checkpoint=cp) + [msg async for msg in r2] + assert llm2.call_count == 0 + + +# -- StreamResult.output --------------------------------------------------- + + +def test_stream_result_output_from_last_message() -> None: + """StreamResult.output delegates to the last message's StructuredOutputPart.""" + m = messages.Message( + id="m1", + role="assistant", + parts=[ + messages.TextPart(text="{}", state="done"), + messages.StructuredOutputPart( + data={"city": "SF", "temperature": 62.0}, + output_type_name=f"{_Weather.__module__}.{_Weather.__qualname__}", + ), + ], + ) + r = StreamResult(messages=[text_msg("streaming..."), m]) + assert r.output is not None + assert r.output.city == "SF" diff --git a/tests/agents2/test_tools.py b/tests/agents2/test_tools.py new file mode 100644 index 00000000..ba2d3792 --- /dev/null +++ b/tests/agents2/test_tools.py @@ -0,0 +1,110 @@ +"""@tool decorator: schema extraction, registry, Runtime parameter handling.""" + +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.agents2.runtime import Runtime +from vercel_ai_sdk.agents2.tools import get_tool + +# -- Schema extraction from type hints ------------------------------------ + + +def test_simple_types_produce_correct_schema() -> None: + @ai.tool + async def greet(name: str, count: int) -> str: + """Say hello.""" + return f"Hello {name}" * count + + assert greet.name == "greet" + assert greet.description == "Say hello." + props = greet.param_schema["properties"] + assert props["name"]["type"] == "string" + assert props["count"]["type"] == "integer" + assert set(greet.param_schema["required"]) == {"name", "count"} + + +def test_optional_param_not_required() -> None: + @ai.tool + async def search(query: str, limit: int | None = None) -> str: + """Search.""" + return query + + assert "query" in search.param_schema.get("required", []) + assert "limit" not in search.param_schema.get("required", []) + # limit should still appear in properties + assert "limit" in search.param_schema["properties"] + + +def test_default_value_not_required() -> None: + @ai.tool + async def fetch(url: str, timeout: int = 30) -> str: + """Fetch URL.""" + return url + + assert "url" in search_required(fetch) + assert "timeout" not in search_required(fetch) + + +def test_complex_type_schema() -> None: + @ai.tool + async def send(recipients: list[str], urgent: bool = False) -> str: + """Send message.""" + return "sent" + + props = send.param_schema["properties"] + assert props["recipients"]["type"] == "array" + assert props["recipients"]["items"]["type"] == "string" + + +# -- Runtime parameter skipping ------------------------------------------- + + +def test_runtime_param_excluded_from_schema() -> None: + @ai.tool + async def needs_runtime(query: str, rt: Runtime) -> str: + """Tool that needs runtime.""" + return query + + props = needs_runtime.param_schema["properties"] + assert "rt" not in props + assert "query" in props + assert set(needs_runtime.param_schema.get("required", [])) == {"query"} + + +# -- Registry ------------------------------------------------------------- + + +def test_tool_registered_on_decoration() -> None: + @ai.tool + async def unique_tool_abc() -> str: + """Unique.""" + return "ok" + + assert get_tool("unique_tool_abc") is unique_tool_abc + + +def test_get_tool_returns_none_for_missing() -> None: + assert get_tool("nonexistent_tool_xyz") is None + + +# -- Execution ------------------------------------------------------------ + + +@pytest.mark.asyncio +async def test_tool_fn_is_callable() -> None: + @ai.tool + async def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + result = await add(a=1, b=2) + assert result == 3 + + +# -- Helpers --------------------------------------------------------------- + + +def search_required(tool: ai.Tool[..., object]) -> list[str]: + result = tool.param_schema.get("required", []) + assert isinstance(result, list) + return result diff --git a/tests/telemetry/test_otel_handler.py b/tests/telemetry/test_otel_handler.py index b7f329f5..3304f38d 100644 --- a/tests/telemetry/test_otel_handler.py +++ b/tests/telemetry/test_otel_handler.py @@ -35,14 +35,10 @@ async def double(x: int) -> int: @pytest.mark.asyncio async def test_text_only_spans(spans: InMemorySpanExporter) -> None: """Text-only run produces ai.run > ai.stream span hierarchy.""" - - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Hi"), tools=[] - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] finished = spans.get_finished_spans() @@ -65,11 +61,7 @@ async def root(model: ai.Model) -> ai.StreamResult: @pytest.mark.asyncio async def test_tool_call_spans(spans: InMemorySpanExporter) -> None: """Tool-calling run produces ai.tool spans with correct attributes.""" - - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Double 5"), tools=[double] - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) mock_llm( [ @@ -77,7 +69,7 @@ async def root(model: ai.Model) -> ai.StreamResult: [text_msg("10")], ] ) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Double 5")) [m async for m in result] finished = spans.get_finished_spans() diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index 8950bf63..b1b06eb3 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -55,14 +55,10 @@ async def double(x: int) -> int: @pytest.mark.asyncio async def test_text_only_run_events(handler: RecordingHandler) -> None: """Simplest run emits RunStart, StepStart, StepFinish, RunFinish.""" - - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Hi"), tools=[] - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -80,12 +76,8 @@ async def root(model: ai.Model) -> ai.StreamResult: @pytest.mark.asyncio async def test_tool_call_events(handler: RecordingHandler) -> None: - """Tool-calling run emits tool events between steps with correct payloads.""" - - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Double 5"), tools=[double] - ) + """Tool-calling run emits tool events between steps.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) mock_llm( [ @@ -93,7 +85,7 @@ async def root(model: ai.Model) -> ai.StreamResult: [text_msg("10")], ] ) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Double 5")) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -122,7 +114,7 @@ async def root(model: ai.Model) -> ai.StreamResult: @pytest.mark.asyncio async def test_run_id_available_during_run() -> None: - """get_run_id() returns a non-empty ID inside a handler during a run.""" + """get_run_id() returns a non-empty ID inside a handler during run.""" captured: str = "" class Capture: @@ -133,14 +125,10 @@ def handle(self, event: TelemetryEvent) -> None: ai.telemetry.enable(Capture()) try: - - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Hi"), tools=[] - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(captured) == 16 finally: @@ -156,13 +144,10 @@ async def test_disable_reverts_to_noop() -> None: handler = RecordingHandler() ai.telemetry.enable(handler) - async def root(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, messages=ai.make_messages(user="Hi"), tools=[] - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(handler.of_type(RunStartEvent)) == 1 @@ -170,7 +155,7 @@ async def root(model: ai.Model) -> ai.StreamResult: handler.events.clear() mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(handler.events) == 0 @@ -186,19 +171,20 @@ async def test_user_emitted_custom_event(handler: RecordingHandler) -> None: class CustomEvent(TelemetryEvent): message: str - async def root(model: ai.Model) -> ai.StreamResult: + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: ai.telemetry.handle(CustomEvent(message="hello")) - return await ai.stream_loop( - model, messages=ai.make_messages(user="Hi"), tools=[] - ) + return await ai.stream_step(agent.model, msgs) mock_llm([[text_msg("Hello!")]]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] - custom = [e for e in handler.events if isinstance(e, CustomEvent)] - assert len(custom) == 1 - assert custom[0].message == "hello" + custom_events = [e for e in handler.events if isinstance(e, CustomEvent)] + assert len(custom_events) == 1 + assert custom_events[0].message == "hello" # ── Error capture ──────────────────────────────────────────────── @@ -206,13 +192,15 @@ async def root(model: ai.Model) -> ai.StreamResult: @pytest.mark.asyncio async def test_run_error_in_finish_event(handler: RecordingHandler) -> None: - """RunFinishEvent captures the error when the root function raises.""" + """RunFinishEvent captures the error when the loop function raises.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def root(model: ai.Model) -> None: + @my_agent.loop + async def failing(agent: ai.Agent, msgs: list[ai.Message]) -> None: raise ValueError("boom") mock_llm([]) - result = ai.run(root, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) with pytest.raises(ExceptionGroup): [m async for m in result] From 85945c6086e244dcafc44fd522b1c1f890b1f5e6 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Fri, 3 Apr 2026 16:50:40 -0700 Subject: [PATCH 17/18] Remove old modules, migrate to new modules --- examples/{models2 => models}/buffer.py | 2 +- .../{models2 => models}/direct_adapter.py | 4 +- .../{models2 => models}/explicit_client.py | 2 +- .../{models2 => models}/image_generation.py | 2 +- examples/{models2 => models}/inline_image.py | 2 +- .../{models2 => models}/multimodal_input.py | 2 +- examples/{models2 => models}/stream.py | 2 +- .../{models2 => models}/structured_output.py | 2 +- examples/{models2 => models}/tools.py | 2 +- .../{models2 => models}/video_generation.py | 2 +- examples/samples/custom_loop.py | 4 +- examples/temporal-durable/workflow.py | 10 +- src/vercel_ai_sdk/__init__.py | 10 +- .../adapters/ai_sdk_ui/adapter.py | 2 +- src/vercel_ai_sdk/agents/__init__.py | 25 +- .../{agents2 => agents}/agent.py | 12 +- .../{agents2 => agents}/context.py | 0 src/vercel_ai_sdk/agents/hooks.py | 46 +- src/vercel_ai_sdk/agents/mcp/client.py | 22 +- src/vercel_ai_sdk/agents/runtime.py | 455 ++++++------- src/vercel_ai_sdk/agents/streams.py | 13 +- src/vercel_ai_sdk/agents/tools.py | 29 +- src/vercel_ai_sdk/agents2/__init__.py | 64 -- src/vercel_ai_sdk/agents2/checkpoint.py | 48 -- src/vercel_ai_sdk/agents2/hooks.py | 245 ------- src/vercel_ai_sdk/agents2/mcp/__init__.py | 6 - src/vercel_ai_sdk/agents2/mcp/client.py | 301 --------- src/vercel_ai_sdk/agents2/runtime.py | 610 ------------------ src/vercel_ai_sdk/agents2/streams.py | 105 --- src/vercel_ai_sdk/agents2/tools.py | 136 ---- src/vercel_ai_sdk/models/__init__.py | 240 +++++-- .../models/ai_gateway/__init__.py | 16 +- .../{models2 => models}/ai_gateway/_common.py | 0 .../ai_gateway/generate.py | 0 src/vercel_ai_sdk/models/ai_gateway/image.py | 126 ---- src/vercel_ai_sdk/models/ai_gateway/llm.py | 192 ------ .../models/ai_gateway/protocol.py | 425 ------------ .../{models2 => models}/ai_gateway/stream.py | 0 src/vercel_ai_sdk/models/ai_gateway/video.py | 212 ------ .../models/anthropic/__init__.py | 8 +- .../{models2 => models}/anthropic/adapter.py | 0 src/vercel_ai_sdk/models/anthropic/llm.py | 341 ---------- src/vercel_ai_sdk/models/core/__init__.py | 39 +- .../{models2 => models}/core/client.py | 0 .../{models2 => models}/core/helpers/media.py | 0 .../core/helpers/streaming.py | 0 src/vercel_ai_sdk/models/core/image.py | 60 -- src/vercel_ai_sdk/models/core/llm.py | 288 --------- .../models/core/media/__init__.py | 12 - src/vercel_ai_sdk/models/core/media/base.py | 86 --- src/vercel_ai_sdk/models/core/media/data.py | 100 --- src/vercel_ai_sdk/models/core/media/detect.py | 188 ------ .../models/core/media/download.py | 104 --- src/vercel_ai_sdk/models/core/model.py | 35 +- .../{models2 => models}/core/proto.py | 0 src/vercel_ai_sdk/models/core/protocol.py | 145 ----- src/vercel_ai_sdk/models/core/registry.py | 54 -- src/vercel_ai_sdk/models/core/video.py | 66 -- src/vercel_ai_sdk/models/openai/__init__.py | 8 +- .../{models2 => models}/openai/adapter.py | 0 src/vercel_ai_sdk/models/openai/llm.py | 367 ----------- src/vercel_ai_sdk/models2/__init__.py | 205 ------ .../models2/ai_gateway/__init__.py | 14 - .../models2/ai_gateway/errors.py | 305 --------- .../models2/anthropic/__init__.py | 7 - src/vercel_ai_sdk/models2/core/__init__.py | 13 - src/vercel_ai_sdk/models2/core/model.py | 34 - src/vercel_ai_sdk/models2/openai/__init__.py | 7 - src/vercel_ai_sdk/types/messages.py | 10 +- tests/adapters/ai_sdk_ui/test_adapter.py | 2 +- tests/agents/mcp/test_client.py | 15 +- tests/agents/test_checkpoint.py | 66 +- tests/agents/test_hooks.py | 45 +- tests/agents/test_runtime.py | 101 ++- tests/agents/test_streams.py | 13 +- tests/agents2/__init__.py | 0 tests/agents2/mcp/__init__.py | 0 tests/agents2/mcp/test_client.py | 108 ---- tests/agents2/test_checkpoint.py | 207 ------ tests/agents2/test_hooks.py | 185 ------ tests/agents2/test_runtime.py | 232 ------- tests/agents2/test_streams.py | 113 ---- tests/agents2/test_tools.py | 110 ---- tests/conftest.py | 80 +-- tests/models/ai_gateway/test_gateway.py | 421 ------------ tests/models/ai_gateway/test_gateway_image.py | 262 -------- tests/models/ai_gateway/test_gateway_video.py | 354 ---------- .../ai_gateway/test_generate_image.py | 8 +- .../ai_gateway/test_generate_video.py | 10 +- tests/models/ai_gateway/test_protocol.py | 174 ++--- .../ai_gateway/test_stream.py | 8 +- tests/models/anthropic/__init__.py | 0 tests/models/anthropic/test_anthropic.py | 390 ----------- tests/models/core/media/__init__.py | 0 tests/models/core/media/test_data.py | 80 --- .../core/media/test_detect_media_type.py | 460 ------------- tests/models/core/media/test_models.py | 198 ------ tests/models/core/test_llm.py | 295 --------- tests/{models2 => models}/core/test_media.py | 2 +- .../core/test_streaming.py | 3 +- tests/models/openai/__init__.py | 0 tests/models/openai/test_openai.py | 245 ------- tests/models2/__init__.py | 0 tests/models2/ai_gateway/__init__.py | 0 tests/models2/ai_gateway/test_errors.py | 139 ---- tests/models2/ai_gateway/test_protocol.py | 460 ------------- tests/models2/core/__init__.py | 0 107 files changed, 777 insertions(+), 9881 deletions(-) rename examples/{models2 => models}/buffer.py (94%) rename examples/{models2 => models}/direct_adapter.py (90%) rename examples/{models2 => models}/explicit_client.py (95%) rename examples/{models2 => models}/image_generation.py (97%) rename examples/{models2 => models}/inline_image.py (98%) rename examples/{models2 => models}/multimodal_input.py (95%) rename examples/{models2 => models}/stream.py (94%) rename examples/{models2 => models}/structured_output.py (96%) rename examples/{models2 => models}/tools.py (96%) rename examples/{models2 => models}/video_generation.py (97%) rename src/vercel_ai_sdk/{agents2 => agents}/agent.py (97%) rename src/vercel_ai_sdk/{agents2 => agents}/context.py (100%) delete mode 100644 src/vercel_ai_sdk/agents2/__init__.py delete mode 100644 src/vercel_ai_sdk/agents2/checkpoint.py delete mode 100644 src/vercel_ai_sdk/agents2/hooks.py delete mode 100644 src/vercel_ai_sdk/agents2/mcp/__init__.py delete mode 100644 src/vercel_ai_sdk/agents2/mcp/client.py delete mode 100644 src/vercel_ai_sdk/agents2/runtime.py delete mode 100644 src/vercel_ai_sdk/agents2/streams.py delete mode 100644 src/vercel_ai_sdk/agents2/tools.py rename src/vercel_ai_sdk/{models2 => models}/ai_gateway/_common.py (100%) rename src/vercel_ai_sdk/{models2 => models}/ai_gateway/generate.py (100%) delete mode 100644 src/vercel_ai_sdk/models/ai_gateway/image.py delete mode 100644 src/vercel_ai_sdk/models/ai_gateway/llm.py delete mode 100644 src/vercel_ai_sdk/models/ai_gateway/protocol.py rename src/vercel_ai_sdk/{models2 => models}/ai_gateway/stream.py (100%) delete mode 100644 src/vercel_ai_sdk/models/ai_gateway/video.py rename src/vercel_ai_sdk/{models2 => models}/anthropic/adapter.py (100%) delete mode 100644 src/vercel_ai_sdk/models/anthropic/llm.py rename src/vercel_ai_sdk/{models2 => models}/core/client.py (100%) rename src/vercel_ai_sdk/{models2 => models}/core/helpers/media.py (100%) rename src/vercel_ai_sdk/{models2 => models}/core/helpers/streaming.py (100%) delete mode 100644 src/vercel_ai_sdk/models/core/image.py delete mode 100644 src/vercel_ai_sdk/models/core/llm.py delete mode 100644 src/vercel_ai_sdk/models/core/media/__init__.py delete mode 100644 src/vercel_ai_sdk/models/core/media/base.py delete mode 100644 src/vercel_ai_sdk/models/core/media/data.py delete mode 100644 src/vercel_ai_sdk/models/core/media/detect.py delete mode 100644 src/vercel_ai_sdk/models/core/media/download.py rename src/vercel_ai_sdk/{models2 => models}/core/proto.py (100%) delete mode 100644 src/vercel_ai_sdk/models/core/protocol.py delete mode 100644 src/vercel_ai_sdk/models/core/registry.py delete mode 100644 src/vercel_ai_sdk/models/core/video.py rename src/vercel_ai_sdk/{models2 => models}/openai/adapter.py (100%) delete mode 100644 src/vercel_ai_sdk/models/openai/llm.py delete mode 100644 src/vercel_ai_sdk/models2/__init__.py delete mode 100644 src/vercel_ai_sdk/models2/ai_gateway/__init__.py delete mode 100644 src/vercel_ai_sdk/models2/ai_gateway/errors.py delete mode 100644 src/vercel_ai_sdk/models2/anthropic/__init__.py delete mode 100644 src/vercel_ai_sdk/models2/core/__init__.py delete mode 100644 src/vercel_ai_sdk/models2/core/model.py delete mode 100644 src/vercel_ai_sdk/models2/openai/__init__.py delete mode 100644 tests/agents2/__init__.py delete mode 100644 tests/agents2/mcp/__init__.py delete mode 100644 tests/agents2/mcp/test_client.py delete mode 100644 tests/agents2/test_checkpoint.py delete mode 100644 tests/agents2/test_hooks.py delete mode 100644 tests/agents2/test_runtime.py delete mode 100644 tests/agents2/test_streams.py delete mode 100644 tests/agents2/test_tools.py delete mode 100644 tests/models/ai_gateway/test_gateway.py delete mode 100644 tests/models/ai_gateway/test_gateway_image.py delete mode 100644 tests/models/ai_gateway/test_gateway_video.py rename tests/{models2 => models}/ai_gateway/test_generate_image.py (97%) rename tests/{models2 => models}/ai_gateway/test_generate_video.py (97%) rename tests/{models2 => models}/ai_gateway/test_stream.py (98%) delete mode 100644 tests/models/anthropic/__init__.py delete mode 100644 tests/models/anthropic/test_anthropic.py delete mode 100644 tests/models/core/media/__init__.py delete mode 100644 tests/models/core/media/test_data.py delete mode 100644 tests/models/core/media/test_detect_media_type.py delete mode 100644 tests/models/core/media/test_models.py delete mode 100644 tests/models/core/test_llm.py rename tests/{models2 => models}/core/test_media.py (99%) rename tests/{models2 => models}/core/test_streaming.py (99%) delete mode 100644 tests/models/openai/__init__.py delete mode 100644 tests/models/openai/test_openai.py delete mode 100644 tests/models2/__init__.py delete mode 100644 tests/models2/ai_gateway/__init__.py delete mode 100644 tests/models2/ai_gateway/test_errors.py delete mode 100644 tests/models2/ai_gateway/test_protocol.py delete mode 100644 tests/models2/core/__init__.py diff --git a/examples/models2/buffer.py b/examples/models/buffer.py similarity index 94% rename from examples/models2/buffer.py rename to examples/models/buffer.py index 5cb88c7e..4020affd 100644 --- a/examples/models2/buffer.py +++ b/examples/models/buffer.py @@ -2,7 +2,7 @@ import asyncio -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/direct_adapter.py b/examples/models/direct_adapter.py similarity index 90% rename from examples/models2/direct_adapter.py rename to examples/models/direct_adapter.py index fe680dae..df386a0b 100644 --- a/examples/models2/direct_adapter.py +++ b/examples/models/direct_adapter.py @@ -3,8 +3,8 @@ import asyncio import os -from vercel_ai_sdk import models2 as m -from vercel_ai_sdk.models2 import ai_gateway as ai_gateway_v3 +from vercel_ai_sdk import models as m +from vercel_ai_sdk.models import ai_gateway as ai_gateway_v3 from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/explicit_client.py b/examples/models/explicit_client.py similarity index 95% rename from examples/models2/explicit_client.py rename to examples/models/explicit_client.py index e4539623..6c3d7c6e 100644 --- a/examples/models2/explicit_client.py +++ b/examples/models/explicit_client.py @@ -3,7 +3,7 @@ import asyncio import os -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/image_generation.py b/examples/models/image_generation.py similarity index 97% rename from examples/models2/image_generation.py rename to examples/models/image_generation.py index 8ba4d318..63b70d7d 100644 --- a/examples/models2/image_generation.py +++ b/examples/models/image_generation.py @@ -4,7 +4,7 @@ import base64 import pathlib -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/inline_image.py b/examples/models/inline_image.py similarity index 98% rename from examples/models2/inline_image.py rename to examples/models/inline_image.py index 4686b33e..91777e87 100644 --- a/examples/models2/inline_image.py +++ b/examples/models/inline_image.py @@ -9,7 +9,7 @@ import base64 import pathlib -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ # This is a language model that can also output images inline. diff --git a/examples/models2/multimodal_input.py b/examples/models/multimodal_input.py similarity index 95% rename from examples/models2/multimodal_input.py rename to examples/models/multimodal_input.py index f5d3b475..f5a11a14 100644 --- a/examples/models2/multimodal_input.py +++ b/examples/models/multimodal_input.py @@ -3,7 +3,7 @@ import asyncio import pathlib -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/stream.py b/examples/models/stream.py similarity index 94% rename from examples/models2/stream.py rename to examples/models/stream.py index 7e1d08a9..1183fb05 100644 --- a/examples/models2/stream.py +++ b/examples/models/stream.py @@ -2,7 +2,7 @@ import asyncio -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/structured_output.py b/examples/models/structured_output.py similarity index 96% rename from examples/models2/structured_output.py rename to examples/models/structured_output.py index c1572988..172d7201 100644 --- a/examples/models2/structured_output.py +++ b/examples/models/structured_output.py @@ -4,7 +4,7 @@ import pydantic -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/models2/tools.py b/examples/models/tools.py similarity index 96% rename from examples/models2/tools.py rename to examples/models/tools.py index 2e25eb96..3e3c5d81 100644 --- a/examples/models2/tools.py +++ b/examples/models/tools.py @@ -2,7 +2,7 @@ import asyncio -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ from vercel_ai_sdk.types import tools as tools_ diff --git a/examples/models2/video_generation.py b/examples/models/video_generation.py similarity index 97% rename from examples/models2/video_generation.py rename to examples/models/video_generation.py index ece777ad..b5f5c8d3 100644 --- a/examples/models2/video_generation.py +++ b/examples/models/video_generation.py @@ -4,7 +4,7 @@ import base64 import pathlib -from vercel_ai_sdk import models2 as m +from vercel_ai_sdk import models as m from vercel_ai_sdk.types import messages as messages_ model = m.Model( diff --git a/examples/samples/custom_loop.py b/examples/samples/custom_loop.py index 082f5ac1..923768fd 100644 --- a/examples/samples/custom_loop.py +++ b/examples/samples/custom_loop.py @@ -28,8 +28,8 @@ async def custom_stream_step( tools: list[ai.Tool[..., Any]], label: str | None = None, ) -> AsyncGenerator[ai.Message]: - """Wraps models2.stream to inject a label on every message.""" - async for msg in ai.models2.stream(model, messages, tools=tools): + """Wraps models.stream to inject a label on every message.""" + async for msg in ai.models.stream(model, messages, tools=tools): msg.label = label yield msg diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index cae25e73..571792d4 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -2,7 +2,7 @@ NOTE: This example still uses the old models.LanguageModel ABC because it wraps Temporal activities as a custom model. When the models layer -is fully migrated to models2, this will need a custom adapter instead. +is fully migrated to models, this will need a custom adapter instead. """ from __future__ import annotations @@ -85,7 +85,7 @@ async def get_population(city: str) -> int: # # TODO: This example uses the old LanguageModel ABC and ai.run() / # ai.stream_loop free-function patterns. Once the models layer is -# migrated, convert to use ai.agent() + models2.Model with a custom +# migrated, convert to use ai.agent() + models.Model with a custom # adapter for Temporal activity-based LLM calls. @@ -94,7 +94,7 @@ async def agent(llm: Any, user_query: str) -> ai.StreamResult: This is a transitional pattern. The old ai.stream_loop and ai.run are no longer part of the public API. This example needs a custom - models2 adapter to work with the new Agent API. + models adapter to work with the new Agent API. """ messages = ai.make_messages( system="Answer questions using the weather and population tools.", @@ -139,9 +139,9 @@ async def run(self, user_query: str) -> str: ) ) - # TODO: This uses the old free-function pattern. Once models2 + # TODO: This uses the old free-function pattern. Once models # supports custom adapters for Temporal, use Agent.run() instead. - from vercel_ai_sdk.agents2 import run + from vercel_ai_sdk.agents import run final_text = "" async for msg in run(agent, llm, user_query): diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index 6d73752e..5dcc0c59 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,6 +1,6 @@ -from . import adapters, models, models2, telemetry +from . import adapters, models, telemetry from .adapters import ai_sdk_ui -from .agents2 import ( +from .agents import ( Agent, AgentRun, Checkpoint, @@ -25,7 +25,7 @@ stream_step, tool, ) -from .models2 import Client, Model, ModelCost +from .models import Client, Model, ModelCost # Re-export core types from .types import ( @@ -61,12 +61,10 @@ "ToolSchema", "Usage", "make_messages", - # Models (from models2/) + # Models (from models/) "Model", "ModelCost", "Client", - "models2", - # Legacy (from models/) — kept during transition "models", # Agents — primary API "Agent", diff --git a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py index e1c108f7..c68ea0d6 100644 --- a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py @@ -11,7 +11,7 @@ from collections.abc import AsyncGenerator, AsyncIterable from typing import Any, Literal -from ...agents2 import hooks +from ...agents import hooks from ...types import messages as messages_ from . import protocol, ui_message diff --git a/src/vercel_ai_sdk/agents/__init__.py b/src/vercel_ai_sdk/agents/__init__.py index 5822431d..d7a62b0f 100644 --- a/src/vercel_ai_sdk/agents/__init__.py +++ b/src/vercel_ai_sdk/agents/__init__.py @@ -1,35 +1,48 @@ """Agent loop orchestration — tools, hooks, runtime, and streaming. -Depends on types/ and models2/. Provides the loop machinery that +Depends on types/ and models/. Provides the loop machinery that plugs a model into a tool-calling loop with hooks and checkpoints. """ from . import mcp +from .agent import Agent, AgentRun, LoopFn, agent, stream_step from .checkpoint import Checkpoint, PendingHookInfo +from .context import Context, ToolSource, get_context from .hooks import Hook, ToolApproval, hook from .runtime import ( + EventLog, HookInfo, + LoopExecutor, RunResult, Runtime, execute_tool, get_checkpoint, run, - stream_loop, - stream_step, ) from .streams import StreamResult, stream from .tools import Tool, ToolLike, ToolSchema, get_tool, tool __all__ = [ - # Core loop - "run", + # Agent (primary user API) + "Agent", + "AgentRun", + "agent", + "LoopFn", + # Composition primitives "stream_step", - "stream_loop", "execute_tool", "get_checkpoint", + # Context + "Context", + "ToolSource", + "get_context", + # Runtime (developer API) "Runtime", + "EventLog", + "LoopExecutor", "RunResult", "HookInfo", + "run", # Stream "stream", "StreamResult", diff --git a/src/vercel_ai_sdk/agents2/agent.py b/src/vercel_ai_sdk/agents/agent.py similarity index 97% rename from src/vercel_ai_sdk/agents2/agent.py rename to src/vercel_ai_sdk/agents/agent.py index 9c312593..1a5b1ec7 100644 --- a/src/vercel_ai_sdk/agents2/agent.py +++ b/src/vercel_ai_sdk/agents/agent.py @@ -29,7 +29,7 @@ import pydantic -from .. import models2 +from .. import models from ..types import messages as messages_ from . import checkpoint as checkpoint_ from . import context as context_ @@ -49,7 +49,7 @@ @streams_.stream async def stream_step( - model: models2.Model, + model: models.Model, messages: list[messages_.Message], tools: Sequence[tools_.ToolLike] | None = None, label: str | None = None, @@ -63,7 +63,7 @@ async def stream_step( ``@stream``, so each call becomes a replayable step in the event log. """ - async for msg in models2.stream( + async for msg in models.stream( model, messages, tools=tools, output_type=output_type, **kwargs ): msg.label = label @@ -141,7 +141,7 @@ async def custom(agent, messages): def __init__( self, - model: models2.Model, + model: models.Model, system: str = "", tools: list[tools_.Tool[..., Any]] | None = None, ) -> None: @@ -151,7 +151,7 @@ def __init__( self._custom_loop: LoopFn | None = None @property - def model(self) -> models2.Model: + def model(self) -> models.Model: return self._model @property @@ -250,7 +250,7 @@ async def _graph() -> streams_.StreamResult | None: def agent( - model: models2.Model, + model: models.Model, system: str = "", tools: list[tools_.Tool[..., Any]] | None = None, ) -> Agent: diff --git a/src/vercel_ai_sdk/agents2/context.py b/src/vercel_ai_sdk/agents/context.py similarity index 100% rename from src/vercel_ai_sdk/agents2/context.py rename to src/vercel_ai_sdk/agents/context.py diff --git a/src/vercel_ai_sdk/agents/hooks.py b/src/vercel_ai_sdk/agents/hooks.py index 758a4c4f..948539bb 100644 --- a/src/vercel_ai_sdk/agents/hooks.py +++ b/src/vercel_ai_sdk/agents/hooks.py @@ -46,8 +46,7 @@ def _cleanup_run(labels: set[str]) -> None: class Hook[T: pydantic.BaseModel]: - """ - Hook: a suspension point that requires external input to continue. + """Hook: a suspension point that requires external input to continue. Usage in graph code: @@ -77,20 +76,14 @@ class Hook[T: pydantic.BaseModel]: @classmethod async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: - """ - Create a hook and await its resolution. + """Create a hook and await its resolution. - The hook is submitted to the Runtime's step queue. run() will either: + The hook is submitted to the LoopExecutor's step queue. run() will + either: - Resolve immediately (if a resolution is available from checkpoint or pre-registered via Hook.resolve()) - Cancel the future (cancels_future=True, serverless mode) - Hold the future (cancels_future=False, long-running mode) - - Args: - label: Stable identifier for this hook. Used to match resolutions - across requests in serverless mode. Must be unique within - a single run. - metadata: Optional metadata surfaced in the pending HookPart message. """ from . import runtime as rt_mod @@ -101,16 +94,16 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: # Check pre-registered resolutions (serverless re-entry path) pre_registered = _pending_resolutions.pop(label, None) if pre_registered is not None: - rt.record_hook(label, pre_registered) + rt.log.record_hook(label, pre_registered) return cls._schema(**pre_registered) # type: ignore[return-value] # Check checkpoint for a previously resolved value - resolution = rt.get_hook_resolution(label) + resolution = rt.log.get_hook_resolution(label) if resolution is not None: - rt.record_hook(label, resolution) + rt.log.record_hook(label, resolution) return cls._schema(**resolution) # type: ignore[return-value] - # Submit to step queue — run() decides what to do + # Submit to executor queue — run() decides what to do future: asyncio.Future[dict[str, Any]] = asyncio.Future() suspension = rt_mod.HookSuspension( label=label, @@ -119,12 +112,12 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: future=future, cancels_future=cls.cancels_future, ) - await rt.put_hook_suspension(suspension) + await rt.executor.put_hook(suspension) # Register in module-level registry for external resolution hook_metadata = metadata or {} _live_hooks[label] = (future, hook_metadata, rt) - rt.track_hook_label(label) + rt.executor.track_hook_label(label) # Await resolution — may be resolved immediately by run(), # cancelled by run() (serverless), or resolved later by @@ -135,10 +128,10 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: _live_hooks.pop(label, None) # Record for checkpoint - rt.record_hook(label, resolution) + rt.log.record_hook(label, resolution) # Emit resolved message - await rt.put_message( + await rt.executor.put_message( messages_.Message( role="assistant", parts=[ @@ -157,8 +150,7 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: @classmethod def resolve(cls, label: str, data: T | dict[str, Any]) -> None: - """ - Resolve a hook by label. + """Resolve a hook by label. Works in two modes: @@ -169,19 +161,12 @@ def resolve(cls, label: str, data: T | dict[str, Any]) -> None: stashes it in the pre-registration registry. When ai.run() replays the graph and Hook.create() executes, it finds the pre-registered resolution and returns without suspending. - - Args: - label: The hook label to resolve. - data: Resolution payload (dict or pydantic model). Validated - against the hook's schema immediately. """ # Validate and normalize to dict if isinstance(data, dict): - # Validate by constructing the schema model validated = cls._schema(**data) resolution = validated.model_dump() else: - # Already a model instance — validate it's the right type if not isinstance(data, cls._schema): raise TypeError( f"Expected {cls._schema.__name__} or dict, " @@ -211,7 +196,7 @@ async def cancel(cls, label: str, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - await rt.put_message( + await rt.executor.put_message( messages_.Message( role="assistant", parts=[ @@ -227,8 +212,7 @@ async def cancel(cls, label: str, reason: str | None = None) -> None: def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: - """ - Decorator to create a Hook type from a pydantic model. + """Decorator to create a Hook type from a pydantic model. The pydantic model defines the schema for the hook's resolution payload. """ diff --git a/src/vercel_ai_sdk/agents/mcp/client.py b/src/vercel_ai_sdk/agents/mcp/client.py index c17a25a0..def1f0e6 100644 --- a/src/vercel_ai_sdk/agents/mcp/client.py +++ b/src/vercel_ai_sdk/agents/mcp/client.py @@ -14,6 +14,7 @@ import mcp.client.streamable_http import mcp.types +from .. import context as context_ from .. import tools as tools_ __all__ = [ @@ -243,11 +244,30 @@ def _mcp_tool_to_native( return_type=Any, ) + # Determine source provenance from connection key + if connection_key.startswith("http:"): + source = context_.ToolSource( + kind="mcp_http", + uri=connection_key.removeprefix("http:"), + ) + elif connection_key.startswith("stdio:"): + source = context_.ToolSource( + kind="mcp_stdio", + server_command=connection_key.removeprefix("stdio:"), + ) + else: + source = context_.ToolSource(kind="mcp") + t = tools_.Tool( fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), schema=schema, + source=source, ) - # Register so execute_tool() can find it by name + + # Register on active Context if available, else fall back to global + ctx = context_._context.get(None) + if ctx is not None: + ctx.register_tool(t) tools_._tool_registry[name] = t return t diff --git a/src/vercel_ai_sdk/agents/runtime.py b/src/vercel_ai_sdk/agents/runtime.py index 4c156c34..724267b3 100644 --- a/src/vercel_ai_sdk/agents/runtime.py +++ b/src/vercel_ai_sdk/agents/runtime.py @@ -5,15 +5,15 @@ import dataclasses import json import logging -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine from typing import Any, get_type_hints import pydantic -from .. import models2 from ..telemetry import events as telemetry_ from ..types import messages as messages_ from . import checkpoint as checkpoint_ +from . import context as context_ from . import hooks as hooks_ from . import mcp from . import streams as streams_ @@ -21,51 +21,23 @@ logger = logging.getLogger(__name__) -# ── Queue item types ────────────────────────────────────────────── - -@dataclasses.dataclass -class HookSuspension: - """Submitted to the step queue when a hook needs resolution.""" - - label: str - hook_type: str - metadata: dict[str, Any] - future: asyncio.Future[Any] - cancels_future: bool = False - - -# ── Runtime ─────────────────────────────────────────────────────── +# ── EventLog ────────────────────────────────────────────────────── +# +# Pure bookkeeping: replay from checkpoint + record new events. +# No asyncio, no queues — just data in, data out. +# -class Runtime: - """ - Central coordinator for the agent loop. +class EventLog: + """Replay/record layer backed by a Checkpoint. - Functions decorated with @stream submit step functions to the queue. - Hooks submit HookSuspension items to the same queue. - run() pulls items, processes them, yields messages, and resolves futures. + Holds replay cursors (read pointer into the checkpoint) and + recording lists (new events produced during this run). + Completely synchronous — no queues, no async. """ - class _Sentinel: - pass - - _SENTINEL = _Sentinel() - - def __init__( - self, - checkpoint: checkpoint_.Checkpoint | None = None, - ) -> None: - self._step_queue: asyncio.Queue[ - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] - | HookSuspension - | Runtime._Sentinel - ] = asyncio.Queue() - - # Message queue for streaming tool results and hook messages - self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() - - # Checkpoint: replay state from previous run + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: self._checkpoint = checkpoint or checkpoint_.Checkpoint() # Replay cursors @@ -77,45 +49,16 @@ def __init__( h.label: h.resolution for h in self._checkpoint.hooks } - # Recording: new events from this run + # Recording lists (new events from this run) self._step_log: list[checkpoint_.StepEvent] = [] self._tool_log: list[checkpoint_.ToolEvent] = [] self._hook_log: list[checkpoint_.HookEvent] = [] - # Pending hooks (unresolved during this run) - self._pending_hooks: dict[str, HookSuspension] = {} - - # Track hook labels registered in this run for cleanup - self._hook_labels: set[str] = set() - - # ── Step queue ──────────────────────────────────────────────── - - async def put_step( - self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] - ) -> None: - await self._step_queue.put((step_fn, future)) - - async def put_hook_suspension(self, suspension: HookSuspension) -> None: - await self._step_queue.put(suspension) - - async def signal_done(self) -> None: - await self._step_queue.put(self._SENTINEL) + # ── Steps ───────────────────────────────────────────────── - # ── Message queue ───────────────────────────────────────────── - - async def put_message(self, message: messages_.Message) -> None: - await self._message_queue.put(message) - - def get_all_messages(self) -> list[messages_.Message]: - msgs = [] - while not self._message_queue.empty(): - try: - msgs.append(self._message_queue.get_nowait()) - except asyncio.QueueEmpty: - break - return msgs - - # ── Replay / record: steps ──────────────────────────────────── + @property + def step_index(self) -> int: + return self._step_index def try_replay_step(self) -> streams_.StreamResult | None: if self._step_index < len(self._checkpoint.steps): @@ -133,10 +76,9 @@ def record_step(self, result: streams_.StreamResult) -> None: self._step_log.append(event) self._step_index += 1 - # ── Replay / record: tools ──────────────────────────────────── + # ── Tools ───────────────────────────────────────────────── def try_replay_tool(self, tool_call_id: str) -> checkpoint_.ToolEvent | None: - """Return the cached ToolEvent if available, else None.""" event = self._tool_replay.get(tool_call_id) if event is not None: logger.info( @@ -155,7 +97,7 @@ def record_tool( ) ) - # ── Replay / record: hooks ──────────────────────────────────── + # ── Hooks ───────────────────────────────────────────────── def get_hook_resolution(self, label: str) -> dict[str, Any] | None: resolution = self._hook_replay.get(label) @@ -166,25 +108,158 @@ def get_hook_resolution(self, label: str) -> dict[str, Any] | None: def record_hook(self, label: str, resolution: dict[str, Any]) -> None: self._hook_log.append(checkpoint_.HookEvent(label=label, resolution=resolution)) - def track_hook_label(self, label: str) -> None: - """Track a hook label for cleanup when the run completes.""" - self._hook_labels.add(label) + # ── Snapshot ────────────────────────────────────────────── - # ── Checkpoint ──────────────────────────────────────────────── - - def get_checkpoint(self) -> checkpoint_.Checkpoint: + def checkpoint( + self, pending_hooks: list[checkpoint_.PendingHookInfo] | None = None + ) -> checkpoint_.Checkpoint: + """Build a full Checkpoint merging prior state + new recordings.""" return checkpoint_.Checkpoint( steps=list(self._checkpoint.steps) + self._step_log, tools=list(self._checkpoint.tools) + self._tool_log, hooks=list(self._checkpoint.hooks) + self._hook_log, - pending_hooks=[ - checkpoint_.PendingHookInfo( - label=sus.label, - hook_type=sus.hook_type, - metadata=sus.metadata, - ) - for sus in self._pending_hooks.values() - ], + pending_hooks=pending_hooks or [], + ) + + +# ── LoopExecutor ───────────────────────────────────────────────── +# +# Async coordination: queues that let graph code (streams, hooks, +# tools) talk to the driver loop. Pure mailbox — no replay, no +# checkpoint awareness. +# + + +@dataclasses.dataclass +class HookSuspension: + """Submitted to the step queue when a hook needs resolution.""" + + label: str + hook_type: str + metadata: dict[str, Any] + future: asyncio.Future[Any] + cancels_future: bool = False + + +class LoopExecutor: + """Async coordination layer between graph code and the driver loop. + + Graph code (``@stream`` decorators, hooks, tool execution) submits + work via the producer methods. The driver loop consumes via + ``next()`` and ``drain_messages()``. + """ + + class _Sentinel: + pass + + _SENTINEL = _Sentinel() + + def __init__(self) -> None: + self._step_queue: asyncio.Queue[ + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | LoopExecutor._Sentinel + ] = asyncio.Queue() + + self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() + + # Pending hooks (unresolved during this run) + self._pending_hooks: dict[str, HookSuspension] = {} + + # Track hook labels registered in this run for cleanup + self._hook_labels: set[str] = set() + + # ── Producers (called by graph code) ────────────────────── + + async def put_step( + self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] + ) -> None: + await self._step_queue.put((step_fn, future)) + + async def put_hook(self, suspension: HookSuspension) -> None: + await self._step_queue.put(suspension) + + async def put_message(self, message: messages_.Message) -> None: + await self._message_queue.put(message) + + async def done(self) -> None: + await self._step_queue.put(self._SENTINEL) + + # ── Consumer (called by driver loop) ────────────────────── + + async def next( + self, timeout: float = 0.1 + ) -> ( + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | None + ): + """Pull the next item from the step queue. + + Returns ``None`` on timeout (no item available). + Returns the sentinel's semantic equivalent by raising StopIteration + when the graph signals completion. + """ + try: + item = await asyncio.wait_for(self._step_queue.get(), timeout=timeout) + except TimeoutError: + return None + + if isinstance(item, LoopExecutor._Sentinel): + raise _LoopDone + return item + + def drain_messages(self) -> list[messages_.Message]: + msgs: list[messages_.Message] = [] + while not self._message_queue.empty(): + try: + msgs.append(self._message_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return msgs + + # ── Hook label tracking ─────────────────────────────────── + + def track_hook_label(self, label: str) -> None: + self._hook_labels.add(label) + + def pending_hook_infos(self) -> list[checkpoint_.PendingHookInfo]: + return [ + checkpoint_.PendingHookInfo( + label=sus.label, + hook_type=sus.hook_type, + metadata=sus.metadata, + ) + for sus in self._pending_hooks.values() + ] + + +class _LoopDone(Exception): + """Internal signal: the loop function has finished.""" + + +# ── Runtime ─────────────────────────────────────────────────────── +# +# Thin composition of EventLog + LoopExecutor. +# The context var points here; graph code accesses rt.log and +# rt.executor directly. +# + + +class Runtime: + """Central coordinator — composes EventLog and LoopExecutor. + + Graph code accesses ``rt.log`` for replay/record and + ``rt.executor`` for async coordination. + """ + + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: + self.log = EventLog(checkpoint) + self.executor = LoopExecutor() + + def checkpoint(self) -> checkpoint_.Checkpoint: + return self.log.checkpoint( + pending_hooks=self.executor.pending_hook_infos(), ) @@ -193,7 +268,7 @@ def get_checkpoint(self) -> checkpoint_.Checkpoint: def get_checkpoint() -> checkpoint_.Checkpoint: """Get the current checkpoint from the active Runtime.""" - return _runtime.get().get_checkpoint() + return _runtime.get().checkpoint() def _find_runtime_param(fn: Callable[..., Any]) -> str | None: @@ -208,37 +283,17 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: return None -# ── Convenience functions ───────────────────────────────────────── - - -@streams_.stream -async def stream_step( - model: models2.Model, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: - """Single LLM call that streams to Runtime.""" - async for msg in models2.stream( - model, messages, tools=tools, output_type=output_type, **kwargs - ): - msg.label = label - yield msg - - async def execute_tool( tool_call: messages_.ToolPart, message: messages_.Message | None = None, ) -> Any: - """ - Execute a single tool call with replay support. + """Execute a single tool call with replay support. - Looks up the tool by name from the global registry, executes it, - and updates the ToolPart (and parent Message) with the result. - Emits the updated message to the Runtime queue so the UI sees - the transition from status="pending" to status="result" (or "error"). + Looks up the tool by name — first from the active Context (if any), + then from the global registry. Executes it and updates the ToolPart + (and parent Message) with the result. Emits the updated message to + the LoopExecutor queue so the UI sees the transition from + status="pending" to status="result" (or "error"). If a checkpoint exists with a cached result for this tool_call_id, returns the cached result without re-executing. @@ -247,7 +302,7 @@ async def execute_tool( # Replay: return cached result if available if rt: - cached = rt.try_replay_tool(tool_call.tool_call_id) + cached = rt.log.try_replay_tool(tool_call.tool_call_id) if cached is not None: if cached.status == "error": tool_call.set_error(cached.result) @@ -264,8 +319,13 @@ async def execute_tool( ) t0 = telemetry_.time_ms() - # Fresh execution - tool = tools_.get_tool(tool_call.tool_name) + # Fresh execution — resolve from Context first, then global registry + tool: tools_.Tool[..., Any] | None = None + ctx = context_._context.get(None) + if ctx is not None: + tool = ctx.get_tool(tool_call.tool_name) + if tool is None: + tool = tools_.get_tool(tool_call.tool_name) if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") @@ -274,8 +334,6 @@ async def execute_tool( result = await tool.validate_and_call(tool_call.tool_args, rt) tool_call.set_result(result) except (json.JSONDecodeError, pydantic.ValidationError) as exc: - # LLM produced malformed JSON or args that don't match the schema. - # Report back as a tool error so the model can retry. result = f"{type(exc).__name__}: {exc}" error_str = result tool_call.set_error(result) @@ -292,43 +350,15 @@ async def execute_tool( # Record for checkpoint if rt: - rt.record_tool(tool_call.tool_call_id, result, status=tool_call.status) + rt.log.record_tool(tool_call.tool_call_id, result, status=tool_call.status) # Emit updated message so UI sees status change if rt and message: - await rt.put_message(message.model_copy(deep=True)) + await rt.executor.put_message(message.model_copy(deep=True)) return result -async def stream_loop( - model: models2.Model, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike], - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, - **kwargs: Any, -) -> streams_.StreamResult: - """Agent loop: stream LLM, execute tools, repeat until done.""" - local_messages = list(messages) - - while True: - result = await stream_step( - model, local_messages, tools, label=label, output_type=output_type, **kwargs - ) - - if not result.tool_calls: - return result - - last_msg = result.last_message - if last_msg is not None: - local_messages.append(last_msg) - - await asyncio.gather( - *(execute_tool(tc, message=last_msg) for tc in result.tool_calls) - ) - - # ── RunResult ───────────────────────────────────────────────────── @@ -342,8 +372,7 @@ class HookInfo: class RunResult: - """ - Returned by run(). Async-iterate for messages, then check state. + """Returned by run(). Async-iterate for messages, then check state. Usage: result = ai.run(my_graph, llm, query) @@ -361,7 +390,7 @@ def __init__(self) -> None: def checkpoint(self) -> checkpoint_.Checkpoint: if self._runtime is None: return checkpoint_.Checkpoint() - return self._runtime.get_checkpoint() + return self._runtime.checkpoint() @property def pending_hooks(self) -> dict[str, HookInfo]: @@ -373,7 +402,7 @@ def pending_hooks(self) -> dict[str, HookInfo]: hook_type=sus.hook_type, metadata=sus.metadata, ) - for label, sus in self._runtime._pending_hooks.items() + for label, sus in self._runtime.executor._pending_hooks.items() } async def __aiter__(self) -> AsyncGenerator[messages_.Message]: @@ -385,37 +414,38 @@ async def __aiter__(self) -> AsyncGenerator[messages_.Message]: # ── run() ───────────────────────────────────────────────────────── -async def _stop_when_done(runtime: Runtime, task: Awaitable[None]) -> None: +async def _stop_when_done(executor: LoopExecutor, task: Awaitable[None]) -> None: try: await task finally: - await runtime.signal_done() + await executor.done() def run( root: Callable[..., Coroutine[Any, Any, Any]], *args: Any, checkpoint: checkpoint_.Checkpoint | None = None, + context: context_.Context | None = None, ) -> RunResult: - """ - Main entry point. + """Main entry point. 1. Starts the root function as a background task - 2. Pulls steps and hook suspensions from the Runtime queue + 2. Pulls steps and hook suspensions from the LoopExecutor queue 3. Executes each step, yielding messages 4. Resolves or suspends hooks depending on the hook's cancels_future - class variable: - - cancels_future=True (serverless): cancel the future, branch dies, - caller inspects result.pending_hooks and result.checkpoint to resume - - cancels_future=False (long-running, default): future stays alive, - external code calls Hook.resolve() / Hook.cancel() to unblock 5. Returns RunResult with .checkpoint and .pending_hooks + + Args: + root: The loop function to execute. + *args: Positional arguments forwarded to ``root``. + checkpoint: Checkpoint to resume from. + context: LLM prompt context (tools, system prompt, messages). + If ``None``, an empty Context is created automatically. """ result = RunResult() # Discard stale checkpoints: if the checkpoint has pending hooks but - # none of them have been resolved (via Hook.resolve() / to_messages()), - # this isn't a resume — it's a fresh turn with an outdated checkpoint. + # none of them have been resolved, this isn't a resume. effective_checkpoint = checkpoint if checkpoint and checkpoint.pending_hooks: pending_labels = [ph.label for ph in checkpoint.pending_hooks] @@ -437,9 +467,13 @@ class variable: ) async def _generate() -> AsyncGenerator[messages_.Message]: - runtime = Runtime(checkpoint=effective_checkpoint) - result._runtime = runtime - token_runtime = _runtime.set(runtime) + rt = Runtime(checkpoint=effective_checkpoint) + result._runtime = rt + token_runtime = _runtime.set(rt) + + ctx = context or context_.Context() + token_context = context_._context.set(ctx) + token_run_id = telemetry_.start_run() telemetry_.handle(telemetry_.RunStartEvent()) @@ -449,7 +483,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: kwargs: dict[str, Any] = {} if runtime_param := _find_runtime_param(root): - kwargs[runtime_param] = runtime + kwargs[runtime_param] = rt run_error: BaseException | None = None total_usage: messages_.Usage | None = None @@ -457,77 +491,64 @@ async def _generate() -> AsyncGenerator[messages_.Message]: try: async with asyncio.TaskGroup() as tg: _task: asyncio.Task[None] = tg.create_task( - _stop_when_done(runtime, root(*args, **kwargs)) + _stop_when_done(rt.executor, root(*args, **kwargs)) ) while True: # Drain pending messages - for msg in runtime.get_all_messages(): + for msg in rt.executor.drain_messages(): yield msg - # Wait for next queue item + # Pull next item from the graph executor try: - step_item = await asyncio.wait_for( - runtime._step_queue.get(), timeout=0.1 - ) - except TimeoutError: - continue - - if isinstance(step_item, Runtime._Sentinel): - for msg in runtime.get_all_messages(): + item = await rt.executor.next() + except _LoopDone: + for msg in rt.executor.drain_messages(): yield msg break + if item is None: + # Timeout — no item available, loop again + continue + # ── Hook suspension ──────────────────────── - if isinstance(step_item, HookSuspension): - resolution = runtime.get_hook_resolution(step_item.label) + if isinstance(item, HookSuspension): + resolution = rt.log.get_hook_resolution(item.label) if resolution is not None: - # Resolve immediately - step_item.future.set_result(resolution) - runtime.record_hook(step_item.label, resolution) + item.future.set_result(resolution) + rt.log.record_hook(item.label, resolution) else: - # No resolution available - runtime._pending_hooks[step_item.label] = step_item - if step_item.cancels_future: - # Serverless: cancel the future so the branch - # dies with CancelledError. Caller inspects - # result.pending_hooks to resume later. - step_item.future.cancel() - # else: long-running — future stays alive, - # external code calls Hook.resolve() to unblock. - - # Yield pending hook message + rt.executor._pending_hooks[item.label] = item + if item.cancels_future: + item.future.cancel() + yield messages_.Message( role="assistant", parts=[ messages_.HookPart( - hook_id=step_item.label, - hook_type=step_item.hook_type, + hook_id=item.label, + hook_type=item.hook_type, status="pending", - metadata=step_item.metadata, + metadata=item.metadata, ) ], ) - # Let resolved branches resume and submit their - # next steps before we pull from the queue again. await asyncio.sleep(0) - - # Drain messages after hook processing - for msg in runtime.get_all_messages(): + for msg in rt.executor.drain_messages(): yield msg continue # ── Regular step ─────────────────────────── - step_fn, future = step_item + step_fn, future = item telemetry_.handle( telemetry_.StepStartEvent( - step_index=runtime._step_index, + step_index=rt.log.step_index, ) ) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg result_messages: list[messages_.Message] = [] @@ -537,7 +558,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: yield msg_copy result_messages.append(msg) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg step_result = streams_.StreamResult(messages=result_messages) @@ -545,7 +566,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: telemetry_.handle( telemetry_.StepFinishEvent( - step_index=runtime._step_index, + step_index=rt.log.step_index, result=step_result, ) ) @@ -560,7 +581,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: ) await asyncio.sleep(0) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg except BaseException as exc: @@ -576,13 +597,13 @@ async def _generate() -> AsyncGenerator[messages_.Message]: ) telemetry_.end_run(token_run_id) - # Clean up module-level hook registries for this run - hooks_._cleanup_run(runtime._hook_labels) + hooks_._cleanup_run(rt.executor._hook_labels) if mcp_token is not None: await mcp.client.close_connections() mcp.client._pool.reset(mcp_token) + context_._context.reset(token_context) _runtime.reset(token_runtime) result._messages = _generate() diff --git a/src/vercel_ai_sdk/agents/streams.py b/src/vercel_ai_sdk/agents/streams.py index 80ca7cf8..fadf6747 100644 --- a/src/vercel_ai_sdk/agents/streams.py +++ b/src/vercel_ai_sdk/agents/streams.py @@ -66,10 +66,9 @@ def total_usage(self) -> messages_.Usage | None: def stream[**P]( fn: Callable[P, AsyncGenerator[messages_.Message]], ) -> Callable[P, Awaitable[StreamResult]]: - """ - Decorator to put an async generator into the Runtime queue. + """Decorator to put an async generator into the LoopExecutor queue. - The decorated function submits its work to the Runtime queue and + The decorated function submits its work to the executor queue and blocks until run() processes it, then returns the StreamResult. If a checkpoint exists with a cached result for this step index, @@ -85,22 +84,22 @@ async def wrapped(*args: Any, **kwargs: Any) -> StreamResult: raise ValueError("No Runtime context - must be called within ai.run()") # Replay: return cached result if available - cached = rt.try_replay_step() + cached = rt.log.try_replay_step() if cached is not None: return cached - # Fresh execution: submit to queue and wait + # Fresh execution: submit to executor queue and wait future: asyncio.Future[StreamResult] = asyncio.Future() async def stream_fn() -> AsyncGenerator[messages_.Message]: async for msg in fn(*args, **kwargs): yield msg - await rt.put_step(stream_fn, future) + await rt.executor.put_step(stream_fn, future) result = await future # Record for checkpoint - rt.record_step(result) + rt.log.record_step(result) return result return wrapped diff --git a/src/vercel_ai_sdk/agents/tools.py b/src/vercel_ai_sdk/agents/tools.py index f4e3744b..39a9aa28 100644 --- a/src/vercel_ai_sdk/agents/tools.py +++ b/src/vercel_ai_sdk/agents/tools.py @@ -9,6 +9,7 @@ from ..types.tools import ToolLike as ToolLike from ..types.tools import ToolSchema as ToolSchema +from .context import ToolSource if TYPE_CHECKING: from . import runtime as runtime_ @@ -36,10 +37,12 @@ def __init__( fn: Callable[P, Awaitable[R]], schema: ToolSchema, validator: type[pydantic.BaseModel] | None = None, + source: ToolSource | None = None, ) -> None: self._fn = fn self._validator = validator self.schema = schema + self.source = source async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return await self._fn(*args, **kwargs) @@ -102,8 +105,32 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: return_type=hints.get("return", None), ) - t = Tool(fn=fn, schema=schema, validator=validator) + source = ToolSource( + kind="python", + module=getattr(fn, "__module__", None), + qualname=getattr(fn, "__qualname__", None), + ) + + t = Tool(fn=fn, schema=schema, validator=validator, source=source) # 3. register in global registry _tool_registry[t.name] = t return t + + +def _unresolvable_tool_fn(name: str) -> Callable[..., Awaitable[Any]]: + """Create a placeholder async function for schema-only tools. + + Used by ``Context.from_dict()`` when a tool's source cannot be + resolved to live code. + """ + + async def _placeholder(**kwargs: Any) -> Any: + raise RuntimeError( + f"Tool {name!r} was reconstructed from serialized context " + f"and has no executable implementation." + ) + + _placeholder.__name__ = name + _placeholder.__qualname__ = name + return _placeholder diff --git a/src/vercel_ai_sdk/agents2/__init__.py b/src/vercel_ai_sdk/agents2/__init__.py deleted file mode 100644 index 8a4059d0..00000000 --- a/src/vercel_ai_sdk/agents2/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Agent loop orchestration — tools, hooks, runtime, and streaming. - -Depends on types/ and models2/. Provides the loop machinery that -plugs a model into a tool-calling loop with hooks and checkpoints. -""" - -from . import mcp -from .agent import Agent, AgentRun, LoopFn, agent, stream_step -from .checkpoint import Checkpoint, PendingHookInfo -from .context import Context, ToolSource, get_context -from .hooks import Hook, ToolApproval, hook -from .runtime import ( - EventLog, - HookInfo, - LoopExecutor, - RunResult, - Runtime, - execute_tool, - get_checkpoint, - run, -) -from .streams import StreamResult, stream -from .tools import Tool, ToolLike, ToolSchema, get_tool, tool - -__all__ = [ - # Agent (primary user API) - "Agent", - "AgentRun", - "agent", - "LoopFn", - # Composition primitives - "stream_step", - "execute_tool", - "get_checkpoint", - # Context - "Context", - "ToolSource", - "get_context", - # Runtime (developer API) - "Runtime", - "EventLog", - "LoopExecutor", - "RunResult", - "HookInfo", - "run", - # Stream - "stream", - "StreamResult", - # Tools - "Tool", - "ToolLike", - "ToolSchema", - "tool", - "get_tool", - # Hooks - "Hook", - "hook", - "ToolApproval", - # Checkpoint - "Checkpoint", - "PendingHookInfo", - # Submodules - "mcp", -] diff --git a/src/vercel_ai_sdk/agents2/checkpoint.py b/src/vercel_ai_sdk/agents2/checkpoint.py deleted file mode 100644 index c3d079bc..00000000 --- a/src/vercel_ai_sdk/agents2/checkpoint.py +++ /dev/null @@ -1,48 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pydantic - -from ..types import messages as messages_ -from . import streams as streams_ - - -class StepEvent(pydantic.BaseModel): - """A completed @stream step.""" - - index: int - messages: list[messages_.Message] - - def to_stream_result(self) -> streams_.StreamResult: - return streams_.StreamResult(messages=list(self.messages)) - - -class ToolEvent(pydantic.BaseModel): - """A completed tool execution.""" - - tool_call_id: str - result: Any - status: str = "result" # "result" | "error" - - -class HookEvent(pydantic.BaseModel): - """A resolved hook.""" - - label: str - resolution: dict[str, Any] - - -class PendingHookInfo(pydantic.BaseModel): - """A hook that was suspended but not resolved when the run ended.""" - - label: str - hook_type: str - metadata: dict[str, Any] = {} - - -class Checkpoint(pydantic.BaseModel): - steps: list[StepEvent] = [] - tools: list[ToolEvent] = [] - hooks: list[HookEvent] = [] - pending_hooks: list[PendingHookInfo] = [] diff --git a/src/vercel_ai_sdk/agents2/hooks.py b/src/vercel_ai_sdk/agents2/hooks.py deleted file mode 100644 index 948539bb..00000000 --- a/src/vercel_ai_sdk/agents2/hooks.py +++ /dev/null @@ -1,245 +0,0 @@ -from __future__ import annotations - -import asyncio -from typing import TYPE_CHECKING, Any, ClassVar - -import pydantic - -from ..types import messages as messages_ - -if TYPE_CHECKING: - from . import runtime as runtime_ - - -# --------------------------------------------------------------------------- -# Module-level hook registries -# -# _live_hooks: -# Populated by Hook.create() when a hook suspends inside a running graph. -# Maps hook label -> (future, metadata dict, Runtime). -# Consumed by Hook.resolve() / Hook.cancel() to unblock the awaiting -# coroutine. Entries are removed when the hook resolves, cancels, or -# the run completes. -# -# _pending_resolutions: -# Populated by Hook.resolve() when no live hook exists yet (serverless -# re-entry: the user calls resolve() *before* ai.run() replays the graph). -# Maps hook label -> validated resolution dict. -# Consumed by Hook.create() at the start of graph execution — if a -# pre-registered resolution exists for the label, the hook returns -# immediately without suspending. Entries are removed on consumption. -# --------------------------------------------------------------------------- - -_live_hooks: dict[ - str, tuple[asyncio.Future[Any], dict[str, Any], runtime_.Runtime] -] = {} - -_pending_resolutions: dict[str, dict[str, Any]] = {} -# label -> validated resolution dict - - -def _cleanup_run(labels: set[str]) -> None: - """Remove all registry entries associated with a finished run.""" - for label in labels: - _live_hooks.pop(label, None) - _pending_resolutions.pop(label, None) - - -class Hook[T: pydantic.BaseModel]: - """Hook: a suspension point that requires external input to continue. - - Usage in graph code: - - approval = await ToolApproval.create("approve_delete", metadata={...}) - if approval.granted: - ... - - Resolution from outside the graph: - - ToolApproval.resolve("approve_delete", {"granted": True, ...}) - - Behavior depends on the ``cancels_future`` class variable: - - cancels_future=False (default, long-running): the await blocks until - Hook.resolve() is called from outside the graph (e.g., websocket - handler, API endpoint). - - cancels_future=True (serverless): if no resolution is available, the - hook's future is cancelled by run(). The branch receives CancelledError - and dies cleanly. On re-entry, call Hook.resolve() before ai.run() to - pre-register the resolution, then pass checkpoint= to replay. - """ - - _schema: ClassVar[type[pydantic.BaseModel]] - hook_type: ClassVar[str] - cancels_future: ClassVar[bool] = False - - @classmethod - async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: - """Create a hook and await its resolution. - - The hook is submitted to the LoopExecutor's step queue. run() will - either: - - Resolve immediately (if a resolution is available from checkpoint - or pre-registered via Hook.resolve()) - - Cancel the future (cancels_future=True, serverless mode) - - Hold the future (cancels_future=False, long-running mode) - """ - from . import runtime as rt_mod - - rt = rt_mod._runtime.get(None) - if rt is None: - raise ValueError("No Runtime context - must be called within ai.run()") - - # Check pre-registered resolutions (serverless re-entry path) - pre_registered = _pending_resolutions.pop(label, None) - if pre_registered is not None: - rt.log.record_hook(label, pre_registered) - return cls._schema(**pre_registered) # type: ignore[return-value] - - # Check checkpoint for a previously resolved value - resolution = rt.log.get_hook_resolution(label) - if resolution is not None: - rt.log.record_hook(label, resolution) - return cls._schema(**resolution) # type: ignore[return-value] - - # Submit to executor queue — run() decides what to do - future: asyncio.Future[dict[str, Any]] = asyncio.Future() - suspension = rt_mod.HookSuspension( - label=label, - hook_type=cls.hook_type, - metadata=metadata or {}, - future=future, - cancels_future=cls.cancels_future, - ) - await rt.executor.put_hook(suspension) - - # Register in module-level registry for external resolution - hook_metadata = metadata or {} - _live_hooks[label] = (future, hook_metadata, rt) - rt.executor.track_hook_label(label) - - # Await resolution — may be resolved immediately by run(), - # cancelled by run() (serverless), or resolved later by - # Hook.resolve() (long-running). - resolution = await future - - # Clean up - _live_hooks.pop(label, None) - - # Record for checkpoint - rt.log.record_hook(label, resolution) - - # Emit resolved message - await rt.executor.put_message( - messages_.Message( - role="assistant", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type=cls.hook_type, - status="resolved", - metadata=hook_metadata, - resolution=resolution, - ) - ], - ) - ) - - return cls._schema(**resolution) # type: ignore[return-value] - - @classmethod - def resolve(cls, label: str, data: T | dict[str, Any]) -> None: - """Resolve a hook by label. - - Works in two modes: - - 1. Live hook exists (long-running): validates data, resolves the - future immediately, unblocking the awaiting coroutine. - - 2. No live hook yet (serverless re-entry): validates data and - stashes it in the pre-registration registry. When ai.run() - replays the graph and Hook.create() executes, it finds the - pre-registered resolution and returns without suspending. - """ - # Validate and normalize to dict - if isinstance(data, dict): - validated = cls._schema(**data) - resolution = validated.model_dump() - else: - if not isinstance(data, cls._schema): - raise TypeError( - f"Expected {cls._schema.__name__} or dict, " - f"got {type(data).__name__}" - ) - resolution = data.model_dump() - - # Path 1: live hook — resolve the future directly - if label in _live_hooks: - future, _, _rt = _live_hooks[label] - future.set_result(resolution) - return - - # Path 2: no live hook — pre-register for later consumption - _pending_resolutions[label] = resolution - - @classmethod - async def cancel(cls, label: str, reason: str | None = None) -> None: - """Cancel a pending hook. - - Only works for live hooks (long-running mode). Raises if the - hook is not currently pending. - """ - if label not in _live_hooks: - raise ValueError(f"No pending hook with label: {label}") - - future, hook_metadata, rt = _live_hooks.pop(label) - future.cancel(reason) - - await rt.executor.put_message( - messages_.Message( - role="assistant", - parts=[ - messages_.HookPart( - hook_id=label, - hook_type=cls.hook_type, - status="cancelled", - metadata=hook_metadata, - ) - ], - ) - ) - - -def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: - """Decorator to create a Hook type from a pydantic model. - - The pydantic model defines the schema for the hook's resolution payload. - """ - hook_impl = type( - cls.__name__, - (Hook,), - { - "_schema": cls, - "hook_type": cls.__name__, - "cancels_future": cls.__dict__.get("cancels_future", False), - "__doc__": cls.__doc__, - }, - ) - - return hook_impl - - -@hook -class ToolApproval(pydantic.BaseModel): - """Prewired hook for tool call approval. - - Used by the AI SDK UI adapter to bridge the protocol's - tool-approval-request / approval-responded flow to the - hook system. - """ - - cancels_future: ClassVar[bool] = True - - granted: bool - reason: str | None = None diff --git a/src/vercel_ai_sdk/agents2/mcp/__init__.py b/src/vercel_ai_sdk/agents2/mcp/__init__.py deleted file mode 100644 index c1202f63..00000000 --- a/src/vercel_ai_sdk/agents2/mcp/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .client import get_http_tools, get_stdio_tools - -__all__ = [ - "get_stdio_tools", - "get_http_tools", -] diff --git a/src/vercel_ai_sdk/agents2/mcp/client.py b/src/vercel_ai_sdk/agents2/mcp/client.py deleted file mode 100644 index def1f0e6..00000000 --- a/src/vercel_ai_sdk/agents2/mcp/client.py +++ /dev/null @@ -1,301 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextlib -import contextvars -import dataclasses -import json -from collections.abc import Callable -from typing import Any - -import httpx -import mcp.client.session -import mcp.client.stdio -import mcp.client.streamable_http -import mcp.types - -from .. import context as context_ -from .. import tools as tools_ - -__all__ = [ - "get_stdio_tools", - "get_http_tools", - "close_connections", -] - - -@dataclasses.dataclass -class _Connection: - """Internal connection state - never exposed to users.""" - - client: mcp.client.session.ClientSession - exit_stack: contextlib.AsyncExitStack - - -# Connection pool stored in contextvar, scoped to execute() -# The pool is set by execute() and cleaned up when execute() finishes -_pool: contextvars.ContextVar[dict[str, _Connection] | None] = contextvars.ContextVar( - "mcp_connections", default=None -) - -_pool_lock = asyncio.Lock() - - -async def _get_or_create_connection( - key: str, - transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], -) -> mcp.client.session.ClientSession: - """Get an existing connection or create a new one.""" - pool = _pool.get() - - if pool is None: - raise RuntimeError( - "MCP tools must be used inside ai.execute(). " - "The connection pool is not initialized." - ) - - async with _pool_lock: - if key in pool: - return pool[key].client - - # Use AsyncExitStack for clean resource management - exit_stack = contextlib.AsyncExitStack() - - try: - # Enter the transport context - streams = await exit_stack.enter_async_context(transport_factory()) - - # Handle both (read, write) and (read, write, callback) returns - read_stream, write_stream = streams[0], streams[1] - - # Create and initialize the client session - client = mcp.client.session.ClientSession( - read_stream=read_stream, - write_stream=write_stream, - ) - await exit_stack.enter_async_context(client) - await client.initialize() - - pool[key] = _Connection(client=client, exit_stack=exit_stack) - return client - - except BaseException: - # Clean up on any error during setup - await exit_stack.aclose() - raise - - -def _make_tool_fn( - connection_key: str, - tool_name: str, - transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], -) -> Callable[..., Any]: - """Create a tool function that manages its own connection.""" - - async def call_tool(**kwargs: Any) -> Any: - client = await _get_or_create_connection(connection_key, transport_factory) - try: - result = await asyncio.wait_for( - client.call_tool(tool_name, kwargs), - timeout=30.0, - ) - except TimeoutError as e: - raise RuntimeError( - f"MCP tool call timed out after 30 seconds: {tool_name}" - ) from e - - # Handle error responses - if result.isError: - error_text = " ".join( - part.text - for part in result.content - if isinstance(part, mcp.types.TextContent) - ) - raise RuntimeError(f"MCP tool error: {error_text or 'Unknown error'}") - - # Prefer structured content if available - if result.structuredContent is not None: - return result.structuredContent - - # Fall back to parsing content - for part in result.content: - if isinstance(part, mcp.types.TextContent): - text = part.text - # Try to parse JSON, otherwise return raw text - if text.startswith(("{", "[")): - try: - return json.loads(text) - except json.JSONDecodeError: - pass - return text - - return result.content - - return call_tool - - -async def get_stdio_tools( - command: str, - *args: str, - env: dict[str, str] | None = None, - cwd: str | None = None, - tool_prefix: str | None = None, -) -> list[tools_.Tool[..., Any]]: - """ - Get tools from an MCP server running as a subprocess. - - Connection is managed automatically - created on first use, cleaned up - when execute() finishes. - - Args: - command: The command to run (e.g., "npx", "python"). - *args: Arguments to pass to the command. - env: Environment variables for the subprocess. - cwd: Working directory for the subprocess. - tool_prefix: Optional prefix to add to all tool names. - - Returns: - List of Tool objects that can be passed to stream_loop. - - Example: - tools = await ai.mcp.get_stdio_tools( - "npx", "-y", "@anthropic/mcp-server-filesystem", "/tmp" - ) - """ - connection_key = f"stdio:{command}:{':'.join(args)}" - - def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: - return mcp.client.stdio.stdio_client( - mcp.client.stdio.StdioServerParameters( - command=command, - args=list(args), - env=env, - cwd=cwd, - ) - ) - - client = await _get_or_create_connection(connection_key, transport_factory) - result = await client.list_tools() - - return [ - _mcp_tool_to_native(mcp_tool, connection_key, transport_factory, tool_prefix) - for mcp_tool in result.tools - ] - - -async def get_http_tools( - url: str, - *, - headers: dict[str, str] | None = None, - tool_prefix: str | None = None, -) -> list[tools_.Tool[..., Any]]: - """ - Get tools from an MCP server over HTTP (Streamable HTTP transport). - - Connection is managed automatically - created on first use, cleaned up - when execute() finishes. - - Args: - url: The URL of the MCP server endpoint. - headers: Optional HTTP headers (e.g., for authentication). - tool_prefix: Optional prefix to add to all tool names. - - Returns: - List of Tool objects that can be passed to stream_loop. - - Example: - tools = await ai.mcp.get_http_tools( - "http://localhost:3000/mcp", - headers={"Authorization": "Bearer xxx"} - ) - """ - connection_key = f"http:{url}" - - def transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: - http_client = httpx.AsyncClient(headers=headers) if headers else None - return mcp.client.streamable_http.streamable_http_client( - url=url, http_client=http_client - ) - - client = await _get_or_create_connection(connection_key, transport_factory) - result = await client.list_tools() - - return [ - _mcp_tool_to_native(mcp_tool, connection_key, transport_factory, tool_prefix) - for mcp_tool in result.tools - ] - - -def _mcp_tool_to_native( - mcp_tool: mcp.types.Tool, - connection_key: str, - transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], - tool_prefix: str | None, -) -> tools_.Tool[..., Any]: - """Convert an MCP tool to a native Tool.""" - name = mcp_tool.name - if tool_prefix: - name = f"{tool_prefix}_{name}" - - schema = tools_.ToolSchema( - name=name, - description=mcp_tool.description or "", - param_schema=mcp_tool.inputSchema, - return_type=Any, - ) - - # Determine source provenance from connection key - if connection_key.startswith("http:"): - source = context_.ToolSource( - kind="mcp_http", - uri=connection_key.removeprefix("http:"), - ) - elif connection_key.startswith("stdio:"): - source = context_.ToolSource( - kind="mcp_stdio", - server_command=connection_key.removeprefix("stdio:"), - ) - else: - source = context_.ToolSource(kind="mcp") - - t = tools_.Tool( - fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), - schema=schema, - source=source, - ) - - # Register on active Context if available, else fall back to global - ctx = context_._context.get(None) - if ctx is not None: - ctx.register_tool(t) - tools_._tool_registry[name] = t - return t - - -async def close_connections() -> None: - """ - Close all MCP connections in the current context. - - This is called automatically by execute(), but can be called - manually for explicit cleanup. - """ - pool = _pool.get() - if pool is None: - return - - async with _pool_lock: - if not pool: - return - - # Use TaskGroup for concurrent cleanup - async with asyncio.TaskGroup() as tg: - for conn in pool.values(): - tg.create_task(_close_connection_safely(conn)) - - pool.clear() - - -async def _close_connection_safely(conn: _Connection) -> None: - """Close a connection, suppressing any errors.""" - with contextlib.suppress(Exception): - await conn.exit_stack.aclose() diff --git a/src/vercel_ai_sdk/agents2/runtime.py b/src/vercel_ai_sdk/agents2/runtime.py deleted file mode 100644 index cca8fb8e..00000000 --- a/src/vercel_ai_sdk/agents2/runtime.py +++ /dev/null @@ -1,610 +0,0 @@ -from __future__ import annotations - -import asyncio -import contextvars -import dataclasses -import json -import logging -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine -from typing import Any, get_type_hints - -import pydantic - -from ..telemetry import events as telemetry_ -from ..types import messages as messages_ -from . import checkpoint as checkpoint_ -from . import context as context_ -from . import hooks as hooks_ -from . import mcp -from . import streams as streams_ -from . import tools as tools_ - -logger = logging.getLogger(__name__) - - -# ── EventLog ────────────────────────────────────────────────────── -# -# Pure bookkeeping: replay from checkpoint + record new events. -# No asyncio, no queues — just data in, data out. -# - - -class EventLog: - """Replay/record layer backed by a Checkpoint. - - Holds replay cursors (read pointer into the checkpoint) and - recording lists (new events produced during this run). - Completely synchronous — no queues, no async. - """ - - def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: - self._checkpoint = checkpoint or checkpoint_.Checkpoint() - - # Replay cursors - self._step_index: int = 0 - self._tool_replay: dict[str, checkpoint_.ToolEvent] = { - t.tool_call_id: t for t in self._checkpoint.tools - } - self._hook_replay: dict[str, dict[str, Any]] = { - h.label: h.resolution for h in self._checkpoint.hooks - } - - # Recording lists (new events from this run) - self._step_log: list[checkpoint_.StepEvent] = [] - self._tool_log: list[checkpoint_.ToolEvent] = [] - self._hook_log: list[checkpoint_.HookEvent] = [] - - # ── Steps ───────────────────────────────────────────────── - - @property - def step_index(self) -> int: - return self._step_index - - def try_replay_step(self) -> streams_.StreamResult | None: - if self._step_index < len(self._checkpoint.steps): - event = self._checkpoint.steps[self._step_index] - self._step_index += 1 - logger.info("Replaying step %d from checkpoint", event.index) - return event.to_stream_result() - return None - - def record_step(self, result: streams_.StreamResult) -> None: - event = checkpoint_.StepEvent( - index=self._step_index, - messages=list(result.messages), - ) - self._step_log.append(event) - self._step_index += 1 - - # ── Tools ───────────────────────────────────────────────── - - def try_replay_tool(self, tool_call_id: str) -> checkpoint_.ToolEvent | None: - event = self._tool_replay.get(tool_call_id) - if event is not None: - logger.info( - "Replaying tool %s (call_id=%s) from checkpoint", - event.tool_call_id, - tool_call_id, - ) - return event - - def record_tool( - self, tool_call_id: str, result: Any, *, status: str = "result" - ) -> None: - self._tool_log.append( - checkpoint_.ToolEvent( - tool_call_id=tool_call_id, result=result, status=status - ) - ) - - # ── Hooks ───────────────────────────────────────────────── - - def get_hook_resolution(self, label: str) -> dict[str, Any] | None: - resolution = self._hook_replay.get(label) - if resolution is not None: - logger.info("Resolving hook '%s' from checkpoint", label) - return resolution - - def record_hook(self, label: str, resolution: dict[str, Any]) -> None: - self._hook_log.append(checkpoint_.HookEvent(label=label, resolution=resolution)) - - # ── Snapshot ────────────────────────────────────────────── - - def checkpoint( - self, pending_hooks: list[checkpoint_.PendingHookInfo] | None = None - ) -> checkpoint_.Checkpoint: - """Build a full Checkpoint merging prior state + new recordings.""" - return checkpoint_.Checkpoint( - steps=list(self._checkpoint.steps) + self._step_log, - tools=list(self._checkpoint.tools) + self._tool_log, - hooks=list(self._checkpoint.hooks) + self._hook_log, - pending_hooks=pending_hooks or [], - ) - - -# ── LoopExecutor ───────────────────────────────────────────────── -# -# Async coordination: queues that let graph code (streams, hooks, -# tools) talk to the driver loop. Pure mailbox — no replay, no -# checkpoint awareness. -# - - -@dataclasses.dataclass -class HookSuspension: - """Submitted to the step queue when a hook needs resolution.""" - - label: str - hook_type: str - metadata: dict[str, Any] - future: asyncio.Future[Any] - cancels_future: bool = False - - -class LoopExecutor: - """Async coordination layer between graph code and the driver loop. - - Graph code (``@stream`` decorators, hooks, tool execution) submits - work via the producer methods. The driver loop consumes via - ``next()`` and ``drain_messages()``. - """ - - class _Sentinel: - pass - - _SENTINEL = _Sentinel() - - def __init__(self) -> None: - self._step_queue: asyncio.Queue[ - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] - | HookSuspension - | LoopExecutor._Sentinel - ] = asyncio.Queue() - - self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() - - # Pending hooks (unresolved during this run) - self._pending_hooks: dict[str, HookSuspension] = {} - - # Track hook labels registered in this run for cleanup - self._hook_labels: set[str] = set() - - # ── Producers (called by graph code) ────────────────────── - - async def put_step( - self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] - ) -> None: - await self._step_queue.put((step_fn, future)) - - async def put_hook(self, suspension: HookSuspension) -> None: - await self._step_queue.put(suspension) - - async def put_message(self, message: messages_.Message) -> None: - await self._message_queue.put(message) - - async def done(self) -> None: - await self._step_queue.put(self._SENTINEL) - - # ── Consumer (called by driver loop) ────────────────────── - - async def next( - self, timeout: float = 0.1 - ) -> ( - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] - | HookSuspension - | None - ): - """Pull the next item from the step queue. - - Returns ``None`` on timeout (no item available). - Returns the sentinel's semantic equivalent by raising StopIteration - when the graph signals completion. - """ - try: - item = await asyncio.wait_for(self._step_queue.get(), timeout=timeout) - except TimeoutError: - return None - - if isinstance(item, LoopExecutor._Sentinel): - raise _LoopDone - return item - - def drain_messages(self) -> list[messages_.Message]: - msgs: list[messages_.Message] = [] - while not self._message_queue.empty(): - try: - msgs.append(self._message_queue.get_nowait()) - except asyncio.QueueEmpty: - break - return msgs - - # ── Hook label tracking ─────────────────────────────────── - - def track_hook_label(self, label: str) -> None: - self._hook_labels.add(label) - - def pending_hook_infos(self) -> list[checkpoint_.PendingHookInfo]: - return [ - checkpoint_.PendingHookInfo( - label=sus.label, - hook_type=sus.hook_type, - metadata=sus.metadata, - ) - for sus in self._pending_hooks.values() - ] - - -class _LoopDone(Exception): - """Internal signal: the loop function has finished.""" - - -# ── Runtime ─────────────────────────────────────────────────────── -# -# Thin composition of EventLog + LoopExecutor. -# The context var points here; graph code accesses rt.log and -# rt.executor directly. -# - - -class Runtime: - """Central coordinator — composes EventLog and LoopExecutor. - - Graph code accesses ``rt.log`` for replay/record and - ``rt.executor`` for async coordination. - """ - - def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: - self.log = EventLog(checkpoint) - self.executor = LoopExecutor() - - def checkpoint(self) -> checkpoint_.Checkpoint: - return self.log.checkpoint( - pending_hooks=self.executor.pending_hook_infos(), - ) - - -_runtime: contextvars.ContextVar[Runtime] = contextvars.ContextVar("runtime") - - -def get_checkpoint() -> checkpoint_.Checkpoint: - """Get the current checkpoint from the active Runtime.""" - return _runtime.get().checkpoint() - - -def _find_runtime_param(fn: Callable[..., Any]) -> str | None: - """Find a parameter typed as Runtime, return its name or None.""" - try: - hints = get_type_hints(fn) - except Exception: - return None - for name, hint in hints.items(): - if hint is Runtime: - return name - return None - - -async def execute_tool( - tool_call: messages_.ToolPart, - message: messages_.Message | None = None, -) -> Any: - """Execute a single tool call with replay support. - - Looks up the tool by name — first from the active Context (if any), - then from the global registry. Executes it and updates the ToolPart - (and parent Message) with the result. Emits the updated message to - the LoopExecutor queue so the UI sees the transition from - status="pending" to status="result" (or "error"). - - If a checkpoint exists with a cached result for this tool_call_id, - returns the cached result without re-executing. - """ - rt = _runtime.get(None) - - # Replay: return cached result if available - if rt: - cached = rt.log.try_replay_tool(tool_call.tool_call_id) - if cached is not None: - if cached.status == "error": - tool_call.set_error(cached.result) - else: - tool_call.set_result(cached.result) - return cached.result - - telemetry_.handle( - telemetry_.ToolCallStartEvent( - tool_name=tool_call.tool_name, - tool_call_id=tool_call.tool_call_id, - args=tool_call.tool_args, - ) - ) - t0 = telemetry_.time_ms() - - # Fresh execution — resolve from Context first, then global registry - tool: tools_.Tool[..., Any] | None = None - ctx = context_._context.get(None) - if ctx is not None: - tool = ctx.get_tool(tool_call.tool_name) - if tool is None: - tool = tools_.get_tool(tool_call.tool_name) - if tool is None: - raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") - - error_str: str | None = None - try: - result = await tool.validate_and_call(tool_call.tool_args, rt) - tool_call.set_result(result) - except (json.JSONDecodeError, pydantic.ValidationError) as exc: - result = f"{type(exc).__name__}: {exc}" - error_str = result - tool_call.set_error(result) - - telemetry_.handle( - telemetry_.ToolCallFinishEvent( - tool_name=tool_call.tool_name, - tool_call_id=tool_call.tool_call_id, - result=result, - error=error_str, - duration_ms=telemetry_.time_ms() - t0, - ) - ) - - # Record for checkpoint - if rt: - rt.log.record_tool(tool_call.tool_call_id, result, status=tool_call.status) - - # Emit updated message so UI sees status change - if rt and message: - await rt.executor.put_message(message.model_copy(deep=True)) - - return result - - -# ── RunResult ───────────────────────────────────────────────────── - - -@dataclasses.dataclass -class HookInfo: - """Info about a pending (unresolved) hook, exposed on RunResult.""" - - label: str - hook_type: str - metadata: dict[str, Any] - - -class RunResult: - """Returned by run(). Async-iterate for messages, then check state. - - Usage: - result = ai.run(my_graph, llm, query) - async for msg in result: - ... - result.checkpoint # Checkpoint with all completed work - result.pending_hooks # dict of unresolved hooks (empty if graph completed) - """ - - def __init__(self) -> None: - self._messages: AsyncGenerator[messages_.Message] | None = None - self._runtime: Runtime | None = None - - @property - def checkpoint(self) -> checkpoint_.Checkpoint: - if self._runtime is None: - return checkpoint_.Checkpoint() - return self._runtime.checkpoint() - - @property - def pending_hooks(self) -> dict[str, HookInfo]: - if self._runtime is None: - return {} - return { - label: HookInfo( - label=sus.label, - hook_type=sus.hook_type, - metadata=sus.metadata, - ) - for label, sus in self._runtime.executor._pending_hooks.items() - } - - async def __aiter__(self) -> AsyncGenerator[messages_.Message]: - if self._messages is not None: - async for msg in self._messages: - yield msg - - -# ── run() ───────────────────────────────────────────────────────── - - -async def _stop_when_done(executor: LoopExecutor, task: Awaitable[None]) -> None: - try: - await task - finally: - await executor.done() - - -def run( - root: Callable[..., Coroutine[Any, Any, Any]], - *args: Any, - checkpoint: checkpoint_.Checkpoint | None = None, - context: context_.Context | None = None, -) -> RunResult: - """Main entry point. - - 1. Starts the root function as a background task - 2. Pulls steps and hook suspensions from the LoopExecutor queue - 3. Executes each step, yielding messages - 4. Resolves or suspends hooks depending on the hook's cancels_future - 5. Returns RunResult with .checkpoint and .pending_hooks - - Args: - root: The loop function to execute. - *args: Positional arguments forwarded to ``root``. - checkpoint: Checkpoint to resume from. - context: LLM prompt context (tools, system prompt, messages). - If ``None``, an empty Context is created automatically. - """ - result = RunResult() - - # Discard stale checkpoints: if the checkpoint has pending hooks but - # none of them have been resolved, this isn't a resume. - effective_checkpoint = checkpoint - if checkpoint and checkpoint.pending_hooks: - pending_labels = [ph.label for ph in checkpoint.pending_hooks] - has_resolution = any( - label in hooks_._pending_resolutions for label in pending_labels - ) - if not has_resolution: - logger.info( - "Discarding stale checkpoint: pending hooks %s have no " - "matching resolutions", - pending_labels, - ) - effective_checkpoint = None - else: - logger.info( - "Resuming from checkpoint with %d pending hook(s): %s", - len(pending_labels), - pending_labels, - ) - - async def _generate() -> AsyncGenerator[messages_.Message]: - rt = Runtime(checkpoint=effective_checkpoint) - result._runtime = rt - token_runtime = _runtime.set(rt) - - ctx = context or context_.Context() - token_context = context_._context.set(ctx) - - token_run_id = telemetry_.start_run() - - telemetry_.handle(telemetry_.RunStartEvent()) - - mcp_pool: dict[str, mcp.client._Connection] = {} - mcp_token = mcp.client._pool.set(mcp_pool) - - kwargs: dict[str, Any] = {} - if runtime_param := _find_runtime_param(root): - kwargs[runtime_param] = rt - - run_error: BaseException | None = None - total_usage: messages_.Usage | None = None - - try: - async with asyncio.TaskGroup() as tg: - _task: asyncio.Task[None] = tg.create_task( - _stop_when_done(rt.executor, root(*args, **kwargs)) - ) - - while True: - # Drain pending messages - for msg in rt.executor.drain_messages(): - yield msg - - # Pull next item from the graph executor - try: - item = await rt.executor.next() - except _LoopDone: - for msg in rt.executor.drain_messages(): - yield msg - break - - if item is None: - # Timeout — no item available, loop again - continue - - # ── Hook suspension ──────────────────────── - if isinstance(item, HookSuspension): - resolution = rt.log.get_hook_resolution(item.label) - if resolution is not None: - item.future.set_result(resolution) - rt.log.record_hook(item.label, resolution) - else: - rt.executor._pending_hooks[item.label] = item - if item.cancels_future: - item.future.cancel() - - yield messages_.Message( - role="assistant", - parts=[ - messages_.HookPart( - hook_id=item.label, - hook_type=item.hook_type, - status="pending", - metadata=item.metadata, - ) - ], - ) - - await asyncio.sleep(0) - for msg in rt.executor.drain_messages(): - yield msg - continue - - # ── Regular step ─────────────────────────── - step_fn, future = item - - telemetry_.handle( - telemetry_.StepStartEvent( - step_index=rt.log.step_index, - ) - ) - - for tool_msg in rt.executor.drain_messages(): - yield tool_msg - - result_messages: list[messages_.Message] = [] - - async for msg in step_fn(): - msg_copy = msg.model_copy(deep=True) - yield msg_copy - result_messages.append(msg) - - for tool_msg in rt.executor.drain_messages(): - yield tool_msg - - step_result = streams_.StreamResult(messages=result_messages) - future.set_result(step_result) - - telemetry_.handle( - telemetry_.StepFinishEvent( - step_index=rt.log.step_index, - result=step_result, # type: ignore[arg-type] - ) - ) - - # Accumulate usage for run-level telemetry - step_usage = step_result.usage - if step_usage is not None: - total_usage = ( - step_usage - if total_usage is None - else total_usage + step_usage - ) - - await asyncio.sleep(0) - for tool_msg in rt.executor.drain_messages(): - yield tool_msg - - except BaseException as exc: - run_error = exc - raise - - finally: - telemetry_.handle( - telemetry_.RunFinishEvent( - usage=total_usage, - error=run_error, - ) - ) - telemetry_.end_run(token_run_id) - - hooks_._cleanup_run(rt.executor._hook_labels) - - if mcp_token is not None: - await mcp.client.close_connections() - mcp.client._pool.reset(mcp_token) - - context_._context.reset(token_context) - _runtime.reset(token_runtime) - - result._messages = _generate() - return result diff --git a/src/vercel_ai_sdk/agents2/streams.py b/src/vercel_ai_sdk/agents2/streams.py deleted file mode 100644 index fadf6747..00000000 --- a/src/vercel_ai_sdk/agents2/streams.py +++ /dev/null @@ -1,105 +0,0 @@ -from __future__ import annotations - -import asyncio -import dataclasses -import functools -from collections.abc import AsyncGenerator, Awaitable, Callable -from typing import Any - -from ..types import messages as messages_ - - -@dataclasses.dataclass -class StreamResult: - messages: list[messages_.Message] = dataclasses.field(default_factory=list) - - @property - def last_message(self) -> messages_.Message | None: - return self.messages[-1] if self.messages else None - - @property - def tool_calls(self) -> list[messages_.ToolPart]: - """Get tool calls from the last message.""" - if self.last_message: - return self.last_message.tool_calls - return [] - - @property - def text(self) -> str: - if self.last_message: - return self.last_message.text - return "" - - @property - def output(self) -> Any: - """Parsed structured output from the last message, if available.""" - if self.last_message: - return self.last_message.output - return None - - @property - def usage(self) -> messages_.Usage | None: - """Token usage from the last (most recent) LLM call.""" - if self.last_message: - return self.last_message.usage - return None - - @property - def total_usage(self) -> messages_.Usage | None: - """Accumulated token usage across all LLM calls in this result. - - Sums usage from every message that carries it (i.e. assistant - messages produced by LLM calls). Returns ``None`` if no message - reported usage. - """ - total: messages_.Usage | None = None - for msg in self.messages: - if msg.usage is not None: - total = msg.usage if total is None else total + msg.usage - return total - - -Stream = Callable[[], AsyncGenerator[messages_.Message]] -# maybe it should have a name and an id inferred from LLM outputs - - -def stream[**P]( - fn: Callable[P, AsyncGenerator[messages_.Message]], -) -> Callable[P, Awaitable[StreamResult]]: - """Decorator to put an async generator into the LoopExecutor queue. - - The decorated function submits its work to the executor queue and - blocks until run() processes it, then returns the StreamResult. - - If a checkpoint exists with a cached result for this step index, - returns the cached result without submitting to the queue (replay). - """ - - from . import runtime as runtime_ - - @functools.wraps(fn) - async def wrapped(*args: Any, **kwargs: Any) -> StreamResult: - rt: runtime_.Runtime | None = runtime_._runtime.get(None) - if rt is None: - raise ValueError("No Runtime context - must be called within ai.run()") - - # Replay: return cached result if available - cached = rt.log.try_replay_step() - if cached is not None: - return cached - - # Fresh execution: submit to executor queue and wait - future: asyncio.Future[StreamResult] = asyncio.Future() - - async def stream_fn() -> AsyncGenerator[messages_.Message]: - async for msg in fn(*args, **kwargs): - yield msg - - await rt.executor.put_step(stream_fn, future) - result = await future - - # Record for checkpoint - rt.log.record_step(result) - return result - - return wrapped diff --git a/src/vercel_ai_sdk/agents2/tools.py b/src/vercel_ai_sdk/agents2/tools.py deleted file mode 100644 index 39a9aa28..00000000 --- a/src/vercel_ai_sdk/agents2/tools.py +++ /dev/null @@ -1,136 +0,0 @@ -from __future__ import annotations - -import inspect -import json -from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any, get_type_hints - -import pydantic - -from ..types.tools import ToolLike as ToolLike -from ..types.tools import ToolSchema as ToolSchema -from .context import ToolSource - -if TYPE_CHECKING: - from . import runtime as runtime_ - -# Module-level tool registry - populated at decoration time -_tool_registry: dict[str, Tool[..., Any]] = {} - - -def get_tool(name: str) -> Tool[..., Any] | None: - """Look up a tool by name from the global registry.""" - return _tool_registry.get(name) - - -def _is_runtime_type(hint: Any) -> bool: - """Check if a type hint is the Runtime class.""" - # Import here to avoid circular import at runtime - from .runtime import Runtime - - return hint is Runtime - - -class Tool[**P, R]: - def __init__( - self, - fn: Callable[P, Awaitable[R]], - schema: ToolSchema, - validator: type[pydantic.BaseModel] | None = None, - source: ToolSource | None = None, - ) -> None: - self._fn = fn - self._validator = validator - self.schema = schema - self.source = source - - async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: - return await self._fn(*args, **kwargs) - - async def validate_and_call( - self, json_str: str, runtime: runtime_.Runtime | None - ) -> R: - from .runtime import _find_runtime_param - - kwargs = json.loads(json_str) if json_str else {} - - if runtime and (rt_param := _find_runtime_param(self._fn)): - kwargs[rt_param] = runtime - - # validate llm-generated inputs (skipped for MCP tools) - if self._validator is not None: - self._validator.model_validate(kwargs) - return await self._fn(**kwargs) # type: ignore[call-arg] - - @property - def name(self) -> str: - return self.schema.name - - @property - def description(self) -> str: - return self.schema.description - - @property - def param_schema(self) -> dict[str, Any]: - return self.schema.param_schema - - -def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: - """Decorator to define a tool from an async function.""" - - # 1. build tool schema by parsing the function - sig = inspect.signature(fn) - hints = get_type_hints(fn) if hasattr(fn, "__annotations__") else {} - - fields: dict[str, Any] = {} - - for param_name, param in sig.parameters.items(): - param_type = hints.get(param_name, str) - - if _is_runtime_type(param_type): - continue - if param.default is inspect.Parameter.empty: - fields[param_name] = (param_type, ...) - else: - fields[param_name] = (param_type, param.default) - - validator = pydantic.create_model(f"{fn.__name__}_Args", **fields) - - # 2. instantiate the tool - - schema = ToolSchema( - name=fn.__name__, - description=inspect.getdoc(fn) or "", - param_schema=validator.model_json_schema(), - return_type=hints.get("return", None), - ) - - source = ToolSource( - kind="python", - module=getattr(fn, "__module__", None), - qualname=getattr(fn, "__qualname__", None), - ) - - t = Tool(fn=fn, schema=schema, validator=validator, source=source) - - # 3. register in global registry - _tool_registry[t.name] = t - return t - - -def _unresolvable_tool_fn(name: str) -> Callable[..., Awaitable[Any]]: - """Create a placeholder async function for schema-only tools. - - Used by ``Context.from_dict()`` when a tool's source cannot be - resolved to live code. - """ - - async def _placeholder(**kwargs: Any) -> Any: - raise RuntimeError( - f"Tool {name!r} was reconstructed from serialized context " - f"and has no executable implementation." - ) - - _placeholder.__name__ = name - _placeholder.__qualname__ = name - return _placeholder diff --git a/src/vercel_ai_sdk/models/__init__.py b/src/vercel_ai_sdk/models/__init__.py index 47c129aa..9e921afd 100644 --- a/src/vercel_ai_sdk/models/__init__.py +++ b/src/vercel_ai_sdk/models/__init__.py @@ -1,109 +1,207 @@ -"""Model adapters — standalone LLM streaming layer. +"""models — composable model layer. -Provides the LanguageModel ABC and concrete provider adapters. -Depends only on types/, never on agents/. +Usage:: -Module-level API -~~~~~~~~~~~~~~~~ + from vercel_ai_sdk import models + from vercel_ai_sdk.types import Message, TextPart -.. code-block:: python + model = models.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) + msgs = [Message(role="user", parts=[TextPart(text="hello")])] - import vercel_ai_sdk as ai + # stream — auto-creates client from env vars + async for msg in models.stream(model, msgs): + print(msg.text_delta, end="") - model = ai.models.Model(id="gpt-4o", api="openai", provider="openai") - s = ai.models.stream(model, messages) - async for msg in s: - ... + # buffer the whole response + result = await models.buffer(models.stream(model, msgs)) + print(result.text) - result = await ai.models.generate(model, messages, n=2) + # explicit client + client = models.Client( + base_url="https://custom.example.com/v3/ai", api_key="sk-...", + ) + async for msg in models.stream(model, msgs, client=client): + ... """ from __future__ import annotations -from collections.abc import Sequence +import os +from collections.abc import AsyncGenerator, Sequence from typing import Any import pydantic from ..types import messages as messages_ from ..types import tools as tools_ -from . import ai_gateway, anthropic, core, openai -from .core import ( - GenerateFn, - ImageModel, - LanguageModel, - MediaModel, - MediaResult, - Model, - Stream, - StreamEvent, - StreamFn, - StreamHandler, - VideoModel, - get_generate_fn, - get_stream_fn, - register_generate, - register_stream, -) - -# ── Module-level dispatch functions ─────────────────────────────── - - -def stream( +from .ai_gateway.generate import GenerateParams, ImageParams, VideoParams +from .core.client import Client +from .core.model import Model, ModelCost +from .core.proto import GenerateFn, StreamFn + +# --------------------------------------------------------------------------- +# Adapter registry — maps adapter string → adapter function. +# Adapter modules are imported lazily on first use. +# --------------------------------------------------------------------------- + +_stream_adapters: dict[str, StreamFn] = {} +_generate_adapters: dict[str, GenerateFn] = {} +_adapters_loaded = False + + +def _ensure_adapters() -> None: + """Lazily register built-in adapter functions on first call.""" + global _adapters_loaded # noqa: PLW0603 + if _adapters_loaded: + return + _adapters_loaded = True + + from .ai_gateway import generate as ai_gw_generate + from .ai_gateway import stream as ai_gw_stream + from .anthropic.adapter import stream as anthropic_stream + from .openai.adapter import stream as openai_stream + + _stream_adapters["ai-gateway-v3"] = ai_gw_stream + _generate_adapters["ai-gateway-v3"] = ai_gw_generate + _stream_adapters["openai"] = openai_stream + _stream_adapters["anthropic"] = anthropic_stream + + +def register_stream(adapter: str, fn: StreamFn) -> None: + """Register a stream adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _stream_adapters[adapter] = fn + + +def register_generate(adapter: str, fn: GenerateFn) -> None: + """Register a generate adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _generate_adapters[adapter] = fn + + +# --------------------------------------------------------------------------- +# Provider defaults — base URLs and env var names for auto-client creation. +# --------------------------------------------------------------------------- + +_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { + "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), + "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), + "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), +} + + +def _auto_client(model: Model) -> Client: + """Create a :class:`Client` from env vars for the given model's provider.""" + defaults = _PROVIDER_DEFAULTS.get(model.provider) + if defaults is None: + raise ValueError( + f"No default client config for provider {model.provider!r}. " + f"Pass an explicit client= argument." + ) + base_url, env_var = defaults + return Client(base_url=base_url, api_key=os.environ.get(env_var)) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def stream( model: Model, messages: list[messages_.Message], + *, tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, -) -> Stream: - """Stream an LLM response for the given model. + client: Client | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response. - Looks up the registered :class:`StreamFn` for ``model.api`` and - returns a :class:`Stream` that can be async-iterated *or* awaited. + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from env vars if none is provided, and yields + ``Message`` snapshots. """ - fn = get_stream_fn(model.api) - return Stream(fn(model, messages, tools=tools, output_type=output_type)) + _ensure_adapters() + c = client or _auto_client(model) + adapter_fn = _stream_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_stream_adapters)) or "(none)" + raise KeyError( + f"No stream adapter registered for adapter={model.adapter!r}. " + f"Registered: {registered}" + ) + async for msg in adapter_fn( + c, model, messages, tools=tools, output_type=output_type, **kwargs + ): + yield msg async def generate( model: Model, messages: list[messages_.Message], - **kwargs: Any, + params: GenerateParams | None = None, + *, + client: Client | None = None, ) -> messages_.Message: - """Generate a response (image, video, etc.) for the given model. + """Generate a response (images, video, etc.). - Looks up the registered :class:`GenerateFn` for ``model.api`` and - returns the resulting :class:`Message`. + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from env vars if none is provided. + + ``params`` controls the generation type: + + * :class:`ImageParams` — image generation (``/image-model``). + * :class:`VideoParams` — video generation (``/video-model``). + * ``None`` — auto-detect from ``model.capabilities``. + """ + _ensure_adapters() + c = client or _auto_client(model) + adapter_fn = _generate_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_generate_adapters)) or "(none)" + raise KeyError( + f"No generate adapter registered for adapter={model.adapter!r}. " + f"Registered: {registered}" + ) + return await adapter_fn(c, model, messages, params=params) + + +async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: + """Drain a stream and return the final ``Message``. + + Raises :class:`ValueError` if the stream yields nothing. """ - fn = get_generate_fn(model.api) - return await fn(model, messages, **kwargs) + result: messages_.Message | None = None + async for msg in gen: + result = msg + if result is None: + raise ValueError("empty stream") + return result __all__ = [ - # Model data + # Core types + "Client", + "GenerateFn", + "GenerateParams", + "ImageParams", "Model", - # Execution protocols + "ModelCost", "StreamFn", - "GenerateFn", - "Stream", - # Registry - "register_stream", + "VideoParams", + # Public API + "buffer", + "generate", "register_generate", - "get_stream_fn", - "get_generate_fn", - # Dispatch + "register_stream", "stream", - "generate", - # Legacy ABCs (still in use) - "LanguageModel", - "StreamEvent", - "StreamHandler", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", - "core", - # Provider adapters - "openai", - "anthropic", - "ai_gateway", ] diff --git a/src/vercel_ai_sdk/models/ai_gateway/__init__.py b/src/vercel_ai_sdk/models/ai_gateway/__init__.py index e467b8ec..7cc9f429 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/__init__.py +++ b/src/vercel_ai_sdk/models/ai_gateway/__init__.py @@ -1,14 +1,14 @@ -"""Vercel AI Gateway provider — language, image, and video models.""" +"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" from . import errors -from .image import GatewayImageModel -from .llm import GatewayModel -from .video import GatewayEmbeddingModel, GatewayVideoModel +from .generate import GenerateParams, ImageParams, VideoParams, generate +from .stream import stream __all__ = [ - "GatewayModel", - "GatewayImageModel", - "GatewayVideoModel", - "GatewayEmbeddingModel", + "GenerateParams", + "ImageParams", + "VideoParams", "errors", + "generate", + "stream", ] diff --git a/src/vercel_ai_sdk/models2/ai_gateway/_common.py b/src/vercel_ai_sdk/models/ai_gateway/_common.py similarity index 100% rename from src/vercel_ai_sdk/models2/ai_gateway/_common.py rename to src/vercel_ai_sdk/models/ai_gateway/_common.py diff --git a/src/vercel_ai_sdk/models2/ai_gateway/generate.py b/src/vercel_ai_sdk/models/ai_gateway/generate.py similarity index 100% rename from src/vercel_ai_sdk/models2/ai_gateway/generate.py rename to src/vercel_ai_sdk/models/ai_gateway/generate.py diff --git a/src/vercel_ai_sdk/models/ai_gateway/image.py b/src/vercel_ai_sdk/models/ai_gateway/image.py deleted file mode 100644 index 1f86d8dd..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/image.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Vercel AI Gateway image model.""" - -from __future__ import annotations - -import os -from typing import Any, override - -import httpx - -from ...types import messages as messages_ -from ..core import image as image_ -from ..core.media import base as media_base -from ..core.media import detect as detect_media_type -from . import errors as errors_ -from .llm import _DEFAULT_BASE_URL, _base_headers, _file_part_to_wire, _raise_for_status - - -class GatewayImageModel(image_.ImageModel): - """Vercel AI Gateway image model. - - Sends requests to ``/v3/ai/image-model`` and returns a :class:`Message` - with :class:`FilePart`\\s for each generated image. - - Args: - model: Model identifier (e.g. ``'google/imagen-4.0-generate-001'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "google/imagen-4.0-generate-001", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._extra_headers = headers or {} - self._transport = _transport - - def _headers(self) -> dict[str, str]: - return _base_headers( - self._api_key, - { - "ai-image-model-specification-version": "3", - "ai-model-id": self._model, - **self._extra_headers, - }, - ) - - @override - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> media_base.MediaResult: - body: dict[str, Any] = { - "prompt": prompt, - "n": n, - "providerOptions": provider_options or {}, - } - if size is not None: - body["size"] = size - if aspect_ratio is not None: - body["aspectRatio"] = aspect_ratio - if seed is not None: - body["seed"] = seed - if input_files: - body["files"] = [_file_part_to_wire(f) for f in input_files] - - url = f"{self._base_url}/image-model" - try: - async with httpx.AsyncClient(transport=self._transport) as client: - response = await client.post( - url, - json=body, - headers=self._headers(), - timeout=httpx.Timeout(timeout=300.0, connect=10.0), - ) - if response.status_code >= 400: - await _raise_for_status(response, api_key=self._api_key) - - data = response.json() - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError(cause=exc) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=f"Gateway image request failed: {exc}", - cause=exc, - ) from exc - - # Parse response: {images: string[], warnings?, usage?} - raw_images: list[str] = data.get("images", []) - usage_data = data.get("usage") - usage = None - if usage_data: - usage = messages_.Usage( - input_tokens=usage_data.get("inputTokens") or 0, - output_tokens=usage_data.get("outputTokens") or 0, - ) - - files: list[messages_.FilePart] = [] - for img_b64 in raw_images: - media_type = detect_media_type.detect_image_media_type(img_b64) - files.append( - messages_.FilePart( - data=img_b64, - media_type=media_type or "image/png", - ) - ) - - return media_base.MediaResult(files=files, usage=usage) diff --git a/src/vercel_ai_sdk/models/ai_gateway/llm.py b/src/vercel_ai_sdk/models/ai_gateway/llm.py deleted file mode 100644 index fa918c96..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/llm.py +++ /dev/null @@ -1,192 +0,0 @@ -"""Vercel AI Gateway language model using the v3 protocol.""" - -from __future__ import annotations - -import base64 -import json -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import httpx -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core.media import data as media_data -from . import errors as errors_ -from . import protocol as protocol_ - -_DEFAULT_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" -_PROTOCOL_VERSION = "0.0.1" - - -class GatewayModel(llm_.LanguageModel): - """Vercel AI Gateway language model using the v3 protocol. - - Sends the AI SDK's native message format directly to the gateway - server and receives responses in the AI SDK's native stream-part - format. The gateway server handles all provider-specific - translation. - - Args: - model: Model identifier in ``provider/model`` format - (e.g. ``'anthropic/claude-sonnet-4'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - provider_options: Gateway options (``order``, ``only``, - ``models``, ``byok``, ``tags``, etc.). - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "anthropic/claude-sonnet-4", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - provider_options: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._provider_options = provider_options - self._extra_headers = headers or {} - self._transport = _transport - - # -- Internals ----------------------------------------------------------- - - def _headers(self, *, streaming: bool) -> dict[str, str]: - h: dict[str, str] = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - "ai-language-model-specification-version": "3", - "ai-language-model-id": self._model, - "ai-language-model-streaming": str(streaming).lower(), - } - if self._api_key: - h["ai-gateway-auth-method"] = "api-key" - h.update(self._extra_headers) - return h - - async def _raise_for_status(self, response: httpx.Response) -> None: - """Raise a typed :class:`GatewayError` for HTTP >= 400.""" - try: - body: Any = response.json() - except Exception: - body = response.text - raise errors_.create_gateway_error( - response_body=body, - status_code=response.status_code, - api_key_provided=bool(self._api_key), - ) - - # -- Stream events ------------------------------------------------------- - - @override - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield ``StreamEvent`` objects from the gateway SSE stream.""" - body = await protocol_.build_request_body( - messages, - tools=tools, - output_type=output_type, - provider_options=self._provider_options, - ) - url = f"{self._base_url}/language-model" - try: - async with ( - httpx.AsyncClient(transport=self._transport) as client, - client.stream( - "POST", - url, - json=body, - headers=self._headers(streaming=True), - timeout=httpx.Timeout(timeout=300.0, connect=10.0), - ) as response, - ): - if response.status_code >= 400: - await response.aread() - await self._raise_for_status(response) - - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - data = json.loads(payload) - except json.JSONDecodeError: - continue - for event in protocol_.parse_stream_part(data): - yield event - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError( - cause=exc, - ) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=( - f"Invalid error response format: Gateway request failed: {exc}" - ), - cause=exc, - ) from exc - - -# --------------------------------------------------------------------------- -# Shared helpers for image/video models -# --------------------------------------------------------------------------- - - -def _base_headers(api_key: str, extra: dict[str, str]) -> dict[str, str]: - """Build common gateway headers.""" - h: dict[str, str] = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - } - if api_key: - h["ai-gateway-auth-method"] = "api-key" - h.update(extra) - return h - - -async def _raise_for_status(response: httpx.Response, *, api_key: str) -> None: - """Raise a typed :class:`GatewayError` for HTTP >= 400.""" - try: - body: Any = response.json() - except Exception: - body = response.text - raise errors_.create_gateway_error( - response_body=body, - status_code=response.status_code, - api_key_provided=bool(api_key), - ) - - -def _file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to the gateway wire format for input files.""" - data = part.data - if isinstance(data, str) and media_data.is_url(data): - return {"type": "url", "url": data} - if isinstance(data, bytes): - b64 = base64.b64encode(data).decode("ascii") - elif isinstance(data, str): - # Assume raw base64 - b64 = data - else: - b64 = str(data) - return {"type": "file", "data": b64, "mediaType": part.media_type} diff --git a/src/vercel_ai_sdk/models/ai_gateway/protocol.py b/src/vercel_ai_sdk/models/ai_gateway/protocol.py deleted file mode 100644 index 8b547396..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/protocol.py +++ /dev/null @@ -1,425 +0,0 @@ -"""Vercel AI Gateway v3 protocol: serialization and deserialization. - -Converts between the Python SDK's internal ``Message`` / ``StreamEvent`` -types and the LanguageModelV3 wire format used by the gateway at -``/v3/ai/language-model``. - -Wire format reference (from ``@ai-sdk/provider``): - -* **Request body** -- ``LanguageModelV3CallOptions`` (prompt + tools + - provider options, sent as JSON). -* **Stream response** -- Server-Sent Events where each ``data:`` line is - a JSON ``LanguageModelV3StreamPart`` (discriminated on ``type``). -* **Non-stream response** -- JSON ``LanguageModelV3GenerateResult``. -""" - -import json -from collections.abc import Sequence -from typing import Any - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core.media import data as media_data -from ..core.media import download as media_download - -# --------------------------------------------------------------------------- -# Internal messages -> v3 prompt format (outgoing request body) -# --------------------------------------------------------------------------- - - -async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: - """Convert an internal :class:`FilePart` to a v3 ``file`` content part. - - Binary data is converted to a ``data:`` URL for JSON transport (matching - the JS SDK gateway's ``maybeEncodeFileParts``). HTTP(S) URLs are - downloaded and converted to ``data:`` URLs because the gateway wire - format does not accept raw HTTP URLs for file content. - """ - data = part.data - if isinstance(data, str) and media_data.is_downloadable_url(data): - downloaded, _ = await media_download.download(data) - data = downloaded - - entry: dict[str, Any] = { - "type": "file", - "mediaType": part.media_type, - "data": media_data.data_to_data_url(data, part.media_type), - } - if part.filename is not None: - entry["filename"] = part.filename - return entry - - -async def messages_to_v3_prompt( - messages: list[messages_.Message], -) -> list[dict[str, Any]]: - """Convert internal ``Message`` list to ``LanguageModelV3Prompt``. - - The v3 prompt format is an array of messages, each with a ``role`` and - typed ``content`` parts:: - - [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - {"role": "assistant", "content": [ - {"type": "text", "text": "Hello!"}, - {"type": "reasoning", "text": "..."}, - {"type": "tool-call", "toolCallId": "tc-1", ...}, - ]}, - {"role": "tool", "content": [ - {"type": "tool-result", "toolCallId": "tc-1", ...}, - ]}, - ] - """ - result: list[dict[str, Any]] = [] - for msg in messages: - match msg.role: - case "system": - text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "system", "content": text}) - - case "user": - content: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - content.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - content.append(await _file_part_to_v3(p)) - result.append({"role": "user", "content": content}) - - case "assistant": - assistant_content: list[dict[str, Any]] = [] - tool_results: list[dict[str, Any]] = [] - - for part in msg.parts: - match part: - case messages_.ReasoningPart(text=text): - assistant_content.append( - {"type": "reasoning", "text": text} - ) - - case messages_.TextPart(text=text): - assistant_content.append({"type": "text", "text": text}) - - case messages_.ToolPart() as tp: - tool_input: Any = ( - json.loads(tp.tool_args) if tp.tool_args else {} - ) - assistant_content.append( - { - "type": "tool-call", - "toolCallId": tp.tool_call_id, - "toolName": tp.tool_name, - "input": tool_input, - } - ) - if tp.status in ("result", "error"): - output = ( - { - "type": "error-text", - "value": ( - str(tp.result) - if tp.result is not None - else "" - ), - } - if tp.status == "error" - else { - "type": "json", - "value": tp.result, - } - ) - tool_results.append( - { - "type": "tool-result", - "toolCallId": tp.tool_call_id, - "toolName": tp.tool_name, - "output": output, - } - ) - - result.append( - { - "role": "assistant", - "content": assistant_content, - } - ) - if tool_results: - result.append( - { - "role": "tool", - "content": tool_results, - } - ) - - return result - - -# --------------------------------------------------------------------------- -# Request body serialization -# --------------------------------------------------------------------------- - - -async def build_request_body( - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[Any] | None = None, - provider_options: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Build the full ``LanguageModelV3CallOptions`` request body.""" - body: dict[str, Any] = { - "prompt": await messages_to_v3_prompt(messages), - } - if tools: - body["tools"] = [ - { - "type": "function", - "name": tool.name, - "description": tool.description, - "inputSchema": tool.param_schema, - } - for tool in tools - ] - if output_type is not None: - import pydantic - - if issubclass(output_type, pydantic.BaseModel): - body["responseFormat"] = { - "type": "json", - "schema": output_type.model_json_schema(), - "name": output_type.__name__, - } - if provider_options: - body["providerOptions"] = provider_options - return body - - -# --------------------------------------------------------------------------- -# v3 stream parts -> internal StreamEvent (incoming SSE response) -# --------------------------------------------------------------------------- - - -def parse_stream_part( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Convert a ``LanguageModelV3StreamPart`` to internal events. - - Most parts map 1:1. A ``tool-call`` part (complete, non-streaming) - expands to Start + ArgsDelta + End. Lifecycle events - (``stream-start``, ``response-metadata``, ``raw``) are silently - dropped. - """ - match data.get("type", ""): - case "text-start": - return [ - llm_.TextStart( - block_id=data.get("id", "text"), - ) - ] - - case "text-delta": - return [ - llm_.TextDelta( - block_id=data.get("id", "text"), - delta=data.get("textDelta", data.get("delta", "")), - ) - ] - - case "text-end": - return [ - llm_.TextEnd( - block_id=data.get("id", "text"), - ) - ] - - case "reasoning-start": - return [ - llm_.ReasoningStart( - block_id=data.get("id", "reasoning"), - ) - ] - - case "reasoning-delta": - return [ - llm_.ReasoningDelta( - block_id=data.get("id", "reasoning"), - delta=data.get("delta", ""), - ) - ] - - case "reasoning-end": - return [ - llm_.ReasoningEnd( - block_id=data.get("id", "reasoning"), - ) - ] - - case "tool-input-start": - return [ - llm_.ToolStart( - tool_call_id=data.get("id", ""), - tool_name=data.get("toolName", ""), - ) - ] - - case "tool-input-delta": - return [ - llm_.ToolArgsDelta( - tool_call_id=data.get("id", ""), - delta=data.get("delta", ""), - ) - ] - - case "tool-input-end": - return [ - llm_.ToolEnd( - tool_call_id=data.get("id", ""), - ) - ] - - case "tool-call": - return _expand_tool_call(data) - - case "file": - return [ - llm_.FileEvent( - block_id=data.get("id", f"file-{len(data)}"), - media_type=data.get("mediaType", "application/octet-stream"), - data=data.get("data", ""), - ) - ] - - case "finish": - return [_parse_finish(data)] - - case _: - return [] - - -# --------------------------------------------------------------------------- -# Non-streaming response -> internal StreamEvents -# --------------------------------------------------------------------------- - - -def parse_generate_result( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Convert a ``LanguageModelV3GenerateResult`` into events. - - Synthesises Start/Delta/End events from the content, then a final - ``MessageDone``. - """ - events: list[llm_.StreamEvent] = [] - - def _expand_content_item(item: dict[str, Any]) -> None: - match item.get("type", ""): - case "text": - bid = item.get("id", "text") - text = item.get("text", "") - events.append(llm_.TextStart(block_id=bid)) - events.append(llm_.TextDelta(block_id=bid, delta=text)) - events.append(llm_.TextEnd(block_id=bid)) - - case "reasoning": - bid = item.get("id", "reasoning") - text = item.get("text", "") - events.append(llm_.ReasoningStart(block_id=bid)) - events.append(llm_.ReasoningDelta(block_id=bid, delta=text)) - events.append(llm_.ReasoningEnd(block_id=bid)) - - case "tool-call": - events.extend(_expand_tool_call(item)) - - case "file": - events.append( - llm_.FileEvent( - block_id=item.get("id", f"file-{len(events)}"), - media_type=item.get("mediaType", "application/octet-stream"), - data=item.get("data", ""), - ) - ) - - match data.get("content"): - case list() as items: - for item in items: - _expand_content_item(item) - case dict() as item: - _expand_content_item(item) - - events.append(_parse_finish(data)) - return events - - -# --------------------------------------------------------------------------- -# Shared helpers (called from multiple sites) -# --------------------------------------------------------------------------- - - -def _expand_tool_call( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Expand a complete ``tool-call`` part into three events.""" - tc_id = data.get("toolCallId", "") - tool_name = data.get("toolName", "") - tool_input = data.get("input", "") - args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) - return [ - llm_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - llm_.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), - llm_.ToolEnd(tool_call_id=tc_id), - ] - - -def _parse_finish(data: dict[str, Any]) -> llm_.MessageDone: - """Parse a ``finish`` stream part into a ``MessageDone`` event.""" - usage_data = data.get("usage") - usage = _parse_usage(usage_data) if usage_data else None - - match data.get("finishReason"): - case dict() as d: - finish_reason = d.get("unified", "stop") - case str() as s: - finish_reason = s - case _: - finish_reason = "stop" - - return llm_.MessageDone(finish_reason=finish_reason, usage=usage) - - -def _parse_usage(data: Any) -> messages_.Usage: - """Parse a v3 ``LanguageModelV3Usage`` into an internal ``Usage``. - - Supports both the v3 nested format:: - - {"inputTokens": {"total": 10, ...}, "outputTokens": {...}} - - and the flat OpenAI-style format:: - - {"prompt_tokens": 10, "completion_tokens": 20} - """ - if not isinstance(data, dict): - return messages_.Usage() - - input_tokens_obj = data.get("inputTokens") - output_tokens_obj = data.get("outputTokens") - - if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): - inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} - out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} - return messages_.Usage( - input_tokens=inp.get("total") or 0, - output_tokens=out.get("total") or 0, - reasoning_tokens=out.get("reasoning"), - cache_read_tokens=inp.get("cacheRead"), - cache_write_tokens=inp.get("cacheWrite"), - raw=data, - ) - - return messages_.Usage( - input_tokens=(data.get("prompt_tokens") or data.get("inputTokens") or 0), - output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), - raw=data, - ) diff --git a/src/vercel_ai_sdk/models2/ai_gateway/stream.py b/src/vercel_ai_sdk/models/ai_gateway/stream.py similarity index 100% rename from src/vercel_ai_sdk/models2/ai_gateway/stream.py rename to src/vercel_ai_sdk/models/ai_gateway/stream.py diff --git a/src/vercel_ai_sdk/models/ai_gateway/video.py b/src/vercel_ai_sdk/models/ai_gateway/video.py deleted file mode 100644 index 86ca88b3..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/video.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Vercel AI Gateway video model.""" - -from __future__ import annotations - -import json -import os -from typing import Any, override - -import httpx - -from ...types import messages as messages_ -from ..core import video as video_ -from ..core.media import base as media_base -from ..core.media import detect as detect_media_type -from ..core.media import download as media_download -from . import errors as errors_ -from .llm import _DEFAULT_BASE_URL, _base_headers, _file_part_to_wire, _raise_for_status - - -class GatewayVideoModel(video_.VideoModel): - """Vercel AI Gateway video model. - - Sends requests to ``/v3/ai/video-model`` (with SSE response) and returns - a :class:`Message` with :class:`FilePart`\\s for each generated video. - - Args: - model: Model identifier (e.g. ``'google/veo-3.0-generate-001'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "google/veo-3.0-generate-001", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._extra_headers = headers or {} - self._transport = _transport - - def _headers(self) -> dict[str, str]: - return _base_headers( - self._api_key, - { - "ai-video-model-specification-version": "3", - "ai-model-id": self._model, - "accept": "text/event-stream", - **self._extra_headers, - }, - ) - - @override - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> media_base.MediaResult: - image_wire: dict[str, Any] | None = None - if input_files: - image_wire = _file_part_to_wire(input_files[0]) - - body: dict[str, Any] = { - "prompt": prompt, - "n": n, - "providerOptions": provider_options or {}, - } - if aspect_ratio is not None: - body["aspectRatio"] = aspect_ratio - if resolution is not None: - body["resolution"] = resolution - if duration is not None: - body["duration"] = duration - if fps is not None: - body["fps"] = fps - if seed is not None: - body["seed"] = seed - if image_wire is not None: - body["image"] = image_wire - - url = f"{self._base_url}/video-model" - try: - async with ( - httpx.AsyncClient(transport=self._transport) as client, - client.stream( - "POST", - url, - json=body, - headers=self._headers(), - timeout=httpx.Timeout(timeout=600.0, connect=10.0), - ) as response, - ): - if response.status_code >= 400: - await response.aread() - await _raise_for_status(response, api_key=self._api_key) - - event_data = await self._read_first_sse_event(response) - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError(cause=exc) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=f"Gateway video request failed: {exc}", - cause=exc, - ) from exc - - # Handle error event - if event_data.get("type") == "error": - status = event_data.get("statusCode", 500) - message = event_data.get("message", "Video generation failed") - error_type = event_data.get("errorType", "") - if status == 400 or error_type == "invalid_request_error": - raise errors_.GatewayInvalidRequestError( - message=message, status_code=status - ) - raise errors_.GatewayResponseError(message=message, status_code=status) - - # Handle result event - raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - files: list[messages_.FilePart] = [] - for video_data in raw_videos: - file_part = await self._video_data_to_file_part(video_data) - files.append(file_part) - - return media_base.MediaResult(files=files) - - @staticmethod - async def _read_first_sse_event(response: httpx.Response) -> dict[str, Any]: - """Read and parse the first SSE data event from the response.""" - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - result: dict[str, Any] = json.loads(payload) - return result - except json.JSONDecodeError: - continue - raise errors_.GatewayResponseError( - message="SSE stream ended without a data event", - ) - - @staticmethod - async def _video_data_to_file_part( - video_data: dict[str, Any], - ) -> messages_.FilePart: - """Convert a gateway video result to a :class:`FilePart`. - - Handles ``{type: "url", url, mediaType}`` (downloads the video) - and ``{type: "base64", data, mediaType}``. - """ - vtype = video_data.get("type", "base64") - media_type = video_data.get("mediaType", "video/mp4") - - if vtype == "url": - video_url = video_data["url"] - downloaded_bytes, content_type = await media_download.download(video_url) - # Prefer provider mediaType, then download content-type, then detect - if media_type == "video/mp4" and content_type: - media_type = content_type - detected = detect_media_type.detect_media_type( - downloaded_bytes, detect_media_type.VIDEO_SIGNATURES - ) - if detected: - media_type = detected - return messages_.FilePart( - data=downloaded_bytes, - media_type=media_type, - ) - - # base64 - data = video_data.get("data", "") - detected = detect_media_type.detect_media_type( - data, detect_media_type.VIDEO_SIGNATURES - ) - if detected: - media_type = detected - return messages_.FilePart( - data=data, - media_type=media_type, - ) - - -# --------------------------------------------------------------------------- -# Stubs for future model types -# --------------------------------------------------------------------------- - - -class GatewayEmbeddingModel: - """Stub -- not yet implemented.""" - - def __init__(self, model: str, **kwargs: Any) -> None: - raise NotImplementedError("GatewayEmbeddingModel is not yet implemented.") diff --git a/src/vercel_ai_sdk/models/anthropic/__init__.py b/src/vercel_ai_sdk/models/anthropic/__init__.py index 38716ce4..a9a0436b 100644 --- a/src/vercel_ai_sdk/models/anthropic/__init__.py +++ b/src/vercel_ai_sdk/models/anthropic/__init__.py @@ -1,5 +1,7 @@ -"""Anthropic provider adapter.""" +"""Anthropic provider — adapter for the Anthropic messages API.""" -from .llm import AnthropicModel, _messages_to_anthropic +from .adapter import stream -__all__ = ["AnthropicModel", "_messages_to_anthropic"] +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models2/anthropic/adapter.py b/src/vercel_ai_sdk/models/anthropic/adapter.py similarity index 100% rename from src/vercel_ai_sdk/models2/anthropic/adapter.py rename to src/vercel_ai_sdk/models/anthropic/adapter.py diff --git a/src/vercel_ai_sdk/models/anthropic/llm.py b/src/vercel_ai_sdk/models/anthropic/llm.py deleted file mode 100644 index b6c4ff52..00000000 --- a/src/vercel_ai_sdk/models/anthropic/llm.py +++ /dev/null @@ -1,341 +0,0 @@ -from __future__ import annotations - -import json -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import anthropic -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core import media - - -def _tools_to_anthropic(tools: Sequence[tools_.ToolLike]) -> list[dict[str, Any]]: - """Convert internal Tool objects to Anthropic tool schema format.""" - return [ - { - "name": tool.name, - "description": tool.description, - "input_schema": tool.param_schema, - } - for tool in tools - ] - - -def _file_part_to_anthropic(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to an Anthropic content block. - - * ``image/*`` → ``{"type": "image", "source": ...}`` - * ``application/pdf`` → ``{"type": "document", "source": ...}`` - * ``text/plain`` → ``{"type": "document", "source": {"type": "text", ...}}`` - * anything else → ``ValueError`` - """ - mt = part.media_type - - if mt.startswith("image/"): - media_type = "image/jpeg" if mt == "image/*" else mt - if isinstance(part.data, str) and media.data.is_url(part.data): - return { - "type": "image", - "source": {"type": "url", "url": part.data}, - } - return { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": media.data.data_to_base64(part.data), - }, - } - - if mt == "application/pdf": - if isinstance(part.data, str) and media.data.is_url(part.data): - return { - "type": "document", - "source": {"type": "url", "url": part.data}, - } - return { - "type": "document", - "source": { - "type": "base64", - "media_type": "application/pdf", - "data": media.data.data_to_base64(part.data), - }, - } - - if mt == "text/plain": - # Anthropic accepts text documents with source.type="text" - if isinstance(part.data, bytes): - text_data = part.data.decode("utf-8") - elif media.data.is_url(part.data): - return { - "type": "document", - "source": {"type": "url", "url": part.data}, - } - else: - import base64 as _b64 - - text_data = _b64.b64decode(part.data).decode("utf-8") - return { - "type": "document", - "source": { - "type": "text", - "media_type": "text/plain", - "data": text_data, - }, - } - - raise ValueError(f"Unsupported media type for Anthropic: {mt}") - - -async def _messages_to_anthropic( - messages: list[messages_.Message], -) -> tuple[str | None, list[dict[str, Any]]]: - """Convert internal messages to Anthropic API format.""" - system_prompt: str | None = None - result: list[dict[str, Any]] = [] - - for msg in messages: - match msg.role: - case "system": - system_prompt = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - case "assistant": - content: list[dict[str, Any]] = [] - tool_results: list[dict[str, Any]] = [] - - for part in msg.parts: - match part: - case messages_.ReasoningPart(text=text, signature=signature): - if signature: - content.append( - { - "type": "thinking", - "thinking": text, - "signature": signature, - } - ) - case messages_.TextPart(text=text): - content.append({"type": "text", "text": text}) - case messages_.ToolPart(): - tool_input = ( - json.loads(part.tool_args) if part.tool_args else {} - ) - content.append( - { - "type": "tool_use", - "id": part.tool_call_id, - "name": part.tool_name, - "input": tool_input, - } - ) - if part.status in ("result", "error"): - entry: dict[str, Any] = { - "type": "tool_result", - "tool_use_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - if part.status == "error": - entry["is_error"] = True - tool_results.append(entry) - - if content: - result.append({"role": "assistant", "content": content}) - if tool_results: - result.append({"role": "user", "content": tool_results}) - case "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - content_text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "user", "content": content_text}) - else: - user_content: list[dict[str, Any]] = [] - for p in msg.parts: - match p: - case messages_.TextPart(text=text): - user_content.append({"type": "text", "text": text}) - case messages_.FilePart(): - user_content.append(_file_part_to_anthropic(p)) - result.append({"role": "user", "content": user_content}) - - # Merge consecutive same-role messages (e.g. synthetic user(tool_result) - # followed by a real user message). - result = _merge_consecutive_roles(result) - - return system_prompt, result - - -def _merge_consecutive_roles( - messages: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Merge consecutive messages that share the same role. - - Anthropic requires strictly alternating user/assistant roles. When - our conversion emits a synthetic ``user`` message for ``tool_result`` - blocks followed by a real ``user`` message, they must be merged. - - Content is normalized to list-of-blocks so heterogeneous content - (tool_result dicts + text strings) can coexist. - """ - if not messages: - return messages - - merged: list[dict[str, Any]] = [messages[0]] - - for msg in messages[1:]: - if msg["role"] == merged[-1]["role"]: - prev = _to_content_list(merged[-1]["content"]) - cur = _to_content_list(msg["content"]) - merged[-1]["content"] = prev + cur - else: - merged.append(msg) - - return merged - - -def _to_content_list(content: Any) -> list[dict[str, Any]]: - """Normalize Anthropic message content to list-of-blocks format.""" - if isinstance(content, list): - return list(content) - return [{"type": "text", "text": content}] - - -class AnthropicModel(llm_.LanguageModel): - """Anthropic adapter with native extended thinking support.""" - - def __init__( - self, - model: str = "claude-sonnet-4-5-20250929", - base_url: str | None = None, - api_key: str | None = None, - thinking: bool = False, - budget_tokens: int = 10000, - ) -> None: - self._model = model - self._thinking = thinking - self._budget_tokens = budget_tokens - resolved_key = api_key or os.environ.get("ANTHROPIC_API_KEY") or "" - self._client = anthropic.AsyncAnthropic(base_url=base_url, api_key=resolved_key) - - @override - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield raw stream events from Anthropic API.""" - system_prompt, anthropic_messages = await _messages_to_anthropic(messages) - anthropic_tools = _tools_to_anthropic(tools) if tools else None - - kwargs: dict[str, Any] = { - "model": self._model, - "messages": anthropic_messages, - "max_tokens": 8192, - } - if system_prompt: - kwargs["system"] = system_prompt - if anthropic_tools: - kwargs["tools"] = anthropic_tools - - if self._thinking: - kwargs["thinking"] = { - "type": "enabled", - "budget_tokens": self._budget_tokens, - } - - # Structured output: SDK handles schema transformation internally - if output_type is not None: - kwargs["output_format"] = output_type - - # Track block types by index to know what End event to emit - block_types: dict[int, str] = {} # index -> "text" | "thinking" | "tool_use" - tool_ids: dict[int, str] = {} # index -> tool_call_id - signature_buffer: dict[int, str] = {} # index -> accumulated signature - - stream_cm = self._client.messages.stream(**kwargs) - - async with stream_cm as stream: - async for event in stream: - match event.type: - case "content_block_start": - block = event.content_block - idx = event.index - block_types[idx] = block.type - - match block.type: - case "text": - yield llm_.TextStart(block_id=str(idx)) - case "thinking": - yield llm_.ReasoningStart(block_id=str(idx)) - case "tool_use": - tool_ids[idx] = block.id - yield llm_.ToolStart( - tool_call_id=block.id, tool_name=block.name - ) - - case "content_block_delta": - delta = event.delta - idx = event.index - - match delta.type: - case "text_delta": - yield llm_.TextDelta( - block_id=str(idx), delta=delta.text - ) - case "thinking_delta": - yield llm_.ReasoningDelta( - block_id=str(idx), delta=delta.thinking - ) - case "signature_delta": - # Accumulate signature for ReasoningEnd - signature_buffer[idx] = ( - signature_buffer.get(idx, "") + delta.signature - ) - case "input_json_delta": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolArgsDelta( - tool_call_id=tool_id, - delta=delta.partial_json, - ) - - case "content_block_stop": - idx = event.index - match block_types.get(idx): - case "text": - yield llm_.TextEnd(block_id=str(idx)) - case "thinking": - yield llm_.ReasoningEnd( - block_id=str(idx), - signature=signature_buffer.get(idx), - ) - case "tool_use": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolEnd(tool_call_id=tool_id) - - # The Anthropic SDK accumulates usage across message_start and - # message_delta events into current_message_snapshot. Read it - # once here instead of tracking state ourselves. - snapshot = stream.current_message_snapshot - sdk_usage = snapshot.usage - usage = messages_.Usage( - input_tokens=sdk_usage.input_tokens or 0, - output_tokens=sdk_usage.output_tokens or 0, - cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), - cache_write_tokens=getattr( - sdk_usage, "cache_creation_input_tokens", None - ), - raw=sdk_usage.model_dump(exclude_none=True) or None, - ) - yield llm_.MessageDone(usage=usage) diff --git a/src/vercel_ai_sdk/models/core/__init__.py b/src/vercel_ai_sdk/models/core/__init__.py index be7a0f38..32bed109 100644 --- a/src/vercel_ai_sdk/models/core/__init__.py +++ b/src/vercel_ai_sdk/models/core/__init__.py @@ -1,38 +1,13 @@ -"""Core model abstractions — LanguageModel, ImageModel, VideoModel.""" +"""Core types for models.""" -from . import media -from .image import ImageModel -from .llm import LanguageModel, StreamEvent, StreamHandler -from .media.base import MediaModel, MediaResult -from .model import Model -from .protocol import GenerateFn, Stream, StreamFn -from .registry import ( - get_generate_fn, - get_stream_fn, - register_generate, - register_stream, -) -from .video import VideoModel +from .client import Client +from .model import Model, ModelCost +from .proto import GenerateFn, StreamFn __all__ = [ - # Model data + "Client", + "GenerateFn", "Model", - # Execution protocols + "ModelCost", "StreamFn", - "GenerateFn", - "Stream", - # Registry - "register_stream", - "register_generate", - "get_stream_fn", - "get_generate_fn", - # Legacy ABCs (still in use) - "LanguageModel", - "StreamEvent", - "StreamHandler", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", - "media", ] diff --git a/src/vercel_ai_sdk/models2/core/client.py b/src/vercel_ai_sdk/models/core/client.py similarity index 100% rename from src/vercel_ai_sdk/models2/core/client.py rename to src/vercel_ai_sdk/models/core/client.py diff --git a/src/vercel_ai_sdk/models2/core/helpers/media.py b/src/vercel_ai_sdk/models/core/helpers/media.py similarity index 100% rename from src/vercel_ai_sdk/models2/core/helpers/media.py rename to src/vercel_ai_sdk/models/core/helpers/media.py diff --git a/src/vercel_ai_sdk/models2/core/helpers/streaming.py b/src/vercel_ai_sdk/models/core/helpers/streaming.py similarity index 100% rename from src/vercel_ai_sdk/models2/core/helpers/streaming.py rename to src/vercel_ai_sdk/models/core/helpers/streaming.py diff --git a/src/vercel_ai_sdk/models/core/image.py b/src/vercel_ai_sdk/models/core/image.py deleted file mode 100644 index eb7aa9c3..00000000 --- a/src/vercel_ai_sdk/models/core/image.py +++ /dev/null @@ -1,60 +0,0 @@ -"""ImageModel — abstract image generation model.""" - -from __future__ import annotations - -import abc -from typing import Any, override - -from ...types import messages as messages_ -from .media.base import MediaModel, MediaResult - - -class ImageModel(MediaModel): - """Abstract image generation model. - - Accepts :class:`Message`\\s as input and returns a :class:`Message` - containing generated images as :class:`FilePart`\\s. - - Adapter authors implement :meth:`make_request`; the framework handles - parsing messages and assembling the response :class:`Message`. - """ - - async def generate( - self, - messages: list[messages_.Message], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> messages_.Message: - """Generate images from the given messages.""" - prompt = self._extract_prompt(messages) - input_files = self._extract_input_files(messages) - result = await self.make_request( - prompt, - input_files, - n=n, - size=size, - aspect_ratio=aspect_ratio, - seed=seed, - provider_options=provider_options, - ) - return self._build_message(result) - - @override - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific image generation.""" - ... diff --git a/src/vercel_ai_sdk/models/core/llm.py b/src/vercel_ai_sdk/models/core/llm.py deleted file mode 100644 index 737a467f..00000000 --- a/src/vercel_ai_sdk/models/core/llm.py +++ /dev/null @@ -1,288 +0,0 @@ -from __future__ import annotations - -import abc -import dataclasses -import json -from collections.abc import AsyncGenerator, Sequence - -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ - - -@dataclasses.dataclass -class TextStart: - block_id: str - - -@dataclasses.dataclass -class TextDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class TextEnd: - block_id: str - - -@dataclasses.dataclass -class ReasoningStart: - block_id: str - - -@dataclasses.dataclass -class ReasoningDelta: - block_id: str - delta: str - - -@dataclasses.dataclass -class ReasoningEnd: - block_id: str - signature: str | None = None - - -@dataclasses.dataclass -class ToolStart: - tool_call_id: str - tool_name: str - - -@dataclasses.dataclass -class ToolArgsDelta: - tool_call_id: str - delta: str - - -@dataclasses.dataclass -class ToolEnd: - tool_call_id: str - - -@dataclasses.dataclass -class FileEvent: - """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" - - block_id: str - media_type: str - data: str # base64 string or data-URL from the gateway - - -@dataclasses.dataclass -class MessageDone: - finish_reason: str | None = None - usage: messages_.Usage | None = None - - -StreamEvent = ( - TextStart - | TextDelta - | TextEnd - | ReasoningStart - | ReasoningDelta - | ReasoningEnd - | ToolStart - | ToolArgsDelta - | ToolEnd - | FileEvent - | MessageDone -) - - -@dataclasses.dataclass -class StreamHandler: - """ - Accumulates LLM adapter events and produces Messages with stateful parts. - - This is the normalization layer between LLM adapters and the rest of the system. - """ - - message_id: str = dataclasses.field(default_factory=messages_._gen_id) - - # Accumulators - _text_blocks: dict[str, str] = dataclasses.field(default_factory=dict) - _reasoning_blocks: dict[str, tuple[str, str | None]] = dataclasses.field( - default_factory=dict - ) # (text, signature) - _tool_calls: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # (name, args) - _files: dict[str, tuple[str, str]] = dataclasses.field( - default_factory=dict - ) # block_id -> (media_type, data) - - # Active tracking - _active_text_id: str | None = None - _active_reasoning_id: str | None = None - _active_tool_ids: set[str] = dataclasses.field(default_factory=set) - - _is_done: bool = False - _usage: messages_.Usage | None = None - - def handle_event(self, event: StreamEvent) -> messages_.Message: - """Process event and return current Message state.""" - - # Current deltas (reset each call) - text_delta: str | None = None - reasoning_delta: str | None = None - tool_deltas: dict[str, str] = {} # tool_call_id -> delta - - match event: - case TextStart(block_id=bid): - self._text_blocks[bid] = "" - self._active_text_id = bid - - case TextDelta(block_id=bid, delta=d): - self._text_blocks[bid] += d - text_delta = d - - case TextEnd(block_id=bid): - if self._active_text_id == bid: - self._active_text_id = None - - case ReasoningStart(block_id=bid): - self._reasoning_blocks[bid] = ("", None) - self._active_reasoning_id = bid - - case ReasoningDelta(block_id=bid, delta=d): - text, sig = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text + d, sig) - reasoning_delta = d - - case ReasoningEnd(block_id=bid, signature=sig): - text, _ = self._reasoning_blocks[bid] - self._reasoning_blocks[bid] = (text, sig) - if self._active_reasoning_id == bid: - self._active_reasoning_id = None - - case ToolStart(tool_call_id=tcid, tool_name=name): - self._tool_calls[tcid] = (name, "") - self._active_tool_ids.add(tcid) - - case ToolArgsDelta(tool_call_id=tcid, delta=d): - name, args = self._tool_calls[tcid] - self._tool_calls[tcid] = (name, args + d) - tool_deltas[tcid] = d - - case ToolEnd(tool_call_id=tcid): - self._active_tool_ids.discard(tcid) - - case FileEvent(block_id=bid, media_type=mt, data=d): - self._files[bid] = (mt, d) - - case MessageDone(usage=usage): - self._is_done = True - self._usage = usage - self._active_text_id = None - self._active_reasoning_id = None - self._active_tool_ids.clear() - - return self._build_message(text_delta, reasoning_delta, tool_deltas) - - def _build_message( - self, - text_delta: str | None, - reasoning_delta: str | None, - tool_deltas: dict[str, str], - ) -> messages_.Message: - parts: list[messages_.Part] = [] - - # Reasoning parts first (like thinking blocks) - for bid, (text, sig) in self._reasoning_blocks.items(): - is_active = bid == self._active_reasoning_id - parts.append( - messages_.ReasoningPart( - text=text, - signature=sig, - state="streaming" if is_active else "done", - delta=reasoning_delta if is_active else None, - ) - ) - - # Text parts - for bid, text in self._text_blocks.items(): - is_active = bid == self._active_text_id - parts.append( - messages_.TextPart( - text=text, - state="streaming" if is_active else "done", - delta=text_delta if is_active else None, - ) - ) - - # Tool parts - for tcid, (name, args) in self._tool_calls.items(): - is_active = tcid in self._active_tool_ids - parts.append( - messages_.ToolPart( - tool_call_id=tcid, - tool_name=name, - tool_args=args, - state="streaming" if is_active else "done", - args_delta=tool_deltas.get(tcid), - ) - ) - - # File parts (inline images/videos from LLMs like Gemini, GPT-5) - for _bid, (media_type, data) in self._files.items(): - parts.append(messages_.FilePart(data=data, media_type=media_type)) - - return messages_.Message( - id=self.message_id, - role="assistant", - parts=parts, - usage=self._usage if self._is_done else None, - ) - - -class LanguageModel(abc.ABC): - @abc.abstractmethod - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[StreamEvent]: - raise NotImplementedError - yield - - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - """Stream Messages (uses StreamHandler internally).""" - handler = StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type=output_type): - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg - - async def buffer( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> messages_.Message: - """Drain the stream and return the final message.""" - final = None - async for msg in self.stream(messages, tools, output_type=output_type): - final = msg - if final is None: - raise ValueError("LLM produced no messages") - return final diff --git a/src/vercel_ai_sdk/models/core/media/__init__.py b/src/vercel_ai_sdk/models/core/media/__init__.py deleted file mode 100644 index a4485760..00000000 --- a/src/vercel_ai_sdk/models/core/media/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Media utilities — data-format helpers, media type detection, and download.""" - -from . import data, detect, download -from .base import MediaModel, MediaResult - -__all__ = [ - "MediaModel", - "MediaResult", - "data", - "detect", - "download", -] diff --git a/src/vercel_ai_sdk/models/core/media/base.py b/src/vercel_ai_sdk/models/core/media/base.py deleted file mode 100644 index b6306a67..00000000 --- a/src/vercel_ai_sdk/models/core/media/base.py +++ /dev/null @@ -1,86 +0,0 @@ -"""MediaModel base class and MediaResult type. - -Shared pipeline steps that every media adapter would otherwise duplicate: - -* **Input** -- extract a text prompt and input files from messages. -* **Output** -- wrap the adapter's :class:`MediaResult` into a - :class:`Message` with ``role="assistant"``. -""" - -from __future__ import annotations - -import abc -import dataclasses -from typing import Any - -from ....types import messages as messages_ - - -@dataclasses.dataclass -class MediaResult: - """Raw result returned by an adapter's ``make_request()`` method. - - The framework wraps this into a :class:`Message` automatically. - """ - - files: list[messages_.FilePart] - usage: messages_.Usage | None = None - - -class MediaModel(abc.ABC): - """Abstract base for media generation models. - - Subclasses (:class:`ImageModel`, :class:`VideoModel`) define the - public ``generate()`` signature with media-type-specific parameters - and delegate to the adapter's ``make_request()`` method. - """ - - @staticmethod - def _extract_prompt(messages: list[messages_.Message]) -> str: - """Concatenate all :class:`TextPart` texts from user/system messages.""" - parts: list[str] = [] - for msg in messages: - if msg.role in ("user", "system"): - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append(p.text) - return " ".join(parts) - - @staticmethod - def _extract_input_files( - messages: list[messages_.Message], - ) -> list[messages_.FilePart]: - """Collect all :class:`FilePart` objects from user messages.""" - files: list[messages_.FilePart] = [] - for msg in messages: - if msg.role == "user": - for p in msg.parts: - if isinstance(p, messages_.FilePart): - files.append(p) - return files - - @staticmethod - def _build_message(result: MediaResult) -> messages_.Message: - """Wrap adapter output into a :class:`Message`.""" - return messages_.Message( - role="assistant", - parts=list(result.files), - usage=result.usage, - ) - - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific generation logic. - - Receives already-parsed inputs and returns a :class:`MediaResult`. - The framework calls this from ``generate()`` and wraps the result - into a :class:`Message`. - """ - ... diff --git a/src/vercel_ai_sdk/models/core/media/data.py b/src/vercel_ai_sdk/models/core/media/data.py deleted file mode 100644 index e92fb5e2..00000000 --- a/src/vercel_ai_sdk/models/core/media/data.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Data-format helpers for multimodal content. - -URL detection, ``data:`` URL parsing, base-64 encoding/decoding, and -media-type inference utilities used by :class:`~vercel_ai_sdk.core.messages.FilePart` -and the provider converters. -""" - -from __future__ import annotations - -import base64 -import mimetypes - -# -- URL helpers ----------------------------------------------------------- - - -def is_url(data: str) -> bool: - """Return True if *data* looks like a URL rather than raw base-64.""" - return data.startswith(("http://", "https://", "data:")) - - -def is_downloadable_url(data: str) -> bool: - """Return True if *data* is an ``http(s)://`` URL that can be fetched.""" - return data.startswith(("http://", "https://")) - - -def split_data_url(url: str) -> tuple[str | None, str | None]: - """Parse a ``data:`` URL into ``(media_type, base64_content)``. - - Returns ``(None, None)`` if the input is not a valid ``data:`` URL. - - Example:: - - >>> split_data_url("data:image/png;base64,iVBOR...") - ("image/png", "iVBOR...") - """ - if not url.startswith("data:"): - return None, None - try: - header, b64_content = url.split(",", 1) - # header = "data:image/png;base64" - mt = header.split(";")[0].split(":", 1)[1] - return (mt or None), (b64_content or None) - except (ValueError, IndexError): - return None, None - - -# -- encoding helpers ------------------------------------------------------ - - -def data_to_base64(data: str | bytes) -> str: - """Ensure *data* is a base-64 encoded string. - - * ``bytes`` -> base-64 encoded. - * ``str`` that is a ``data:`` URL -> base-64 content extracted. - * ``str`` that is an ``http(s)://`` URL -> returned as-is (caller - must handle). - * ``str`` that is not a URL -> assumed to already be base-64. - """ - if isinstance(data, bytes): - return base64.b64encode(data).decode("ascii") - if data.startswith("data:"): - _, b64 = split_data_url(data) - if b64 is not None: - return b64 - return data - - -def data_to_data_url(data: str | bytes, media_type: str) -> str: - """Convert *data* to a ``data:`` URL. Passes through existing URLs.""" - if isinstance(data, str) and is_url(data): - return data - b64 = data_to_base64(data) - return f"data:{media_type};base64,{b64}" - - -# -- media-type inference -------------------------------------------------- - - -def infer_media_type(url: str) -> str: - """Infer IANA media type from a URL. - - * ``data:image/png;base64,...`` -> ``"image/png"`` - * ``https://example.com/cat.jpg`` -> ``"image/jpeg"`` (via :mod:`mimetypes`) - * Unknown -> raises :class:`ValueError` - """ - if url.startswith("data:"): - # data:[][;base64], - rest = url[5:] # strip "data:" - sep = rest.find(",") - meta = rest[:sep] if sep != -1 else rest - mt = meta.split(";")[0] - if mt: - return mt - else: - guessed, _ = mimetypes.guess_type(url) - if guessed: - return guessed - raise ValueError( - f"Cannot infer media_type from URL: {url!r}. Provide media_type explicitly." - ) diff --git a/src/vercel_ai_sdk/models/core/media/detect.py b/src/vercel_ai_sdk/models/core/media/detect.py deleted file mode 100644 index a9bf770a..00000000 --- a/src/vercel_ai_sdk/models/core/media/detect.py +++ /dev/null @@ -1,188 +0,0 @@ -"""Magic-byte media type detection. - -Port of ``@ai-sdk/ai/src/util/detect-media-type.ts``. Detects image, -audio, and video formats by inspecting the first bytes of binary data -(or the first characters of a base-64 string). -""" - -from __future__ import annotations - -import base64 as _b64 - -# --------------------------------------------------------------------------- -# Signature definitions -# --------------------------------------------------------------------------- - -# Each signature is a tuple of (media_type, byte_prefix) where byte_prefix -# is a tuple of ``int | None`` values. ``None`` is a wildcard that matches -# any byte (mirrors the TS SDK's ``null`` sentinel). - -_Signature = tuple[str, tuple[int | None, ...]] - -IMAGE_SIGNATURES: list[_Signature] = [ - ("image/gif", (0x47, 0x49, 0x46)), - ("image/png", (0x89, 0x50, 0x4E, 0x47)), - ("image/jpeg", (0xFF, 0xD8)), - ( - "image/webp", - (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x45, 0x42, 0x50), - ), - ("image/bmp", (0x42, 0x4D)), - ("image/tiff", (0x49, 0x49, 0x2A, 0x00)), # little-endian - ("image/tiff", (0x4D, 0x4D, 0x00, 0x2A)), # big-endian - ( - "image/avif", - (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66), - ), - ( - "image/heic", - (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63), - ), -] - -AUDIO_SIGNATURES: list[_Signature] = [ - ("audio/mpeg", (0xFF, 0xFB)), - ("audio/mpeg", (0xFF, 0xFA)), - ("audio/mpeg", (0xFF, 0xF3)), - ("audio/mpeg", (0xFF, 0xF2)), - ("audio/mpeg", (0xFF, 0xE3)), - ("audio/mpeg", (0xFF, 0xE2)), - ( - "audio/wav", - (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x41, 0x56, 0x45), - ), - ("audio/ogg", (0x4F, 0x67, 0x67, 0x53)), - ("audio/flac", (0x66, 0x4C, 0x61, 0x43)), - ("audio/aac", (0x40, 0x15, 0x00, 0x00)), - ("audio/mp4", (0x66, 0x74, 0x79, 0x70)), - ("audio/webm", (0x1A, 0x45, 0xDF, 0xA3)), -] - -VIDEO_SIGNATURES: list[_Signature] = [ - ("video/mp4", (0x00, 0x00, 0x00, None, 0x66, 0x74, 0x79, 0x70)), - ("video/webm", (0x1A, 0x45, 0xDF, 0xA3)), - ( - "video/quicktime", - (0x00, 0x00, 0x00, 0x14, 0x66, 0x74, 0x79, 0x70, 0x71, 0x74), - ), - ("video/x-msvideo", (0x52, 0x49, 0x46, 0x46)), -] - - -# --------------------------------------------------------------------------- -# ID3 tag stripping (for MP3 files that start with ID3v2 metadata) -# --------------------------------------------------------------------------- - -_ID3_HEADER = bytes([0x49, 0x44, 0x33]) # "ID3" -_ID3_BASE64 = "SUQz" # base64("ID3") - - -def _strip_id3_tags(data: bytes) -> bytes: - """Strip an ID3v2 tag header if present, returning the audio data.""" - if len(data) < 10 or data[:3] != _ID3_HEADER: - return data - # Syncsafe integer: 4 bytes, 7 bits each - size = ( - (data[6] & 0x7F) << 21 - | (data[7] & 0x7F) << 14 - | (data[8] & 0x7F) << 7 - | (data[9] & 0x7F) - ) - offset = size + 10 - return data[offset:] if offset < len(data) else data - - -def _strip_id3_tags_base64(data: str) -> str: - """Strip an ID3v2 tag from base64-encoded data if present.""" - if not data.startswith(_ID3_BASE64): - return data - # Decode enough to read the ID3 header (10 bytes = ~16 base64 chars) - try: - header = _b64.b64decode(data[:16]) - except Exception: - return data - if len(header) < 10 or header[:3] != _ID3_HEADER: - return data - size = ( - (header[6] & 0x7F) << 21 - | (header[7] & 0x7F) << 14 - | (header[8] & 0x7F) << 7 - | (header[9] & 0x7F) - ) - offset = size + 10 - # Re-encode: decode full data, strip, re-encode - try: - full = _b64.b64decode(data) - stripped = full[offset:] if offset < len(full) else full - return _b64.b64encode(stripped).decode("ascii") - except Exception: - return data - - -# --------------------------------------------------------------------------- -# Core detection -# --------------------------------------------------------------------------- - - -def _to_bytes(data: bytes | str, *, max_bytes: int = 24) -> bytes: - """Convert *data* to bytes for signature comparison. - - For ``str`` input (base-64), decodes only the first *max_bytes* - characters worth of data to avoid decoding large payloads. - """ - if isinstance(data, bytes): - return data[:max_bytes] - # base-64: 4 chars → 3 bytes. Decode ~32 chars to get enough bytes. - chunk = data[: max_bytes * 2] - # Pad to multiple of 4 for valid base64 - padded = chunk + "=" * (-len(chunk) % 4) - try: - return _b64.b64decode(padded)[:max_bytes] - except Exception: - return b"" - - -def detect_media_type( - data: bytes | str, - signatures: list[_Signature], -) -> str | None: - """Detect media type from magic bytes. - - Args: - data: Raw bytes or a base-64 encoded string. - signatures: List of ``(media_type, byte_prefix)`` tuples to - match against (e.g. :data:`IMAGE_SIGNATURES`). - - Returns: - The matched IANA media type, or ``None`` if no signature matches. - """ - # Strip ID3 tags for audio detection - if signatures is AUDIO_SIGNATURES: - if isinstance(data, bytes): - data = _strip_id3_tags(data) - else: - data = _strip_id3_tags_base64(data) - - raw = _to_bytes(data) - if not raw: - return None - - for media_type, prefix in signatures: - if len(raw) < len(prefix): - continue - if all( - expected is None or raw[i] == expected for i, expected in enumerate(prefix) - ): - return media_type - - return None - - -def detect_image_media_type(data: bytes | str) -> str | None: - """Detect image format from magic bytes.""" - return detect_media_type(data, IMAGE_SIGNATURES) - - -def detect_audio_media_type(data: bytes | str) -> str | None: - """Detect audio format from magic bytes.""" - return detect_media_type(data, AUDIO_SIGNATURES) diff --git a/src/vercel_ai_sdk/models/core/media/download.py b/src/vercel_ai_sdk/models/core/media/download.py deleted file mode 100644 index ef3757af..00000000 --- a/src/vercel_ai_sdk/models/core/media/download.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Async download utility for URL-based file parts. - -Port of ``@ai-sdk/ai/src/util/download/download.ts``. Used by -provider adapters that need to fetch a URL the provider API cannot -accept natively (e.g. OpenAI does not support audio/PDF URLs). -""" - -from __future__ import annotations - -import httpx - -DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) -_ALLOWED_SCHEMES = frozenset({"http", "https"}) - - -class DownloadError(Exception): - """Raised when a URL download fails.""" - - def __init__( - self, - url: str, - *, - status_code: int | None = None, - status_text: str | None = None, - cause: BaseException | None = None, - ) -> None: - parts = [f"Failed to download {url!r}"] - if status_code is not None: - parts.append(f"status={status_code}") - if status_text: - parts.append(status_text) - super().__init__(": ".join(parts)) - self.url = url - self.status_code = status_code - if cause is not None: - self.__cause__ = cause - - -def _validate_url(url: str) -> None: - """Reject non-HTTP(S) URLs (SSRF prevention).""" - from urllib.parse import urlparse - - parsed = urlparse(url) - if parsed.scheme not in _ALLOWED_SCHEMES: - raise DownloadError( - url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" - ) - - -async def download( - url: str, - *, - max_bytes: int = DEFAULT_MAX_BYTES, -) -> tuple[bytes, str | None]: - """Download *url* and return ``(data, content_type)``. - - Args: - url: The URL to fetch (must be ``http`` or ``https``). - max_bytes: Maximum response size. Defaults to 100 MiB. - - Returns: - A tuple of ``(raw_bytes, content_type_or_None)``. - - Raises: - DownloadError: On any failure (network, HTTP status, size, etc.). - """ - _validate_url(url) - - try: - async with httpx.AsyncClient(follow_redirects=True) as client: - resp = await client.get(url) - - # Validate redirect target - if resp.url is not None and str(resp.url) != url: - _validate_url(str(resp.url)) - - if resp.status_code >= 400: - raise DownloadError( - url, - status_code=resp.status_code, - status_text=resp.reason_phrase or "", - ) - - data = resp.content - if len(data) > max_bytes: - raise DownloadError( - url, - status_text=( - f"Response exceeds maximum size " - f"({len(data)} > {max_bytes} bytes)" - ), - ) - - content_type = resp.headers.get("content-type") - # Strip charset/parameters: "image/png; charset=..." → "image/png" - if content_type: - content_type = content_type.split(";")[0].strip() - - return data, content_type or None - - except DownloadError: - raise - except Exception as exc: - raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/models/core/model.py b/src/vercel_ai_sdk/models/core/model.py index 3b6d8797..cbf59f50 100644 --- a/src/vercel_ai_sdk/models/core/model.py +++ b/src/vercel_ai_sdk/models/core/model.py @@ -1,4 +1,4 @@ -"""Model — pure data describing a model, no execution logic.""" +"""Model metadata types.""" from __future__ import annotations @@ -6,24 +6,29 @@ @dataclasses.dataclass(frozen=True) -class Model: - """Immutable description of a model. +class ModelCost: + """Per-million-token pricing.""" + + input: float = 0.0 + output: float = 0.0 + cache_read: float = 0.0 + cache_write: float = 0.0 - ``id`` - The model identifier sent to the provider - (e.g. ``"claude-sonnet-4-20250514"``, ``"gpt-4o"``). - ``api`` - Wire protocol discriminator used to look up the execution function - (e.g. ``"anthropic"``, ``"openai"``, ``"ai-gateway"``). - A single ``api`` value may be shared by multiple providers that speak - the same wire format. +@dataclasses.dataclass(frozen=True) +class Model: + """Pure-data description of a model. - ``provider`` - The actual host / provider name - (e.g. ``"anthropic"``, ``"azure"``, ``"ai-gateway"``). + * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). + * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). + * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). """ id: str - api: str + adapter: str provider: str + name: str = "" + capabilities: tuple[str, ...] = ("text",) + context_window: int = 0 + max_output_tokens: int = 0 + cost: ModelCost | None = None diff --git a/src/vercel_ai_sdk/models2/core/proto.py b/src/vercel_ai_sdk/models/core/proto.py similarity index 100% rename from src/vercel_ai_sdk/models2/core/proto.py rename to src/vercel_ai_sdk/models/core/proto.py diff --git a/src/vercel_ai_sdk/models/core/protocol.py b/src/vercel_ai_sdk/models/core/protocol.py deleted file mode 100644 index 56d3eb8a..00000000 --- a/src/vercel_ai_sdk/models/core/protocol.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Execution protocols and the Stream result type. - -``StreamFn`` and ``GenerateFn`` define the execution contract that -provider adapters must satisfy. ``Stream`` wraps an async generator -of :class:`Message` objects into an async-iterable *and* awaitable -result with convenience properties. -""" - -from __future__ import annotations - -from collections.abc import AsyncGenerator, Generator, Sequence -from typing import Any, Protocol, runtime_checkable - -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from .model import Model - -# ── Execution protocols ─────────────────────────────────────────── - - -@runtime_checkable -class StreamFn(Protocol): - """Protocol for streaming LLM calls. - - Implementations accept a :class:`Model`, messages, and optional tools / - output type, and return an async generator that yields - :class:`Message` snapshots as the response streams in. - """ - - def __call__( - self, - model: Model, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: ... - - -@runtime_checkable -class GenerateFn(Protocol): - """Protocol for non-streaming generation (image, video, etc.). - - Implementations accept a :class:`Model`, messages, and arbitrary - keyword arguments forwarded from the caller. - """ - - async def __call__( - self, - model: Model, - messages: list[messages_.Message], - **kwargs: Any, - ) -> messages_.Message: ... - - -# ── Stream result ───────────────────────────────────────────────── - - -class Stream: - """Async-iterable *and* awaitable wrapper around a message generator. - - Usage:: - - # Streaming - stream = Stream(gen) - async for msg in stream: - print(msg.text) - - # Or just await the final result - stream = Stream(gen) - await stream - stream.result # last Message - stream.text # concatenated text - """ - - def __init__(self, generator: AsyncGenerator[messages_.Message]) -> None: - self._generator = generator - self._messages: list[messages_.Message] = [] - self._done = False - - # ── Async iteration ─────────────────────────────────────────── - - async def __aiter__(self) -> AsyncGenerator[messages_.Message]: - if self._done: - # Already consumed — replay from buffer - for msg in self._messages: - yield msg - return - - async for msg in self._generator: - self._messages.append(msg) - yield msg - self._done = True - - # ── Awaitable ───────────────────────────────────────────────── - - def __await__(self) -> Generator[Any, None, Stream]: - return self._drain().__await__() - - async def _drain(self) -> Stream: - """Consume the entire generator, populating result fields.""" - if not self._done: - async for _ in self: - pass - return self - - # ── Result properties (available after iteration / await) ───── - - @property - def messages(self) -> list[messages_.Message]: - """All messages yielded during streaming.""" - return list(self._messages) - - @property - def result(self) -> messages_.Message | None: - """The last message (final snapshot), or ``None`` if empty.""" - return self._messages[-1] if self._messages else None - - @property - def tool_calls(self) -> list[messages_.ToolPart]: - """Tool-call parts from the final message.""" - if not self._messages: - return [] - return [ - p for p in self._messages[-1].parts if isinstance(p, messages_.ToolPart) - ] - - @property - def text(self) -> str: - """Concatenated text from the final message.""" - if not self._messages: - return "" - return "".join( - p.text - for p in self._messages[-1].parts - if isinstance(p, messages_.TextPart) - ) - - @property - def usage(self) -> messages_.Usage | None: - """Usage from the final message, if available.""" - if not self._messages: - return None - return self._messages[-1].usage diff --git a/src/vercel_ai_sdk/models/core/registry.py b/src/vercel_ai_sdk/models/core/registry.py deleted file mode 100644 index 033dba45..00000000 --- a/src/vercel_ai_sdk/models/core/registry.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Registry mapping ``api`` strings to execution functions. - -Provider adapters call :func:`register_stream` / :func:`register_generate` -to make themselves available. The module-level ``stream()`` and -``generate()`` functions in :mod:`vercel_ai_sdk.models` use -:func:`get_stream_fn` / :func:`get_generate_fn` to dispatch. -""" - -from __future__ import annotations - -from .protocol import GenerateFn, StreamFn - -_stream_fns: dict[str, StreamFn] = {} -_generate_fns: dict[str, GenerateFn] = {} - - -def register_stream(api: str, fn: StreamFn) -> None: - """Register a :class:`StreamFn` for the given wire-protocol ``api``.""" - _stream_fns[api] = fn - - -def register_generate(api: str, fn: GenerateFn) -> None: - """Register a :class:`GenerateFn` for the given wire-protocol ``api``.""" - _generate_fns[api] = fn - - -def get_stream_fn(api: str) -> StreamFn: - """Look up the registered :class:`StreamFn` for ``api``. - - Raises :class:`KeyError` with a descriptive message if no function - has been registered for the given ``api``. - """ - try: - return _stream_fns[api] - except KeyError: - registered = ", ".join(sorted(_stream_fns)) or "(none)" - raise KeyError( - f"No StreamFn registered for api={api!r}. Registered: {registered}" - ) from None - - -def get_generate_fn(api: str) -> GenerateFn: - """Look up the registered :class:`GenerateFn` for ``api``. - - Raises :class:`KeyError` with a descriptive message if no function - has been registered for the given ``api``. - """ - try: - return _generate_fns[api] - except KeyError: - registered = ", ".join(sorted(_generate_fns)) or "(none)" - raise KeyError( - f"No GenerateFn registered for api={api!r}. Registered: {registered}" - ) from None diff --git a/src/vercel_ai_sdk/models/core/video.py b/src/vercel_ai_sdk/models/core/video.py deleted file mode 100644 index 84e1d074..00000000 --- a/src/vercel_ai_sdk/models/core/video.py +++ /dev/null @@ -1,66 +0,0 @@ -"""VideoModel — abstract video generation model.""" - -from __future__ import annotations - -import abc -from typing import Any, override - -from ...types import messages as messages_ -from .media.base import MediaModel, MediaResult - - -class VideoModel(MediaModel): - """Abstract video generation model. - - Accepts :class:`Message`\\s as input and returns a :class:`Message` - containing generated videos as :class:`FilePart`\\s. - - Adapter authors implement :meth:`make_request`; the framework handles - parsing messages and assembling the response :class:`Message`. - """ - - async def generate( - self, - messages: list[messages_.Message], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> messages_.Message: - """Generate videos from the given messages.""" - prompt = self._extract_prompt(messages) - input_files = self._extract_input_files(messages) - result = await self.make_request( - prompt, - input_files, - n=n, - aspect_ratio=aspect_ratio, - resolution=resolution, - duration=duration, - fps=fps, - seed=seed, - provider_options=provider_options, - ) - return self._build_message(result) - - @override - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific video generation.""" - ... diff --git a/src/vercel_ai_sdk/models/openai/__init__.py b/src/vercel_ai_sdk/models/openai/__init__.py index 4b83b500..bd01bcd1 100644 --- a/src/vercel_ai_sdk/models/openai/__init__.py +++ b/src/vercel_ai_sdk/models/openai/__init__.py @@ -1,5 +1,7 @@ -"""OpenAI provider adapter.""" +"""OpenAI provider — adapter for the OpenAI chat completions API.""" -from .llm import OpenAIModel, _messages_to_openai +from .adapter import stream -__all__ = ["OpenAIModel", "_messages_to_openai"] +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models2/openai/adapter.py b/src/vercel_ai_sdk/models/openai/adapter.py similarity index 100% rename from src/vercel_ai_sdk/models2/openai/adapter.py rename to src/vercel_ai_sdk/models/openai/adapter.py diff --git a/src/vercel_ai_sdk/models/openai/llm.py b/src/vercel_ai_sdk/models/openai/llm.py deleted file mode 100644 index 46dd3a0d..00000000 --- a/src/vercel_ai_sdk/models/openai/llm.py +++ /dev/null @@ -1,367 +0,0 @@ -from __future__ import annotations - -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import openai -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core import media - - -def _tools_to_openai(tools: Sequence[tools_.ToolLike]) -> list[dict[str, Any]]: - """Convert internal Tool objects to OpenAI tool schema format.""" - return [ - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.param_schema, - }, - } - for tool in tools - ] - - -async def _file_part_to_openai(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to an OpenAI content-array element. - - Follows the OpenAI chat-completions content part formats: - - * ``image/*`` → ``image_url`` (URL or ``data:`` URL) - * ``audio/*`` → ``input_audio`` (base-64 only; URLs auto-downloaded) - * ``application/pdf`` → ``file`` (base-64 only; URLs auto-downloaded) - * ``text/*`` → ``text`` (decoded to string) - * anything else → ``ValueError`` - - OpenAI does not accept URLs for audio ``input_audio`` or PDF ``file`` - parts. When URL data is provided for these types, it is downloaded - automatically (matching the TS SDK's ``downloadAssets`` behaviour). - """ - mt = part.media_type - data = part.data - - if mt.startswith("image/"): - media_type = "image/jpeg" if mt == "image/*" else mt - url = media.data.data_to_data_url(data, media_type) - return {"type": "image_url", "image_url": {"url": url}} - - if mt.startswith("audio/"): - # OpenAI input_audio requires raw base-64 — download http(s) URLs. - if isinstance(data, str) and media.data.is_downloadable_url(data): - downloaded, _ = await media.download.download(data) - data = downloaded - fmt = mt.split("/", 1)[1] if "/" in mt else mt - b64 = media.data.data_to_base64(data) - return {"type": "input_audio", "input_audio": {"data": b64, "format": fmt}} - - if mt == "application/pdf": - # OpenAI file parts require base-64 — download http(s) URLs. - if isinstance(data, str) and media.data.is_downloadable_url(data): - downloaded, _ = await media.download.download(data) - data = downloaded - data_url = media.data.data_to_data_url(data, mt) - filename = part.filename or "document.pdf" - return {"type": "file", "file": {"filename": filename, "file_data": data_url}} - - if mt.startswith("text/"): - # Decode text content — URLs are passed through as text, - # bytes/base-64 are decoded to UTF-8 string. - if isinstance(data, bytes): - text_content = data.decode("utf-8") - elif media.data.is_url(data): - text_content = data - else: - import base64 as _b64 - - text_content = _b64.b64decode(data).decode("utf-8") - return {"type": "text", "text": text_content} - - raise ValueError(f"Unsupported media type for OpenAI: {mt}") - - -async def _messages_to_openai( - messages: list[messages_.Message], -) -> list[dict[str, Any]]: - """Convert internal messages to OpenAI API format. - - Converts to the OpenAI wire format: - - - ``tool_calls`` on assistant messages - - tool results as separate ``role: "tool"`` messages - - The Vercel AI Gateway preserves reasoning details across interactions, - normalizing formats from different providers. - - See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - """ - result: list[dict[str, Any]] = [] - for msg in messages: - match msg.role: - case "assistant": - content = "" - reasoning = "" - tool_calls = [] - tool_results = [] - - for part in msg.parts: - match part: - case messages_.ReasoningPart(text=text): - reasoning += text - case messages_.TextPart(text=text): - content += text - case messages_.ToolPart(): - tool_calls.append( - { - "id": part.tool_call_id, - "type": "function", - "function": { - "name": part.tool_name, - "arguments": part.tool_args, - }, - } - ) - if part.status in ("result", "error"): - tool_results.append( - { - "role": "tool", - "tool_call_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - ) - - entry: dict[str, Any] = {"role": "assistant"} - if content: - entry["content"] = content - if reasoning: - entry["reasoning"] = reasoning - if tool_calls: - entry["tool_calls"] = tool_calls - result.append(entry) - - # Emit tool results as separate messages (OpenAI API format) - result.extend(tool_results) - case "system": - content = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "system", "content": content}) - case "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - # Text-only: keep simple string format (cheaper, no content array) - text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "user", "content": text}) - else: - parts: list[dict[str, Any]] = [] - for p in msg.parts: - match p: - case messages_.TextPart(text=text): - parts.append({"type": "text", "text": text}) - case messages_.FilePart(): - parts.append(await _file_part_to_openai(p)) - result.append({"role": "user", "content": parts}) - return result - - -class OpenAIModel(llm_.LanguageModel): - """OpenAI adapter with reasoning/thinking support via Vercel AI Gateway. - - Supports reasoning for models like GPT 5.x, o-series, and Claude via gateway. - Uses the Vercel AI Gateway's unified reasoning API format. - - See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - """ - - def __init__( - self, - model: str = "gpt-4o", - base_url: str | None = None, - api_key: str | None = None, - thinking: bool = False, - budget_tokens: int | None = None, - reasoning_effort: str | None = None, - ) -> None: - """Initialize OpenAI model adapter. - - Args: - model: Model identifier - (e.g., 'openai/gpt-5.2', 'anthropic/claude-sonnet-4.5') - base_url: API base URL - (e.g., 'https://ai-gateway.vercel.sh/v1') - api_key: API key for authentication - thinking: Enable reasoning/thinking output - budget_tokens: Max tokens for reasoning - (mutually exclusive with reasoning_effort) - reasoning_effort: Effort level — 'none', 'minimal', - 'low', 'medium', 'high', 'xhigh' - (mutually exclusive with budget_tokens) - """ - self._model = model - self._thinking = thinking - self._budget_tokens = budget_tokens - self._reasoning_effort = reasoning_effort - resolved_key = api_key or os.environ.get("OPENAI_API_KEY") or "" - self._client = openai.AsyncOpenAI(base_url=base_url, api_key=resolved_key) - - @override - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield raw stream events from OpenAI API.""" - openai_messages = await _messages_to_openai(messages) - openai_tools = _tools_to_openai(tools) if tools else None - - kwargs: dict[str, Any] = { - "model": self._model, - "messages": openai_messages, - "stream": True, - } - if openai_tools: - kwargs["tools"] = openai_tools - kwargs["stream_options"] = {"include_usage": True} - - if output_type is not None: - from openai.lib._pydantic import to_strict_json_schema - - kwargs["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": output_type.__name__, - "schema": to_strict_json_schema(output_type), - "strict": True, - }, - } - - # Enable reasoning/thinking via Vercel AI Gateway's unified format - # See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - if self._thinking: - reasoning_config: dict[str, Any] = {"enabled": True} - # Use budget_tokens OR reasoning_effort (mutually exclusive per docs) - if self._budget_tokens is not None: - reasoning_config["max_tokens"] = self._budget_tokens - elif self._reasoning_effort is not None: - reasoning_config["effort"] = self._reasoning_effort - kwargs["extra_body"] = {"reasoning": reasoning_config} - - stream = await self._client.chat.completions.create(**kwargs) - - # Track active blocks for Start/End events - text_started = False - reasoning_started = False - tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, started} - finish_reason: str | None = None - usage: messages_.Usage | None = None - - async for chunk in stream: - # Extract usage from any chunk that carries it (typically the final - # chunk when stream_options.include_usage is True). - if chunk.usage is not None: - raw = chunk.usage.model_dump(exclude_none=True) - # Extract optional breakdowns - reasoning_tokens: int | None = None - cache_read: int | None = None - completion_details = getattr( - chunk.usage, "completion_tokens_details", None - ) - if completion_details: - reasoning_tokens = getattr( - completion_details, "reasoning_tokens", None - ) - prompt_details = getattr(chunk.usage, "prompt_tokens_details", None) - if prompt_details: - cache_read = getattr(prompt_details, "cached_tokens", None) - usage = messages_.Usage( - input_tokens=chunk.usage.prompt_tokens or 0, - output_tokens=chunk.usage.completion_tokens or 0, - reasoning_tokens=reasoning_tokens, - cache_read_tokens=cache_read, - raw=raw, - ) - - if not chunk.choices: - continue - - choice = chunk.choices[0] - delta = choice.delta - - # Handle reasoning/thinking content via Vercel AI Gateway - # The gateway may return reasoning in different ways: - # 1. As a direct attribute (if SDK supports it) - # 2. In model_extra (Pydantic v2 extra fields) - reasoning_value = None - if hasattr(delta, "reasoning") and delta.reasoning: - reasoning_value = delta.reasoning - elif hasattr(delta, "model_extra") and delta.model_extra: - reasoning_value = delta.model_extra.get("reasoning") - - if reasoning_value: - if not reasoning_started: - reasoning_started = True - yield llm_.ReasoningStart(block_id="reasoning") - yield llm_.ReasoningDelta(block_id="reasoning", delta=reasoning_value) - - if delta.content: - # Close reasoning block when text starts (reasoning precedes text) - if reasoning_started: - yield llm_.ReasoningEnd(block_id="reasoning") - reasoning_started = False - - if not text_started: - text_started = True - yield llm_.TextStart(block_id="text") - yield llm_.TextDelta(block_id="text", delta=delta.content) - - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in tool_calls: - tool_calls[idx] = {"id": tc.id, "name": None, "started": False} - if tc.id: - tool_calls[idx]["id"] = tc.id - if tc.function: - if tc.function.name: - tool_calls[idx]["name"] = tc.function.name - if tc.function.arguments: - tool_id = tool_calls[idx]["id"] - tool_name = tool_calls[idx]["name"] or "" - - # Emit start if not started - if not tool_calls[idx]["started"] and tool_id: - tool_calls[idx]["started"] = True - yield llm_.ToolStart( - tool_call_id=tool_id, tool_name=tool_name - ) - - if tool_id: - yield llm_.ToolArgsDelta( - tool_call_id=tool_id, delta=tc.function.arguments - ) - - if choice.finish_reason is not None: - finish_reason = choice.finish_reason - # Close any open blocks - if reasoning_started: - yield llm_.ReasoningEnd(block_id="reasoning") - if text_started: - yield llm_.TextEnd(block_id="text") - for tc in tool_calls.values(): - if tc["started"] and tc["id"]: - yield llm_.ToolEnd(tool_call_id=tc["id"]) - - # Don't return yet — the usage chunk may arrive after - # finish_reason. We'll emit MessageDone after the loop. - - yield llm_.MessageDone(finish_reason=finish_reason, usage=usage) diff --git a/src/vercel_ai_sdk/models2/__init__.py b/src/vercel_ai_sdk/models2/__init__.py deleted file mode 100644 index fb78f460..00000000 --- a/src/vercel_ai_sdk/models2/__init__.py +++ /dev/null @@ -1,205 +0,0 @@ -"""models2 — composable model layer. - -Usage:: - - from vercel_ai_sdk import models2 as m - from vercel_ai_sdk.types import Message, TextPart - - model = m.Model( - id="anthropic/claude-sonnet-4", - adapter="ai-gateway-v3", - provider="ai-gateway", - ) - msgs = [Message(role="user", parts=[TextPart(text="hello")])] - - # stream — auto-creates client from env vars - async for msg in m.stream(model, msgs): - print(msg.text_delta, end="") - - # buffer the whole response - result = await m.buffer(m.stream(model, msgs)) - print(result.text) - - # explicit client - client = m.Client(base_url="https://custom.example.com/v3/ai", api_key="sk-...") - async for msg in m.stream(model, msgs, client=client): - ... -""" - -from __future__ import annotations - -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any - -import pydantic - -from ..types import messages as messages_ -from ..types import tools as tools_ -from .ai_gateway.generate import GenerateParams, ImageParams, VideoParams -from .core.client import Client -from .core.model import Model, ModelCost -from .core.proto import GenerateFn, StreamFn - -# --------------------------------------------------------------------------- -# Adapter registry — maps adapter string → adapter function. -# Adapter modules are imported lazily on first use. -# --------------------------------------------------------------------------- - -_stream_adapters: dict[str, StreamFn] = {} -_generate_adapters: dict[str, GenerateFn] = {} -_adapters_loaded = False - - -def _ensure_adapters() -> None: - """Lazily register built-in adapter functions on first call.""" - global _adapters_loaded # noqa: PLW0603 - if _adapters_loaded: - return - _adapters_loaded = True - - from .ai_gateway import generate as ai_gw_generate - from .ai_gateway import stream as ai_gw_stream - from .anthropic.adapter import stream as anthropic_stream - from .openai.adapter import stream as openai_stream - - _stream_adapters["ai-gateway-v3"] = ai_gw_stream - _generate_adapters["ai-gateway-v3"] = ai_gw_generate - _stream_adapters["openai"] = openai_stream - _stream_adapters["anthropic"] = anthropic_stream - - -def register_stream(adapter: str, fn: StreamFn) -> None: - """Register a stream adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _stream_adapters[adapter] = fn - - -def register_generate(adapter: str, fn: GenerateFn) -> None: - """Register a generate adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _generate_adapters[adapter] = fn - - -# --------------------------------------------------------------------------- -# Provider defaults — base URLs and env var names for auto-client creation. -# --------------------------------------------------------------------------- - -_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { - "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), - "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), - "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), -} - - -def _auto_client(model: Model) -> Client: - """Create a :class:`Client` from env vars for the given model's provider.""" - defaults = _PROVIDER_DEFAULTS.get(model.provider) - if defaults is None: - raise ValueError( - f"No default client config for provider {model.provider!r}. " - f"Pass an explicit client= argument." - ) - base_url, env_var = defaults - return Client(base_url=base_url, api_key=os.environ.get(env_var)) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -async def stream( - model: Model, - messages: list[messages_.Message], - *, - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - client: Client | None = None, - **kwargs: Any, -) -> AsyncGenerator[messages_.Message]: - """Stream an LLM response. - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from env vars if none is provided, and yields - ``Message`` snapshots. - """ - _ensure_adapters() - c = client or _auto_client(model) - adapter_fn = _stream_adapters.get(model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_stream_adapters)) or "(none)" - raise KeyError( - f"No stream adapter registered for adapter={model.adapter!r}. " - f"Registered: {registered}" - ) - async for msg in adapter_fn( - c, model, messages, tools=tools, output_type=output_type, **kwargs - ): - yield msg - - -async def generate( - model: Model, - messages: list[messages_.Message], - params: GenerateParams | None = None, - *, - client: Client | None = None, -) -> messages_.Message: - """Generate a response (images, video, etc.). - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from env vars if none is provided. - - ``params`` controls the generation type: - - * :class:`ImageParams` — image generation (``/image-model``). - * :class:`VideoParams` — video generation (``/video-model``). - * ``None`` — auto-detect from ``model.capabilities``. - """ - _ensure_adapters() - c = client or _auto_client(model) - adapter_fn = _generate_adapters.get(model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_generate_adapters)) or "(none)" - raise KeyError( - f"No generate adapter registered for adapter={model.adapter!r}. " - f"Registered: {registered}" - ) - return await adapter_fn(c, model, messages, params=params) - - -async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: - """Drain a stream and return the final ``Message``. - - Raises :class:`ValueError` if the stream yields nothing. - """ - result: messages_.Message | None = None - async for msg in gen: - result = msg - if result is None: - raise ValueError("empty stream") - return result - - -__all__ = [ - # Core types - "Client", - "GenerateFn", - "GenerateParams", - "ImageParams", - "Model", - "ModelCost", - "StreamFn", - "VideoParams", - # Public API - "buffer", - "generate", - "register_generate", - "register_stream", - "stream", -] diff --git a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py b/src/vercel_ai_sdk/models2/ai_gateway/__init__.py deleted file mode 100644 index 7cc9f429..00000000 --- a/src/vercel_ai_sdk/models2/ai_gateway/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" - -from . import errors -from .generate import GenerateParams, ImageParams, VideoParams, generate -from .stream import stream - -__all__ = [ - "GenerateParams", - "ImageParams", - "VideoParams", - "errors", - "generate", - "stream", -] diff --git a/src/vercel_ai_sdk/models2/ai_gateway/errors.py b/src/vercel_ai_sdk/models2/ai_gateway/errors.py deleted file mode 100644 index d0dade24..00000000 --- a/src/vercel_ai_sdk/models2/ai_gateway/errors.py +++ /dev/null @@ -1,305 +0,0 @@ -"""Vercel AI Gateway error hierarchy. - -Maps HTTP error responses from the gateway server to typed Python exceptions. -Each error class corresponds to a specific ``error.type`` value in the -gateway's JSON error response format:: - - { - "error": { - "message": "...", - "type": "authentication_error" | "invalid_request_error" | ..., - "param": ..., - "code": ... - }, - "generationId": "..." - } -""" - -import json -from typing import Any, Self - -_KEY_URL = "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys" - - -# --------------------------------------------------------------------------- -# Base class -# --------------------------------------------------------------------------- - - -class GatewayError(Exception): - """Base class for all Vercel AI Gateway errors.""" - - type: str = "gateway_error" - - def __init__( - self, - message: str = "", - *, - status_code: int = 500, - cause: BaseException | None = None, - generation_id: str | None = None, - ) -> None: - display = f"{message} [{generation_id}]" if generation_id else message - super().__init__(display) - self.status_code = status_code - self.generation_id = generation_id - if cause is not None: - self.__cause__ = cause - - -# --------------------------------------------------------------------------- -# Concrete errors — thin subclasses that set type + default status_code -# --------------------------------------------------------------------------- - - -class GatewayAuthenticationError(GatewayError): - """Authentication failed (HTTP 401).""" - - type = "authentication_error" - - def __init__( - self, - message: str = "Authentication failed", - *, - status_code: int = 401, - cause: BaseException | None = None, - generation_id: str | None = None, - ) -> None: - super().__init__( - message, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) - - @classmethod - def create_contextual( - cls, - *, - api_key_provided: bool, - status_code: int = 401, - cause: BaseException | None = None, - generation_id: str | None = None, - ) -> Self: - """Build a helpful message based on which auth method was used.""" - if api_key_provided: - msg = ( - "AI Gateway authentication failed: Invalid API key.\n\n" - f"Create a new API key: {_KEY_URL}\n\n" - "Provide via 'api_key' option or " - "'AI_GATEWAY_API_KEY' environment variable." - ) - else: - msg = ( - "AI Gateway authentication failed: " - "No authentication provided.\n\n" - f"Create an API key: {_KEY_URL}\n" - "Provide via 'api_key' option or " - "'AI_GATEWAY_API_KEY' environment variable." - ) - return cls( - msg, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) - - -class GatewayInvalidRequestError(GatewayError): - """Malformed or invalid request (HTTP 400).""" - - type = "invalid_request_error" - - def __init__( - self, - message: str = "Invalid request", - *, - status_code: int = 400, - **kwargs: Any, - ) -> None: - super().__init__(message, status_code=status_code, **kwargs) - - -class GatewayRateLimitError(GatewayError): - """Rate limit exceeded (HTTP 429).""" - - type = "rate_limit_exceeded" - - def __init__( - self, - message: str = "Rate limit exceeded", - *, - status_code: int = 429, - **kwargs: Any, - ) -> None: - super().__init__(message, status_code=status_code, **kwargs) - - -class GatewayModelNotFoundError(GatewayError): - """Requested model was not found (HTTP 404).""" - - type = "model_not_found" - - def __init__( - self, - message: str = "Model not found", - *, - status_code: int = 404, - model_id: str | None = None, - cause: BaseException | None = None, - generation_id: str | None = None, - ) -> None: - super().__init__( - message, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) - self.model_id = model_id - - -class GatewayInternalServerError(GatewayError): - """Internal error on the gateway server (HTTP 500).""" - - type = "internal_server_error" - - def __init__( - self, - message: str = "Internal server error", - *, - status_code: int = 500, - **kwargs: Any, - ) -> None: - super().__init__(message, status_code=status_code, **kwargs) - - -class GatewayResponseError(GatewayError): - """Malformed or unparseable response (HTTP 502).""" - - type = "response_error" - - def __init__( - self, - message: str = "Invalid response", - *, - status_code: int = 502, - response: Any = None, - validation_error: Any = None, - cause: BaseException | None = None, - generation_id: str | None = None, - ) -> None: - super().__init__( - message, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) - self.response = response - self.validation_error = validation_error - - -class GatewayTimeoutError(GatewayError): - """Gateway request timed out (HTTP 408).""" - - type = "timeout_error" - - def __init__( - self, - message: str = "Request timed out", - *, - status_code: int = 408, - **kwargs: Any, - ) -> None: - super().__init__(message, status_code=status_code, **kwargs) - - -# --------------------------------------------------------------------------- -# Error factory -# --------------------------------------------------------------------------- - -_TYPE_MAP: dict[str, type[GatewayError]] = { - "authentication_error": GatewayAuthenticationError, - "invalid_request_error": GatewayInvalidRequestError, - "rate_limit_exceeded": GatewayRateLimitError, - "model_not_found": GatewayModelNotFoundError, - "internal_server_error": GatewayInternalServerError, -} - -_MALFORMED = "Invalid error response format: Gateway request failed" - - -def create_gateway_error( - *, - response_body: Any, - status_code: int, - api_key_provided: bool = False, - cause: BaseException | None = None, -) -> GatewayError: - """Create a typed error from a gateway JSON error response. - - Falls back to :class:`GatewayResponseError` when the body doesn't - match the expected ``{"error": {"message": ..., "type": ...}}`` - shape. - """ - # Parse the response body - body: Any = response_body - if isinstance(body, (str, bytes)): - try: - body = json.loads(body) - except (json.JSONDecodeError, ValueError): - return GatewayResponseError( - message=_MALFORMED, - status_code=status_code, - response=response_body, - validation_error="Response body is not valid JSON", - cause=cause, - ) - - # Validate shape - error_obj = body.get("error") if isinstance(body, dict) else None - if not isinstance(error_obj, dict) or "message" not in error_obj: - reason = ( - "Missing 'error' field in response" - if not isinstance(error_obj, dict) - else "Missing 'message' field in error object" - ) - return GatewayResponseError( - message=_MALFORMED, - status_code=status_code, - response=body, - validation_error=reason, - cause=cause, - ) - - message: str = error_obj["message"] - error_type: str | None = error_obj.get("type") - generation_id: str | None = body.get("generationId") - - match error_type: - case "authentication_error": - return GatewayAuthenticationError.create_contextual( - api_key_provided=api_key_provided, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) - - case "model_not_found": - param = error_obj.get("param") - model_id = param.get("modelId") if isinstance(param, dict) else None - return GatewayModelNotFoundError( - message=message, - status_code=status_code, - model_id=model_id, - cause=cause, - generation_id=generation_id, - ) - - case _: - cls = _TYPE_MAP.get(error_type or "", GatewayInternalServerError) - return cls( - message=message, - status_code=status_code, - cause=cause, - generation_id=generation_id, - ) diff --git a/src/vercel_ai_sdk/models2/anthropic/__init__.py b/src/vercel_ai_sdk/models2/anthropic/__init__.py deleted file mode 100644 index a9a0436b..00000000 --- a/src/vercel_ai_sdk/models2/anthropic/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""Anthropic provider — adapter for the Anthropic messages API.""" - -from .adapter import stream - -__all__ = [ - "stream", -] diff --git a/src/vercel_ai_sdk/models2/core/__init__.py b/src/vercel_ai_sdk/models2/core/__init__.py deleted file mode 100644 index a99a9797..00000000 --- a/src/vercel_ai_sdk/models2/core/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Core types for models2.""" - -from .client import Client -from .model import Model, ModelCost -from .proto import GenerateFn, StreamFn - -__all__ = [ - "Client", - "GenerateFn", - "Model", - "ModelCost", - "StreamFn", -] diff --git a/src/vercel_ai_sdk/models2/core/model.py b/src/vercel_ai_sdk/models2/core/model.py deleted file mode 100644 index cbf59f50..00000000 --- a/src/vercel_ai_sdk/models2/core/model.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Model metadata types.""" - -from __future__ import annotations - -import dataclasses - - -@dataclasses.dataclass(frozen=True) -class ModelCost: - """Per-million-token pricing.""" - - input: float = 0.0 - output: float = 0.0 - cache_read: float = 0.0 - cache_write: float = 0.0 - - -@dataclasses.dataclass(frozen=True) -class Model: - """Pure-data description of a model. - - * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). - * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). - * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). - """ - - id: str - adapter: str - provider: str - name: str = "" - capabilities: tuple[str, ...] = ("text",) - context_window: int = 0 - max_output_tokens: int = 0 - cost: ModelCost | None = None diff --git a/src/vercel_ai_sdk/models2/openai/__init__.py b/src/vercel_ai_sdk/models2/openai/__init__.py deleted file mode 100644 index bd01bcd1..00000000 --- a/src/vercel_ai_sdk/models2/openai/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -"""OpenAI provider — adapter for the OpenAI chat completions API.""" - -from .adapter import stream - -__all__ = [ - "stream", -] diff --git a/src/vercel_ai_sdk/types/messages.py b/src/vercel_ai_sdk/types/messages.py index 41a088b7..683c7469 100644 --- a/src/vercel_ai_sdk/types/messages.py +++ b/src/vercel_ai_sdk/types/messages.py @@ -138,9 +138,9 @@ def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: ``media_type`` is provided. """ if media_type is None: - from ..models.core.media import data as media_data + from ..models.core.helpers import media as media_helpers - media_type = media_data.infer_media_type(url) + media_type = media_helpers.infer_media_type(url) return cls(data=url, media_type=media_type) @classmethod @@ -158,11 +158,11 @@ def from_bytes( detection fails. """ if media_type is None: - from ..models.core.media import detect as media_detect + from ..models.core.helpers import media as media_helpers - media_type = media_detect.detect_image_media_type( + media_type = media_helpers.detect_image_media_type( data - ) or media_detect.detect_audio_media_type(data) + ) or media_helpers.detect_audio_media_type(data) if media_type is None: raise ValueError( "Cannot detect media_type from bytes. Provide media_type explicitly." diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index a7910630..1bbe2756 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -9,7 +9,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.adapters.ai_sdk_ui import adapter, ui_message -from vercel_ai_sdk.agents2 import hooks +from vercel_ai_sdk.agents import hooks from vercel_ai_sdk.types import messages from ...conftest import MOCK_MODEL, mock_llm, tool_msg diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 78bee949..a1220ac8 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -64,12 +64,12 @@ def test_mcp_tool_to_native_schema_preserved() -> None: assert native.description == "Echo input" -# -- End-to-end: MCP tool executes through stream_loop -------------------- +# -- End-to-end: MCP tool executes through Agent default loop --------------- @pytest.mark.asyncio -async def test_mcp_tool_executes_through_stream_loop() -> None: - """MCP-style tool via _mcp_tool_to_native can be called by the agent loop.""" +async def test_mcp_tool_executes_through_agent() -> None: + """MCP-style tool via _mcp_tool_to_native works with Agent.""" call_log: list[dict[str, str]] = [] async def fake_fn(**kwargs: str) -> str: @@ -84,18 +84,13 @@ async def fake_fn(**kwargs: str) -> str: native._fn = fake_fn _tool_registry[native.name] = native - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="echo hello"), - tools=[native], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[native]) call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] call2 = [text_msg("Done.", id="msg-2")] llm = mock_llm([call1, call2]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="echo hello")) msgs = [m async for m in result] # Tool was called with the right args diff --git a/tests/agents/test_checkpoint.py b/tests/agents/test_checkpoint.py index c90ee925..22fbd583 100644 --- a/tests/agents/test_checkpoint.py +++ b/tests/agents/test_checkpoint.py @@ -23,19 +23,20 @@ class Approval(pydantic.BaseModel): @pytest.mark.asyncio async def test_step_replay_skips_llm() -> None: - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_step( - model, messages=ai.make_messages(system="test", user="hello") - ) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) llm1 = mock_llm([[text_msg("Hi there!")]]) - result1 = ai.run(graph, MOCK_MODEL) + result1 = my_agent.run(ai.make_messages(system="test", user="hello")) [msg async for msg in result1] assert llm1.call_count == 1 cp = result1.checkpoint llm2 = mock_llm([]) - result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) + result2 = my_agent.run(ai.make_messages(system="test", user="hello"), checkpoint=cp) [msg async for msg in result2] assert llm2.call_count == 0 @@ -51,8 +52,11 @@ async def counting_tool(x: int) -> int: execution_count += 1 return x + 1 - async def graph(model: ai.Model) -> ai.StreamResult: - result = await ai.stream_step(model, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL, tools=[counting_tool]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: await asyncio.gather( *( @@ -63,14 +67,16 @@ async def graph(model: ai.Model) -> ai.StreamResult: return result mock_llm([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) - result1 = ai.run(graph, MOCK_MODEL) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] assert execution_count == 1 assert result1.checkpoint.tools[0].result == 6 execution_count = 0 mock_llm([]) - result2 = ai.run(graph, MOCK_MODEL, checkpoint=result1.checkpoint) + result2 = my_agent.run( + ai.make_messages(system="t", user="go"), checkpoint=result1.checkpoint + ) [msg async for msg in result2] assert execution_count == 0 @@ -80,12 +86,15 @@ async def graph(model: ai.Model) -> ai.StreamResult: @pytest.mark.asyncio async def test_hook_cancellation_pending() -> None: - async def graph(model: ai.Model) -> Any: - await ai.stream_step(model, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] mock_llm([[text_msg("OK")]]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(system="t", user="go")) msgs = [msg async for msg in result] assert "my_approval" in result.pending_hooks hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] @@ -94,19 +103,22 @@ async def graph(model: ai.Model) -> Any: @pytest.mark.asyncio async def test_hook_resolution_on_reentry() -> None: - async def graph(model: ai.Model) -> Any: - await ai.stream_step(model, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) return await Approval.create("my_approval") # type: ignore[attr-defined] resp = [text_msg("OK")] mock_llm([resp]) - result1 = ai.run(graph, MOCK_MODEL) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] mock_llm([]) - result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 assert result2.checkpoint.hooks[-1].label == "my_approval" @@ -114,8 +126,11 @@ async def graph(model: ai.Model) -> Any: @pytest.mark.asyncio async def test_parallel_hooks_all_collected() -> None: - async def graph(model: ai.Model) -> None: - await ai.stream_step(model, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -128,15 +143,18 @@ async def b() -> Any: tg.create_task(b()) mock_llm([[text_msg("OK")]]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result] assert {"hook_a", "hook_b"} <= set(result.pending_hooks) @pytest.mark.asyncio async def test_parallel_hooks_resolve_on_reentry() -> None: - async def graph(model: ai.Model) -> Any: - await ai.stream_step(model, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -151,14 +169,14 @@ async def b() -> Any: resp = [text_msg("OK")] mock_llm([resp]) - result1 = ai.run(graph, MOCK_MODEL) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] mock_llm([]) - result2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index 1bd86bd9..65bfd15b 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -31,16 +31,18 @@ class CancellingConfirmation(pydantic.BaseModel): async def test_resolve_live_future() -> None: """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" resolved_value = None + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(model: ai.Model) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: nonlocal resolved_value - await ai.stream_step(model, ai.make_messages(user="go")) + await ai.stream_step(agent.model, msgs) result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] resolved_value = result mock_llm([[text_msg("OK")]]) # Confirmation.cancels_future=False -> long-running mode - run_result = ai.run(graph, MOCK_MODEL) + run_result = my_agent.run(ai.make_messages(user="go")) collected = [] async for msg in run_result: @@ -54,10 +56,6 @@ async def graph(model: ai.Model) -> None: assert resolved_value is not None assert resolved_value.approved is True assert resolved_value.reason == "looks good" - # The graph completed successfully (resolved_value proves it). - # Note: pending_hooks is not cleaned up after live resolution -- - # that's a known runtime limitation. The important thing is the - # graph continued past the hook. # -- Hook.cancel() -------------------------------------------------------- @@ -67,17 +65,19 @@ async def graph(model: ai.Model) -> None: async def test_cancel_live_hook() -> None: """Hook.cancel() cancels the future, causing CancelledError in graph.""" was_cancelled = False + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(model: ai.Model) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: nonlocal was_cancelled - await ai.stream_step(model, ai.make_messages(user="go")) + await ai.stream_step(agent.model, msgs) try: await Confirmation.create("cancel_me") # type: ignore[attr-defined] except asyncio.CancelledError: was_cancelled = True mock_llm([[text_msg("OK")]]) - run_result = ai.run(graph, MOCK_MODEL) + run_result = my_agent.run(ai.make_messages(user="go")) async for msg in run_result: if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -101,9 +101,11 @@ async def test_cancel_nonexistent_raises() -> None: @pytest.mark.asyncio async def test_pre_registered_resolution_consumed() -> None: """Pre-registered resolution is consumed by Hook.create() without suspending.""" + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(model: ai.Model) -> Any: - await ai.stream_step(model, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] return result @@ -111,7 +113,7 @@ async def graph(model: ai.Model) -> Any: Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] mock_llm([[text_msg("OK")]]) - run_result = ai.run(graph, MOCK_MODEL) + run_result = my_agent.run(ai.make_messages(user="go")) [m async for m in run_result] # Should have completed with no pending hooks @@ -136,13 +138,15 @@ def test_resolve_validates_schema() -> None: @pytest.mark.asyncio async def test_resolved_hook_emits_message() -> None: """After resolution, a 'resolved' HookPart message is emitted.""" + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(model: ai.Model) -> None: - await ai.stream_step(model, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) await Confirmation.create("emit_test") # type: ignore[attr-defined] mock_llm([[text_msg("OK")]]) - run_result = ai.run(graph, MOCK_MODEL) + run_result = my_agent.run(ai.make_messages(user="go")) msgs = [] async for msg in run_result: @@ -164,14 +168,17 @@ async def graph(model: ai.Model) -> None: @pytest.mark.asyncio async def test_hook_metadata_in_pending() -> None: - async def graph(model: ai.Model) -> None: - await ai.stream_step(model, ai.make_messages(user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) await CancellingConfirmation.create( # type: ignore[attr-defined] "meta_test", metadata={"tool": "rm -rf", "path": "/"} ) mock_llm([[text_msg("OK")]]) - run_result = ai.run(graph, MOCK_MODEL) + run_result = my_agent.run(ai.make_messages(user="go")) [m async for m in run_result] info = run_result.pending_hooks["meta_test"] diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index e15d4f49..6eb5b0be 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -1,4 +1,4 @@ -"""Runtime: stream_loop end-to-end, execute_tool, multi-turn, Runtime injection.""" +"""Agent default loop, execute_tool, multi-turn, Runtime injection.""" import asyncio @@ -25,46 +25,34 @@ async def concat(a: str, b: str) -> str: return a + b -# -- stream_loop: single turn (no tools) ---------------------------------- +# -- Agent default loop: single turn (no tools) ---------------------------- @pytest.mark.asyncio -async def test_stream_loop_text_only() -> None: - """stream_loop with no tool calls returns after one LLM call.""" - - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="Hi"), - tools=[double], - ) +async def test_agent_text_only() -> None: + """Agent default loop with no tool calls returns after one LLM call.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) llm = mock_llm([[text_msg("Hello!")]]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Hi")) msgs = [m async for m in result] assert llm.call_count == 1 assert any(m.text == "Hello!" for m in msgs) -# -- stream_loop: tool call + follow-up ----------------------------------- +# -- Agent default loop: tool call + follow-up ----------------------------- @pytest.mark.asyncio -async def test_stream_loop_tool_then_text() -> None: - """stream_loop calls tool, feeds result back, gets final text.""" - - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="Double 5"), - tools=[double], - ) +async def test_agent_tool_then_text() -> None: + """Agent default loop calls tool, feeds result back, gets final text.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] call2 = [text_msg("The answer is 10.")] llm = mock_llm([call1, call2]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Double 5")) msgs = [m async for m in result] assert llm.call_count == 2 # Tool should have been executed: 5 * 2 = 10 @@ -75,19 +63,13 @@ async def graph(model: ai.Model) -> ai.StreamResult: assert tool_results[0].tool_calls[0].result == 10 -# -- stream_loop: multiple tool calls in one message ---------------------- +# -- Agent default loop: multiple tool calls in one message ---------------- @pytest.mark.asyncio -async def test_stream_loop_parallel_tools() -> None: +async def test_agent_parallel_tools() -> None: """LLM returns two tool calls in one message; both execute.""" - - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="Double 3 and 7"), - tools=[double], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) two_tools = messages.Message( id="msg-1", @@ -112,7 +94,7 @@ async def graph(model: ai.Model) -> ai.StreamResult: call2 = [text_msg("6 and 14", id="msg-2")] llm = mock_llm([[two_tools], call2]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Double 3 and 7")) msgs = [m async for m in result] assert llm.call_count == 2 # Both tools should have results @@ -124,19 +106,13 @@ async def graph(model: ai.Model) -> ai.StreamResult: assert len(tool_result_msgs) >= 1 -# -- stream_loop: multi-turn (tool -> tool -> text) ----------------------- +# -- Agent default loop: multi-turn (tool -> tool -> text) ----------------- @pytest.mark.asyncio -async def test_stream_loop_multi_turn() -> None: +async def test_agent_multi_turn() -> None: """LLM calls a tool, then calls another tool, then returns text.""" - - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="Concat then double"), - tools=[double, concat], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[double, concat]) turn1 = [ tool_msg(tc_id="tc-1", name="concat", args='{"a": "hello", "b": " world"}') @@ -145,7 +121,7 @@ async def graph(model: ai.Model) -> ai.StreamResult: turn3 = [text_msg("Done: hello world, 6", id="msg-3")] llm = mock_llm([turn1, turn2, turn3]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Concat then double")) [m async for m in result] assert llm.call_count == 3 @@ -162,12 +138,14 @@ async def test_execute_tool_missing_raises() -> None: tc = messages.ToolPart( tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def graph(model: ai.Model) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: await ai.execute_tool(tc) mock_llm([]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="go")) with pytest.raises(ExceptionGroup) as exc_info: [m async for m in result] assert any(isinstance(e, ValueError) for e in exc_info.value.exceptions) @@ -188,8 +166,11 @@ async def introspect(query: str, rt: Runtime) -> str: received_rt = rt return "ok" - async def graph(model: ai.Model) -> None: - result = await ai.stream_step(model, ai.make_messages(user="go")) + my_agent = ai.agent(model=MOCK_MODEL, tools=[introspect]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: await asyncio.gather( *( @@ -200,7 +181,7 @@ async def graph(model: ai.Model) -> None: call = [tool_msg(tc_id="tc-1", name="introspect", args='{"query": "test"}')] mock_llm([call]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="go")) [m async for m in result] assert received_rt is not None assert isinstance(received_rt, Runtime) @@ -212,9 +193,11 @@ async def graph(model: ai.Model) -> None: @pytest.mark.asyncio async def test_execute_tool_updates_message() -> None: """After execute_tool, the ToolPart in the message has status=result.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - async def graph(model: ai.Model) -> None: - result = await ai.stream_step(model, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: msg = result.last_message for tc in result.tool_calls: @@ -226,29 +209,23 @@ async def graph(model: ai.Model) -> None: call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] mock_llm([call]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="go")) [m async for m in result] -# -- Checkpoint records tools from stream_loop ----------------------------- +# -- Checkpoint records tools from Agent default loop ---------------------- @pytest.mark.asyncio -async def test_stream_loop_checkpoint_records_tools() -> None: - """stream_loop's tool executions are recorded in the checkpoint.""" - - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_loop( - model, - messages=ai.make_messages(user="Double 4"), - tools=[double], - ) +async def test_agent_checkpoint_records_tools() -> None: + """Agent default loop's tool executions are recorded in the checkpoint.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 4}')] call2 = [text_msg("8", id="msg-2")] mock_llm([call1, call2]) - result = ai.run(graph, MOCK_MODEL) + result = my_agent.run(ai.make_messages(user="Double 4")) [m async for m in result] cp = result.checkpoint diff --git a/tests/agents/test_streams.py b/tests/agents/test_streams.py index 2eb3ef20..db7770ee 100644 --- a/tests/agents/test_streams.py +++ b/tests/agents/test_streams.py @@ -71,20 +71,23 @@ async def test_stream_outside_run_raises() -> None: @pytest.mark.asyncio async def test_stream_step_replays_from_checkpoint() -> None: - """stream_step inside ai.run with a checkpoint replays without calling LLM.""" + """stream_step inside Agent.run with a checkpoint replays without calling LLM.""" - async def graph(model: ai.Model) -> ai.StreamResult: - return await ai.stream_step(model, ai.make_messages(user="hello")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) # First run mock_llm([[text_msg("Hi")]]) - r1 = ai.run(graph, MOCK_MODEL) + r1 = my_agent.run(ai.make_messages(user="hello")) [msg async for msg in r1] cp = r1.checkpoint # Replay llm2 = mock_llm([]) - r2 = ai.run(graph, MOCK_MODEL, checkpoint=cp) + r2 = my_agent.run(ai.make_messages(user="hello"), checkpoint=cp) [msg async for msg in r2] assert llm2.call_count == 0 diff --git a/tests/agents2/__init__.py b/tests/agents2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/agents2/mcp/__init__.py b/tests/agents2/mcp/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/agents2/mcp/test_client.py b/tests/agents2/mcp/test_client.py deleted file mode 100644 index 531b16be..00000000 --- a/tests/agents2/mcp/test_client.py +++ /dev/null @@ -1,108 +0,0 @@ -"""MCP client: tool registration in global registry, end-to-end execution.""" - -import contextlib -from typing import Any - -import mcp.types -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.agents2.mcp.client import _mcp_tool_to_native -from vercel_ai_sdk.agents2.tools import _tool_registry, get_tool - -from ...conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg - - -def _fake_mcp_tool( - name: str = "mcp_echo", description: str = "Echo input" -) -> mcp.types.Tool: - """Build a minimal mcp.types.Tool for testing.""" - return mcp.types.Tool( - name=name, - description=description, - inputSchema={ - "type": "object", - "properties": {"text": {"type": "string"}}, - "required": ["text"], - }, - ) - - -def _noop_transport_factory() -> contextlib.AbstractAsyncContextManager[Any]: - """Dummy transport factory — never actually called in these tests.""" - raise NotImplementedError("should not be called") - - -# -- _mcp_tool_to_native registers in global registry ---------------------- - - -def test_mcp_tool_to_native_registers_in_global_registry() -> None: - """Converting an MCP tool to native registers it in _tool_registry.""" - mcp_tool = _fake_mcp_tool(name="mcp_reg_test") - native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) - - assert native.name == "mcp_reg_test" - assert get_tool("mcp_reg_test") is native - assert _tool_registry["mcp_reg_test"] is native - - -def test_mcp_tool_to_native_with_prefix() -> None: - """Tool prefix is prepended to the name and both name forms are correct.""" - mcp_tool = _fake_mcp_tool(name="echo") - native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, "ctx7") - - assert native.name == "ctx7_echo" - assert get_tool("ctx7_echo") is native - - -def test_mcp_tool_to_native_schema_preserved() -> None: - """The inputSchema from the MCP tool is passed through as param_schema.""" - mcp_tool = _fake_mcp_tool() - native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) - - assert native.param_schema == mcp_tool.inputSchema - assert native.description == "Echo input" - - -# -- End-to-end: MCP tool executes through Agent default loop --------------- - - -@pytest.mark.asyncio -async def test_mcp_tool_executes_through_agent() -> None: - """MCP-style tool via _mcp_tool_to_native works with Agent.""" - call_log: list[dict[str, str]] = [] - - async def fake_fn(**kwargs: str) -> str: - call_log.append(kwargs) - return f"echoed: {kwargs.get('text', '')}" - - # Build and register a tool the same way the MCP client does, - # but with a fake fn so we don't need a real MCP server. - mcp_tool = _fake_mcp_tool(name="mcp_e2e_echo") - native = _mcp_tool_to_native(mcp_tool, "test:key", _noop_transport_factory, None) - # Replace the real fn (which would try to connect) with our fake - native._fn = fake_fn - _tool_registry[native.name] = native - - my_agent = ai.agent(model=MOCK_MODEL, tools=[native]) - - call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] - call2 = [text_msg("Done.", id="msg-2")] - llm = mock_llm([call1, call2]) - - result = my_agent.run(ai.make_messages(user="echo hello")) - msgs = [m async for m in result] - - # Tool was called with the right args - assert len(call_log) == 1 - assert call_log[0] == {"text": "hello"} - - # Tool result is visible in messages - tool_results = [ - m for m in msgs if m.tool_calls and m.tool_calls[0].status == "result" - ] - assert len(tool_results) >= 1 - assert tool_results[0].tool_calls[0].result == "echoed: hello" - - # LLM was called twice (tool call + final text) - assert llm.call_count == 2 diff --git a/tests/agents2/test_checkpoint.py b/tests/agents2/test_checkpoint.py deleted file mode 100644 index 1f5f0490..00000000 --- a/tests/agents2/test_checkpoint.py +++ /dev/null @@ -1,207 +0,0 @@ -"""Checkpoint replay, hook cancellation/resolution, serialization.""" - -import asyncio -from typing import Any, ClassVar - -import pydantic -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.agents2.checkpoint import Checkpoint, HookEvent, StepEvent, ToolEvent - -from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg - - -@ai.hook -class Approval(pydantic.BaseModel): - cancels_future: ClassVar[bool] = True - granted: bool - - -# -- Replay ---------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_step_replay_skips_llm() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: - return await ai.stream_step(agent.model, msgs) - - llm1 = mock_llm([[text_msg("Hi there!")]]) - result1 = my_agent.run(ai.make_messages(system="test", user="hello")) - [msg async for msg in result1] - assert llm1.call_count == 1 - - cp = result1.checkpoint - llm2 = mock_llm([]) - result2 = my_agent.run(ai.make_messages(system="test", user="hello"), checkpoint=cp) - [msg async for msg in result2] - assert llm2.call_count == 0 - - -@pytest.mark.asyncio -async def test_tool_replay_skips_execution() -> None: - execution_count = 0 - - @ai.tool - async def counting_tool(x: int) -> int: - """Counts calls.""" - nonlocal execution_count - execution_count += 1 - return x + 1 - - my_agent = ai.agent(model=MOCK_MODEL, tools=[counting_tool]) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: - result = await ai.stream_step(agent.model, msgs, agent.tools) - if result.tool_calls: - await asyncio.gather( - *( - ai.execute_tool(tc, message=result.last_message) - for tc in result.tool_calls - ) - ) - return result - - mock_llm([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) - result1 = my_agent.run(ai.make_messages(system="t", user="go")) - [msg async for msg in result1] - assert execution_count == 1 - assert result1.checkpoint.tools[0].result == 6 - - execution_count = 0 - mock_llm([]) - result2 = my_agent.run( - ai.make_messages(system="t", user="go"), checkpoint=result1.checkpoint - ) - [msg async for msg in result2] - assert execution_count == 0 - - -# -- Hooks ----------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_hook_cancellation_pending() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: - await ai.stream_step(agent.model, msgs) - return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] - - mock_llm([[text_msg("OK")]]) - result = my_agent.run(ai.make_messages(system="t", user="go")) - msgs = [msg async for msg in result] - assert "my_approval" in result.pending_hooks - hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] - assert hook_msgs[0].parts[0].status == "pending" # type: ignore[union-attr] - - -@pytest.mark.asyncio -async def test_hook_resolution_on_reentry() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: - await ai.stream_step(agent.model, msgs) - return await Approval.create("my_approval") # type: ignore[attr-defined] - - resp = [text_msg("OK")] - mock_llm([resp]) - result1 = my_agent.run(ai.make_messages(system="t", user="go")) - [msg async for msg in result1] - cp = result1.checkpoint - - Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] - mock_llm([]) - result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) - [msg async for msg in result2] - assert len(result2.pending_hooks) == 0 - assert result2.checkpoint.hooks[-1].label == "my_approval" - - -@pytest.mark.asyncio -async def test_parallel_hooks_all_collected() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - await ai.stream_step(agent.model, msgs) - - async def a() -> Any: - return await Approval.create("hook_a") # type: ignore[attr-defined] - - async def b() -> Any: - return await Approval.create("hook_b") # type: ignore[attr-defined] - - async with asyncio.TaskGroup() as tg: - tg.create_task(a()) - tg.create_task(b()) - - mock_llm([[text_msg("OK")]]) - result = my_agent.run(ai.make_messages(system="t", user="go")) - [msg async for msg in result] - assert {"hook_a", "hook_b"} <= set(result.pending_hooks) - - -@pytest.mark.asyncio -async def test_parallel_hooks_resolve_on_reentry() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: - await ai.stream_step(agent.model, msgs) - - async def a() -> Any: - return await Approval.create("hook_a") # type: ignore[attr-defined] - - async def b() -> Any: - return await Approval.create("hook_b") # type: ignore[attr-defined] - - async with asyncio.TaskGroup() as tg: - ta = tg.create_task(a()) - tb = tg.create_task(b()) - return ta.result(), tb.result() - - resp = [text_msg("OK")] - mock_llm([resp]) - result1 = my_agent.run(ai.make_messages(system="t", user="go")) - [msg async for msg in result1] - cp = result1.checkpoint - - Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] - Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] - mock_llm([]) - result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) - [msg async for msg in result2] - assert len(result2.pending_hooks) == 0 - - -# -- Serialization --------------------------------------------------------- - - -def test_checkpoint_serialization_roundtrip() -> None: - cp = Checkpoint( - steps=[ - StepEvent( - index=0, - messages=[ - ai.Message( - id="m1", - role="assistant", - parts=[ai.TextPart(text="hi")], - ) - ], - ) - ], - tools=[ToolEvent(tool_call_id="tc-1", result=42)], - hooks=[HookEvent(label="h1", resolution={"granted": True})], - ) - cp2 = Checkpoint.model_validate(cp.model_dump()) - assert cp2.steps[0].index == 0 - assert cp2.tools[0].result == 42 - assert cp2.hooks[0].label == "h1" diff --git a/tests/agents2/test_hooks.py b/tests/agents2/test_hooks.py deleted file mode 100644 index 65bfd15b..00000000 --- a/tests/agents2/test_hooks.py +++ /dev/null @@ -1,185 +0,0 @@ -"""Hooks: live resolution, cancellation, pre-registration, schema validation.""" - -import asyncio -from typing import Any, ClassVar - -import pydantic -import pytest - -import vercel_ai_sdk as ai - -from ..conftest import MOCK_MODEL, mock_llm, text_msg - - -@ai.hook -class Confirmation(pydantic.BaseModel): - approved: bool - reason: str = "" - - -@ai.hook -class CancellingConfirmation(pydantic.BaseModel): - cancels_future: ClassVar[bool] = True - approved: bool - reason: str = "" - - -# -- Hook.resolve() with live future (long-running mode) ------------------- - - -@pytest.mark.asyncio -async def test_resolve_live_future() -> None: - """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" - resolved_value = None - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - nonlocal resolved_value - await ai.stream_step(agent.model, msgs) - result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] - resolved_value = result - - mock_llm([[text_msg("OK")]]) - # Confirmation.cancels_future=False -> long-running mode - run_result = my_agent.run(ai.make_messages(user="go")) - - collected = [] - async for msg in run_result: - collected.append(msg) - # When we see the pending hook message, resolve it - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - Confirmation.resolve( # type: ignore[attr-defined] - "confirm_1", {"approved": True, "reason": "looks good"} - ) - - assert resolved_value is not None - assert resolved_value.approved is True - assert resolved_value.reason == "looks good" - - -# -- Hook.cancel() -------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_cancel_live_hook() -> None: - """Hook.cancel() cancels the future, causing CancelledError in graph.""" - was_cancelled = False - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - nonlocal was_cancelled - await ai.stream_step(agent.model, msgs) - try: - await Confirmation.create("cancel_me") # type: ignore[attr-defined] - except asyncio.CancelledError: - was_cancelled = True - - mock_llm([[text_msg("OK")]]) - run_result = my_agent.run(ai.make_messages(user="go")) - - async for msg in run_result: - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - await Confirmation.cancel("cancel_me", reason="denied") # type: ignore[attr-defined] - - assert was_cancelled - - -# -- Hook.cancel() on non-existent label raises ---------------------------- - - -@pytest.mark.asyncio -async def test_cancel_nonexistent_raises() -> None: - with pytest.raises(ValueError, match="No pending hook"): - await Confirmation.cancel("does_not_exist_xyz") # type: ignore[attr-defined] - - -# -- Pre-registration (serverless re-entry) -------------------------------- - - -@pytest.mark.asyncio -async def test_pre_registered_resolution_consumed() -> None: - """Pre-registered resolution is consumed by Hook.create() without suspending.""" - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: - await ai.stream_step(agent.model, msgs) - result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] - return result - - # Pre-register BEFORE run - Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] - - mock_llm([[text_msg("OK")]]) - run_result = my_agent.run(ai.make_messages(user="go")) - [m async for m in run_result] - - # Should have completed with no pending hooks - assert len(run_result.pending_hooks) == 0 - # Hook event should be in checkpoint - assert any(h.label == "pre_reg_1" for h in run_result.checkpoint.hooks) - - -# -- Schema validation on resolve ----------------------------------------- - - -def test_resolve_validates_schema() -> None: - """resolve() with invalid data raises from pydantic validation.""" - # 'approved' is required bool, passing string should raise - with pytest.raises(pydantic.ValidationError): - Confirmation.resolve("schema_test", {"approved": "not_a_bool"}) # type: ignore[attr-defined] - - -# -- Resolved hook emits message ------------------------------------------- - - -@pytest.mark.asyncio -async def test_resolved_hook_emits_message() -> None: - """After resolution, a 'resolved' HookPart message is emitted.""" - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - await ai.stream_step(agent.model, msgs) - await Confirmation.create("emit_test") # type: ignore[attr-defined] - - mock_llm([[text_msg("OK")]]) - run_result = my_agent.run(ai.make_messages(user="go")) - - msgs = [] - async for msg in run_result: - msgs.append(msg) - if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): - Confirmation.resolve("emit_test", {"approved": False}) # type: ignore[attr-defined] - - hook_msgs = [ - m - for m in msgs - if any(isinstance(p, ai.HookPart) and p.status == "resolved" for p in m.parts) - ] - assert len(hook_msgs) == 1 - assert hook_msgs[0].parts[0].resolution == {"approved": False, "reason": ""} # type: ignore[union-attr] - - -# -- Hook metadata surfaces in pending message ----------------------------- - - -@pytest.mark.asyncio -async def test_hook_metadata_in_pending() -> None: - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - await ai.stream_step(agent.model, msgs) - await CancellingConfirmation.create( # type: ignore[attr-defined] - "meta_test", metadata={"tool": "rm -rf", "path": "/"} - ) - - mock_llm([[text_msg("OK")]]) - run_result = my_agent.run(ai.make_messages(user="go")) - [m async for m in run_result] - - info = run_result.pending_hooks["meta_test"] - assert info.metadata == {"tool": "rm -rf", "path": "/"} diff --git a/tests/agents2/test_runtime.py b/tests/agents2/test_runtime.py deleted file mode 100644 index b423dc1b..00000000 --- a/tests/agents2/test_runtime.py +++ /dev/null @@ -1,232 +0,0 @@ -"""Agent default loop, execute_tool, multi-turn, Runtime injection.""" - -import asyncio - -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.agents2.runtime import Runtime -from vercel_ai_sdk.types import messages - -from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg - -# -- Tool definitions for tests -------------------------------------------- - - -@ai.tool -async def double(x: int) -> int: - """Double a number.""" - return x * 2 - - -@ai.tool -async def concat(a: str, b: str) -> str: - """Concatenate strings.""" - return a + b - - -# -- Agent default loop: single turn (no tools) ---------------------------- - - -@pytest.mark.asyncio -async def test_agent_text_only() -> None: - """Agent default loop with no tool calls returns after one LLM call.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - - llm = mock_llm([[text_msg("Hello!")]]) - result = my_agent.run(ai.make_messages(user="Hi")) - msgs = [m async for m in result] - assert llm.call_count == 1 - assert any(m.text == "Hello!" for m in msgs) - - -# -- Agent default loop: tool call + follow-up ----------------------------- - - -@pytest.mark.asyncio -async def test_agent_tool_then_text() -> None: - """Agent default loop calls tool, feeds result back, gets final text.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - - call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] - call2 = [text_msg("The answer is 10.")] - llm = mock_llm([call1, call2]) - - result = my_agent.run(ai.make_messages(user="Double 5")) - msgs = [m async for m in result] - assert llm.call_count == 2 - # Tool should have been executed: 5 * 2 = 10 - tool_results = [ - m for m in msgs if m.tool_calls and m.tool_calls[0].status == "result" - ] - assert len(tool_results) >= 1 - assert tool_results[0].tool_calls[0].result == 10 - - -# -- Agent default loop: multiple tool calls in one message ---------------- - - -@pytest.mark.asyncio -async def test_agent_parallel_tools() -> None: - """LLM returns two tool calls in one message; both execute.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - - two_tools = messages.Message( - id="msg-1", - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id="tc-1", - tool_name="double", - tool_args='{"x": 3}', - status="pending", - state="done", - ), - messages.ToolPart( - tool_call_id="tc-2", - tool_name="double", - tool_args='{"x": 7}', - status="pending", - state="done", - ), - ], - ) - call2 = [text_msg("6 and 14", id="msg-2")] - llm = mock_llm([[two_tools], call2]) - - result = my_agent.run(ai.make_messages(user="Double 3 and 7")) - msgs = [m async for m in result] - assert llm.call_count == 2 - # Both tools should have results - tool_result_msgs = [ - m - for m in msgs - if m.tool_calls and any(tc.status == "result" for tc in m.tool_calls) - ] - assert len(tool_result_msgs) >= 1 - - -# -- Agent default loop: multi-turn (tool -> tool -> text) ----------------- - - -@pytest.mark.asyncio -async def test_agent_multi_turn() -> None: - """LLM calls a tool, then calls another tool, then returns text.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double, concat]) - - turn1 = [ - tool_msg(tc_id="tc-1", name="concat", args='{"a": "hello", "b": " world"}') - ] - turn2 = [tool_msg(tc_id="tc-2", name="double", args='{"x": 3}', id="msg-2")] - turn3 = [text_msg("Done: hello world, 6", id="msg-3")] - llm = mock_llm([turn1, turn2, turn3]) - - result = my_agent.run(ai.make_messages(user="Concat then double")) - [m async for m in result] - assert llm.call_count == 3 - - -# -- execute_tool: missing tool raises ------------------------------------ - - -@pytest.mark.asyncio -async def test_execute_tool_missing_raises() -> None: - """execute_tool with unknown tool name raises ValueError. - - Wrapped in ExceptionGroup by TaskGroup. - """ - tc = messages.ToolPart( - tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" - ) - my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - await ai.execute_tool(tc) - - mock_llm([]) - result = my_agent.run(ai.make_messages(user="go")) - with pytest.raises(ExceptionGroup) as exc_info: - [m async for m in result] - assert any(isinstance(e, ValueError) for e in exc_info.value.exceptions) - - -# -- execute_tool: Runtime injection --------------------------------------- - - -@pytest.mark.asyncio -async def test_execute_tool_injects_runtime() -> None: - """Tools with a Runtime parameter get the active runtime injected.""" - received_rt = None - - @ai.tool - async def introspect(query: str, rt: Runtime) -> str: - """Tool that inspects runtime.""" - nonlocal received_rt - received_rt = rt - return "ok" - - my_agent = ai.agent(model=MOCK_MODEL, tools=[introspect]) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - result = await ai.stream_step(agent.model, msgs, agent.tools) - if result.tool_calls: - await asyncio.gather( - *( - ai.execute_tool(tc, message=result.last_message) - for tc in result.tool_calls - ) - ) - - call = [tool_msg(tc_id="tc-1", name="introspect", args='{"query": "test"}')] - mock_llm([call]) - result = my_agent.run(ai.make_messages(user="go")) - [m async for m in result] - assert received_rt is not None - assert isinstance(received_rt, Runtime) - - -# -- execute_tool: result updates ToolPart in message ---------------------- - - -@pytest.mark.asyncio -async def test_execute_tool_updates_message() -> None: - """After execute_tool, the ToolPart in the message has status=result.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: - result = await ai.stream_step(agent.model, msgs, agent.tools) - if result.tool_calls: - msg = result.last_message - for tc in result.tool_calls: - await ai.execute_tool(tc, message=msg) - # Verify the tool part was mutated - assert msg is not None - assert msg.tool_calls[0].status == "result" - assert msg.tool_calls[0].result == 10 - - call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] - mock_llm([call]) - result = my_agent.run(ai.make_messages(user="go")) - [m async for m in result] - - -# -- Checkpoint records tools from Agent default loop ---------------------- - - -@pytest.mark.asyncio -async def test_agent_checkpoint_records_tools() -> None: - """Agent default loop's tool executions are recorded in the checkpoint.""" - my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - - call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 4}')] - call2 = [text_msg("8", id="msg-2")] - mock_llm([call1, call2]) - - result = my_agent.run(ai.make_messages(user="Double 4")) - [m async for m in result] - - cp = result.checkpoint - assert any(t.tool_call_id == "tc-1" and t.result == 8 for t in cp.tools) diff --git a/tests/agents2/test_streams.py b/tests/agents2/test_streams.py deleted file mode 100644 index 80a459fa..00000000 --- a/tests/agents2/test_streams.py +++ /dev/null @@ -1,113 +0,0 @@ -"""@stream decorator: context requirement, replay, queue submission.""" - -import pydantic -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.agents2.streams import StreamResult -from vercel_ai_sdk.types import messages - -from ..conftest import MOCK_MODEL, mock_llm, text_msg - - -class _Weather(pydantic.BaseModel): - city: str - temperature: float - - -# -- StreamResult properties ----------------------------------------------- - - -def test_stream_result_empty() -> None: - r = StreamResult() - assert r.last_message is None - assert r.tool_calls == [] - assert r.text == "" - - -def test_stream_result_last_message() -> None: - m1 = text_msg("first", id="m1") - m2 = text_msg("second", id="m2") - r = StreamResult(messages=[m1, m2]) - last = r.last_message - assert last is not None - assert last.id == "m2" - assert r.text == "second" - - -def test_stream_result_tool_calls() -> None: - m = messages.Message( - id="m1", - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id="tc1", tool_name="t", tool_args="{}", state="done" - ), - messages.ToolPart( - tool_call_id="tc2", tool_name="u", tool_args="{}", state="done" - ), - ], - ) - r = StreamResult(messages=[m]) - assert len(r.tool_calls) == 2 - - -# -- @stream requires Runtime context ------------------------------------- - - -@pytest.mark.asyncio -async def test_stream_outside_run_raises() -> None: - """@stream-decorated fn called without ai.run() should raise.""" - mock_llm([[text_msg("hi")]]) - with pytest.raises(ValueError, match="No Runtime context"): - await ai.stream_step( - MOCK_MODEL, - ai.make_messages(user="test"), - ) - - -# -- @stream replays from checkpoint -------------------------------------- - - -@pytest.mark.asyncio -async def test_stream_step_replays_from_checkpoint() -> None: - """stream_step inside Agent.run with a checkpoint replays without calling LLM.""" - - my_agent = ai.agent(model=MOCK_MODEL) - - @my_agent.loop - async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: - return await ai.stream_step(agent.model, msgs) - - # First run - mock_llm([[text_msg("Hi")]]) - r1 = my_agent.run(ai.make_messages(user="hello")) - [msg async for msg in r1] - cp = r1.checkpoint - - # Replay - llm2 = mock_llm([]) - r2 = my_agent.run(ai.make_messages(user="hello"), checkpoint=cp) - [msg async for msg in r2] - assert llm2.call_count == 0 - - -# -- StreamResult.output --------------------------------------------------- - - -def test_stream_result_output_from_last_message() -> None: - """StreamResult.output delegates to the last message's StructuredOutputPart.""" - m = messages.Message( - id="m1", - role="assistant", - parts=[ - messages.TextPart(text="{}", state="done"), - messages.StructuredOutputPart( - data={"city": "SF", "temperature": 62.0}, - output_type_name=f"{_Weather.__module__}.{_Weather.__qualname__}", - ), - ], - ) - r = StreamResult(messages=[text_msg("streaming..."), m]) - assert r.output is not None - assert r.output.city == "SF" diff --git a/tests/agents2/test_tools.py b/tests/agents2/test_tools.py deleted file mode 100644 index ba2d3792..00000000 --- a/tests/agents2/test_tools.py +++ /dev/null @@ -1,110 +0,0 @@ -"""@tool decorator: schema extraction, registry, Runtime parameter handling.""" - -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.agents2.runtime import Runtime -from vercel_ai_sdk.agents2.tools import get_tool - -# -- Schema extraction from type hints ------------------------------------ - - -def test_simple_types_produce_correct_schema() -> None: - @ai.tool - async def greet(name: str, count: int) -> str: - """Say hello.""" - return f"Hello {name}" * count - - assert greet.name == "greet" - assert greet.description == "Say hello." - props = greet.param_schema["properties"] - assert props["name"]["type"] == "string" - assert props["count"]["type"] == "integer" - assert set(greet.param_schema["required"]) == {"name", "count"} - - -def test_optional_param_not_required() -> None: - @ai.tool - async def search(query: str, limit: int | None = None) -> str: - """Search.""" - return query - - assert "query" in search.param_schema.get("required", []) - assert "limit" not in search.param_schema.get("required", []) - # limit should still appear in properties - assert "limit" in search.param_schema["properties"] - - -def test_default_value_not_required() -> None: - @ai.tool - async def fetch(url: str, timeout: int = 30) -> str: - """Fetch URL.""" - return url - - assert "url" in search_required(fetch) - assert "timeout" not in search_required(fetch) - - -def test_complex_type_schema() -> None: - @ai.tool - async def send(recipients: list[str], urgent: bool = False) -> str: - """Send message.""" - return "sent" - - props = send.param_schema["properties"] - assert props["recipients"]["type"] == "array" - assert props["recipients"]["items"]["type"] == "string" - - -# -- Runtime parameter skipping ------------------------------------------- - - -def test_runtime_param_excluded_from_schema() -> None: - @ai.tool - async def needs_runtime(query: str, rt: Runtime) -> str: - """Tool that needs runtime.""" - return query - - props = needs_runtime.param_schema["properties"] - assert "rt" not in props - assert "query" in props - assert set(needs_runtime.param_schema.get("required", [])) == {"query"} - - -# -- Registry ------------------------------------------------------------- - - -def test_tool_registered_on_decoration() -> None: - @ai.tool - async def unique_tool_abc() -> str: - """Unique.""" - return "ok" - - assert get_tool("unique_tool_abc") is unique_tool_abc - - -def test_get_tool_returns_none_for_missing() -> None: - assert get_tool("nonexistent_tool_xyz") is None - - -# -- Execution ------------------------------------------------------------ - - -@pytest.mark.asyncio -async def test_tool_fn_is_callable() -> None: - @ai.tool - async def add(a: int, b: int) -> int: - """Add two numbers.""" - return a + b - - result = await add(a=1, b=2) - assert result == 3 - - -# -- Helpers --------------------------------------------------------------- - - -def search_required(tool: ai.Tool[..., object]) -> list[str]: - result = tool.param_schema.get("required", []) - assert isinstance(result, list) - return result diff --git a/tests/conftest.py b/tests/conftest.py index 3ae4489d..e10fef63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,14 +6,14 @@ import pydantic import vercel_ai_sdk as ai -from vercel_ai_sdk import models2 +from vercel_ai_sdk import models from vercel_ai_sdk.types import messages as messages_ # A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. -MOCK_MODEL = models2.Model(id="mock-model", adapter="mock", provider="mock") +MOCK_MODEL = models.Model(id="mock-model", adapter="mock", provider="mock") # Register a dummy provider so _auto_client() doesn't error for provider="mock". -models2._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") +models._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") class MockAdapter: @@ -31,8 +31,8 @@ def __init__(self, responses: list[list[messages_.Message]]) -> None: async def stream( self, - client: models2.Client, - model: models2.Model, + client: models.Client, + model: models.Model, messages: list[messages_.Message], *, tools: Sequence[ai.ToolLike] | None = None, @@ -45,7 +45,7 @@ async def stream( seq = self._responses[self._call_index] self._call_index += 1 - from vercel_ai_sdk.models2.core.helpers import streaming as streaming_ + from vercel_ai_sdk.models.core.helpers import streaming as streaming_ handler = streaming_.StreamHandler() @@ -93,79 +93,15 @@ async def stream( def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: - """Create a MockAdapter and register it in the models2 adapter registry. + """Create a MockAdapter and register it in the models adapter registry. Returns the adapter so tests can inspect ``call_count``. """ adapter = MockAdapter(responses) - models2.register_stream("mock", adapter.stream) + models.register_stream("mock", adapter.stream) return adapter -# ── Legacy MockLLM (for tests/models/ that test the old LanguageModel ABC) ── - - -class MockLLM(ai.models.LanguageModel): - """LLM that yields pre-configured response sequences, one per call. - - Converts pre-configured ``Message`` objects into ``StreamEvent`` sequences - so the base-class ``stream()`` (which uses ``StreamHandler``) can - reconstruct them. - - **Legacy** — kept for tests of the old ``models/`` module. - New agent tests should use :func:`mock_llm` + ``MOCK_MODEL`` instead. - """ - - def __init__(self, responses: list[list[messages_.Message]]) -> None: - self._responses = list(responses) - self._call_index = 0 - self.call_count = 0 - - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[ai.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[Any]: - from vercel_ai_sdk.models.core import llm as llm_ - - if self._call_index >= len(self._responses): - raise RuntimeError("MockLLM: no more responses configured") - self.call_count += 1 - seq = self._responses[self._call_index] - self._call_index += 1 - - for msg in seq: - for i, part in enumerate(msg.parts): - if isinstance(part, messages_.TextPart): - bid = f"text-{i}" - yield llm_.TextStart(block_id=bid) - if part.text: - yield llm_.TextDelta(block_id=bid, delta=part.text) - yield llm_.TextEnd(block_id=bid) - - elif isinstance(part, messages_.ReasoningPart): - bid = f"reasoning-{i}" - yield llm_.ReasoningStart(block_id=bid) - if part.text: - yield llm_.ReasoningDelta(block_id=bid, delta=part.text) - yield llm_.ReasoningEnd(block_id=bid, signature=part.signature) - - elif isinstance(part, messages_.ToolPart): - yield llm_.ToolStart( - tool_call_id=part.tool_call_id, - tool_name=part.tool_name, - ) - if part.tool_args: - yield llm_.ToolArgsDelta( - tool_call_id=part.tool_call_id, - delta=part.tool_args, - ) - yield llm_.ToolEnd(tool_call_id=part.tool_call_id) - - yield llm_.MessageDone() - - # ── Helpers ────────────────────────────────────────────────────── diff --git a/tests/models/ai_gateway/test_gateway.py b/tests/models/ai_gateway/test_gateway.py deleted file mode 100644 index 2ac3e64b..00000000 --- a/tests/models/ai_gateway/test_gateway.py +++ /dev/null @@ -1,421 +0,0 @@ -"""Integration tests for ``GatewayModel``. - -Every test exercises the real ``model.stream()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: - - model.stream() - → build_request_body() - → httpx POST (mock) - → SSE line parsing - → parse_stream_part() - → StreamHandler - → yield Message -""" - -from __future__ import annotations - -import json -from typing import Any - -import httpx -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.models.ai_gateway import GatewayModel, errors -from vercel_ai_sdk.types import messages - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _sse(*events: dict[str, Any]) -> str: - """Build SSE response text from event dicts.""" - return "".join(f"data: {json.dumps(e)}\n\n" for e in events) - - -def _gateway( - handler: httpx.MockTransport, - *, - model: str = "test-provider/test-model", - api_key: str = "test-key", - provider_options: dict[str, Any] | None = None, -) -> GatewayModel: - """Create a ``GatewayModel`` wired to a mock transport.""" - return GatewayModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - provider_options=provider_options, - _transport=handler, - ) - - -async def _collect( - model: GatewayModel, - msgs: list[messages.Message], - **kwargs: Any, -) -> list[messages.Message]: - """Drain ``model.stream()`` and return all yielded messages.""" - result: list[messages.Message] = [] - async for msg in model.stream(msgs, **kwargs): - result.append(msg) - return result - - -def _user(text: str) -> messages.Message: - return messages.Message( - role="user", - parts=[messages.TextPart(text=text)], - ) - - -# --------------------------------------------------------------------------- -# Streaming: text, reasoning, tool calls -# --------------------------------------------------------------------------- - - -class TestStreaming: - @pytest.mark.asyncio - async def test_text_stream(self) -> None: - body = _sse( - {"type": "text-start", "id": "t1"}, - {"type": "text-delta", "id": "t1", "textDelta": "Hello"}, - {"type": "text-delta", "id": "t1", "textDelta": " World"}, - {"type": "text-end", "id": "t1"}, - { - "type": "finish", - "finishReason": "stop", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 2, - }, - }, - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - model = _gateway(httpx.MockTransport(handler)) - msgs = await _collect(model, [_user("Hi")]) - - final = msgs[-1] - assert final.text == "Hello World" - assert final.is_done - assert final.usage is not None - assert final.usage.input_tokens == 5 - assert final.usage.output_tokens == 2 - - @pytest.mark.asyncio - async def test_reasoning_then_text(self) -> None: - body = _sse( - {"type": "reasoning-start", "id": "r1"}, - {"type": "reasoning-delta", "id": "r1", "delta": "think"}, - {"type": "reasoning-end", "id": "r1"}, - {"type": "text-start", "id": "t1"}, - {"type": "text-delta", "id": "t1", "textDelta": "42"}, - {"type": "text-end", "id": "t1"}, - {"type": "finish", "finishReason": "stop", "usage": {}}, - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - final = (await _collect(_gateway(httpx.MockTransport(handler)), [_user("?")]))[ - -1 - ] - assert final.reasoning == "think" - assert final.text == "42" - - @pytest.mark.asyncio - async def test_streaming_tool_call(self) -> None: - body = _sse( - { - "type": "tool-input-start", - "id": "tc-1", - "toolName": "search", - }, - {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q":'}, - {"type": "tool-input-delta", "id": "tc-1", "delta": '"hi"}'}, - {"type": "tool-input-end", "id": "tc-1"}, - { - "type": "finish", - "finishReason": "tool-calls", - "usage": {}, - }, - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("search")]) - )[-1] - tc = final.tool_calls - assert len(tc) == 1 - assert tc[0].tool_name == "search" - assert tc[0].tool_args == '{"q":"hi"}' - - @pytest.mark.asyncio - async def test_inline_file_stream(self) -> None: - """Models like Gemini-3-pro-image return inline file parts - alongside text in the language model stream.""" - body = _sse( - {"type": "text-start", "id": "t1"}, - {"type": "text-delta", "id": "t1", "textDelta": "Here is an image:"}, - {"type": "text-end", "id": "t1"}, - { - "type": "file", - "id": "f1", - "mediaType": "image/png", - "data": "iVBORw0KGgo=", - }, - { - "type": "finish", - "finishReason": "stop", - "usage": {"prompt_tokens": 10, "completion_tokens": 20}, - }, - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("draw me")]) - )[-1] - assert final.text == "Here is an image:" - assert len(final.images) == 1 - assert final.images[0].media_type == "image/png" - assert final.images[0].data == "iVBORw0KGgo=" - assert final.is_done - - @pytest.mark.asyncio - async def test_complete_tool_call_part(self) -> None: - """Non-streaming ``tool-call`` part (one shot) must also work.""" - body = _sse( - { - "type": "tool-call", - "toolCallId": "tc-1", - "toolName": "get_weather", - "input": {"city": "SF"}, - }, - { - "type": "finish", - "finishReason": "tool-calls", - "usage": {}, - }, - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("weather")]) - )[-1] - assert len(final.tool_calls) == 1 - assert json.loads(final.tool_calls[0].tool_args) == {"city": "SF"} - - -# --------------------------------------------------------------------------- -# Request: headers, body, tools -# --------------------------------------------------------------------------- - - -class TestRequest: - @pytest.mark.asyncio - async def test_protocol_headers(self) -> None: - captured: dict[str, str] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured.update(dict(req.headers)) - return httpx.Response( - 200, - text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), - ) - - model = _gateway( - httpx.MockTransport(handler), - model="anthropic/claude-sonnet-4", - api_key="sk-test", - ) - await _collect(model, [_user("Hi")]) - - assert captured["authorization"] == "Bearer sk-test" - assert captured["ai-gateway-protocol-version"] == "0.0.1" - assert captured["ai-language-model-specification-version"] == "3" - assert captured["ai-language-model-id"] == "anthropic/claude-sonnet-4" - assert captured["ai-language-model-streaming"] == "true" - assert captured["ai-gateway-auth-method"] == "api-key" - - @pytest.mark.asyncio - async def test_body_prompt_format(self) -> None: - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), - ) - - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hello")]) - - assert captured_body["prompt"] == [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - } - ] - - @pytest.mark.asyncio - async def test_provider_options_in_body(self) -> None: - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), - ) - - opts = {"gateway": {"order": ["bedrock", "openai"]}} - await _collect( - _gateway(httpx.MockTransport(handler), provider_options=opts), - [_user("Hi")], - ) - - assert captured_body["providerOptions"] == opts - - @pytest.mark.asyncio - async def test_real_tool_in_request_body(self) -> None: - """A real ``@tool``-decorated function must appear correctly - in the request body sent to the gateway.""" - - @ai.tool - async def lookup(query: str) -> str: - """Search the database.""" - return "result" - - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), - ) - - await _collect( - _gateway(httpx.MockTransport(handler)), - [_user("find something")], - tools=[lookup], - ) - - assert "tools" in captured_body - td = captured_body["tools"][0] - assert td["name"] == "lookup" - assert td["type"] == "function" - assert "query" in td["inputSchema"]["properties"] - - @pytest.mark.asyncio - async def test_multi_turn_request_body(self) -> None: - """A multi-turn conversation including a tool result must - serialize correctly into the v3 prompt format.""" - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), - ) - - tool_part = messages.ToolPart( - tool_call_id="tc-1", - tool_name="search", - tool_args='{"q": "weather"}', - status="result", - result={"temp": 72}, - ) - conversation = [ - _user("What's the weather?"), - messages.Message(role="assistant", parts=[tool_part]), - _user("Thanks, and tomorrow?"), - ] - - await _collect(_gateway(httpx.MockTransport(handler)), conversation) - - prompt = captured_body["prompt"] - # user → assistant (tool-call) → tool (tool-result) → user - assert len(prompt) == 4 - assert prompt[0]["role"] == "user" - assert prompt[1]["role"] == "assistant" - assert prompt[1]["content"][0]["type"] == "tool-call" - assert prompt[2]["role"] == "tool" - assert prompt[2]["content"][0]["type"] == "tool-result" - assert prompt[3]["role"] == "user" - - -# --------------------------------------------------------------------------- -# Error handling -# --------------------------------------------------------------------------- - - -class TestErrors: - @pytest.mark.asyncio - async def test_401_authentication_error(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 401, - json={ - "error": { - "message": "Invalid API key", - "type": "authentication_error", - } - }, - ) - - with pytest.raises(errors.GatewayAuthenticationError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) - - @pytest.mark.asyncio - async def test_429_rate_limit_error(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 429, - json={ - "error": { - "message": "Rate limit exceeded", - "type": "rate_limit_exceeded", - } - }, - ) - - with pytest.raises(errors.GatewayRateLimitError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) - - @pytest.mark.asyncio - async def test_404_model_not_found(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 404, - json={ - "error": { - "message": "Model xyz not found", - "type": "model_not_found", - "param": {"modelId": "xyz"}, - } - }, - ) - - with pytest.raises(errors.GatewayModelNotFoundError) as exc_info: - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) - assert exc_info.value.model_id == "xyz" - - @pytest.mark.asyncio - async def test_500_malformed_response(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(500, text="Not JSON") - - with pytest.raises(errors.GatewayResponseError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) diff --git a/tests/models/ai_gateway/test_gateway_image.py b/tests/models/ai_gateway/test_gateway_image.py deleted file mode 100644 index 660457db..00000000 --- a/tests/models/ai_gateway/test_gateway_image.py +++ /dev/null @@ -1,262 +0,0 @@ -"""Integration tests for ``GatewayImageModel``. - -Every test exercises the real ``model.generate()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: - - model.generate() - → extract prompt/images from messages - → httpx POST (mock) to /image-model - → JSON response parsing - → media type detection - → return Message with FileParts -""" - -from __future__ import annotations - -import base64 -import json -from typing import Any - -import httpx -import pytest - -from vercel_ai_sdk.models.ai_gateway import GatewayImageModel, errors -from vercel_ai_sdk.types import messages - -# 1x1 transparent PNG (minimal valid PNG for magic-byte detection) -_PNG_HEADER = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) -_PNG_B64 = base64.b64encode(_PNG_HEADER).decode() - -# 1x1 JPEG header -_JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) -_JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _image_model( - handler: httpx.MockTransport, - *, - model: str = "google/imagen-4.0-generate-001", - api_key: str = "test-key", -) -> GatewayImageModel: - return GatewayImageModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - _transport=handler, - ) - - -def _user(text: str) -> messages.Message: - return messages.Message( - role="user", - parts=[messages.TextPart(text=text)], - ) - - -# --------------------------------------------------------------------------- -# Basic generation -# --------------------------------------------------------------------------- - - -class TestGenerate: - @pytest.mark.asyncio - async def test_basic_image_generation(self) -> None: - """Simple prompt → one PNG image back.""" - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 200, - json={"images": [_PNG_B64]}, - ) - - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("A sunset over Tokyo")]) - - assert msg.role == "assistant" - assert len(msg.images) == 1 - assert msg.images[0].data == _PNG_B64 - assert msg.images[0].media_type == "image/png" - - @pytest.mark.asyncio - async def test_multiple_images(self) -> None: - """Request n=3 images.""" - - def handler(req: httpx.Request) -> httpx.Response: - body = json.loads(req.content) - assert body["n"] == 3 - return httpx.Response( - 200, - json={"images": [_PNG_B64, _JPEG_B64, _PNG_B64]}, - ) - - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("Three cats")], n=3) - - assert len(msg.images) == 3 - assert msg.images[0].media_type == "image/png" - assert msg.images[1].media_type == "image/jpeg" - assert msg.images[2].media_type == "image/png" - - @pytest.mark.asyncio - async def test_usage_parsing(self) -> None: - """Usage data from response surfaces on the Message.""" - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 200, - json={ - "images": [_PNG_B64], - "usage": {"inputTokens": 50, "outputTokens": 100}, - }, - ) - - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("a dog")]) - - assert msg.usage is not None - assert msg.usage.input_tokens == 50 - assert msg.usage.output_tokens == 100 - - -# --------------------------------------------------------------------------- -# Request format -# --------------------------------------------------------------------------- - - -class TestRequest: - @pytest.mark.asyncio - async def test_protocol_headers(self) -> None: - captured: dict[str, str] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured.update(dict(req.headers)) - return httpx.Response(200, json={"images": [_PNG_B64]}) - - model = _image_model( - httpx.MockTransport(handler), - model="openai/gpt-image-1", - api_key="sk-test", - ) - await model.generate([_user("Hi")]) - - assert captured["authorization"] == "Bearer sk-test" - assert captured["ai-image-model-specification-version"] == "3" - assert captured["ai-model-id"] == "openai/gpt-image-1" - assert captured["ai-gateway-auth-method"] == "api-key" - - @pytest.mark.asyncio - async def test_parameters_forwarded(self) -> None: - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response(200, json={"images": [_PNG_B64]}) - - model = _image_model(httpx.MockTransport(handler)) - await model.generate( - [_user("landscape")], - n=2, - size="1024x1024", - aspect_ratio="16:9", - seed=42, - provider_options={"google": {"style": "vivid"}}, - ) - - assert captured_body["prompt"] == "landscape" - assert captured_body["n"] == 2 - assert captured_body["size"] == "1024x1024" - assert captured_body["aspectRatio"] == "16:9" - assert captured_body["seed"] == 42 - assert captured_body["providerOptions"] == {"google": {"style": "vivid"}} - - @pytest.mark.asyncio - async def test_input_images_forwarded(self) -> None: - """Input images from user messages → files in request body.""" - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response(200, json={"images": [_PNG_B64]}) - - user_msg = messages.Message( - role="user", - parts=[ - messages.TextPart(text="Edit this"), - messages.FilePart(data=_PNG_B64, media_type="image/png"), - ], - ) - model = _image_model(httpx.MockTransport(handler)) - await model.generate([user_msg]) - - assert captured_body["prompt"] == "Edit this" - assert "files" in captured_body - assert len(captured_body["files"]) == 1 - assert captured_body["files"][0]["type"] == "file" - assert captured_body["files"][0]["mediaType"] == "image/png" - - @pytest.mark.asyncio - async def test_url_posts_to_image_model_endpoint(self) -> None: - captured_url: list[str] = [] - - def handler(req: httpx.Request) -> httpx.Response: - captured_url.append(str(req.url)) - return httpx.Response(200, json={"images": [_PNG_B64]}) - - model = _image_model(httpx.MockTransport(handler)) - await model.generate([_user("test")]) - - assert captured_url[0] == "https://gw.test/v3/ai/image-model" - - -# --------------------------------------------------------------------------- -# Error handling -# --------------------------------------------------------------------------- - - -class TestErrors: - @pytest.mark.asyncio - async def test_401_authentication_error(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 401, - json={ - "error": { - "message": "Invalid API key", - "type": "authentication_error", - } - }, - ) - - with pytest.raises(errors.GatewayAuthenticationError): - await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) - - @pytest.mark.asyncio - async def test_429_rate_limit_error(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 429, - json={ - "error": { - "message": "Rate limited", - "type": "rate_limit_exceeded", - } - }, - ) - - with pytest.raises(errors.GatewayRateLimitError): - await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) - - @pytest.mark.asyncio - async def test_empty_images_returns_empty_message(self) -> None: - """Gateway returns empty images array → message with no parts.""" - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, json={"images": []}) - - msg = await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) - assert len(msg.images) == 0 diff --git a/tests/models/ai_gateway/test_gateway_video.py b/tests/models/ai_gateway/test_gateway_video.py deleted file mode 100644 index 07de0ecf..00000000 --- a/tests/models/ai_gateway/test_gateway_video.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Integration tests for ``GatewayVideoModel``. - -Every test exercises the real ``model.generate()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: - - model.generate() - → extract prompt/image from messages - → httpx POST (mock) to /video-model with SSE accept - → SSE event parsing - → video data handling (base64 or URL download) - → return Message with FileParts -""" - -from __future__ import annotations - -import base64 -import json -from typing import Any -from unittest.mock import AsyncMock, patch - -import httpx -import pytest - -from vercel_ai_sdk.models.ai_gateway import GatewayVideoModel, errors -from vercel_ai_sdk.types import messages - -# MP4 magic bytes (ftyp box) -_MP4_HEADER = bytes( - [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D] -) -_MP4_B64 = base64.b64encode(_MP4_HEADER).decode() - -# WebM magic bytes -_WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) -_WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _sse(*events: dict[str, Any]) -> str: - """Build SSE response text from event dicts.""" - return "".join(f"data: {json.dumps(e)}\n\n" for e in events) - - -def _video_model( - handler: httpx.MockTransport, - *, - model: str = "google/veo-3.0-generate-001", - api_key: str = "test-key", -) -> GatewayVideoModel: - return GatewayVideoModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - _transport=handler, - ) - - -def _user(text: str) -> messages.Message: - return messages.Message( - role="user", - parts=[messages.TextPart(text=text)], - ) - - -# --------------------------------------------------------------------------- -# Basic generation -# --------------------------------------------------------------------------- - - -class TestGenerate: - @pytest.mark.asyncio - async def test_basic_video_generation_base64(self) -> None: - """Simple prompt → one MP4 video back via base64.""" - body = _sse( - { - "type": "result", - "videos": [ - {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"} - ], - } - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - model = _video_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("A cat walking on a beach")]) - - assert msg.role == "assistant" - assert len(msg.videos) == 1 - assert msg.videos[0].data == _MP4_B64 - assert msg.videos[0].media_type == "video/mp4" - - @pytest.mark.asyncio - async def test_video_generation_url(self) -> None: - """Video returned as URL → downloaded automatically.""" - body = _sse( - { - "type": "result", - "videos": [ - { - "type": "url", - "url": "https://storage.example.com/video.mp4", - "mediaType": "video/mp4", - } - ], - } - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - model = _video_model(httpx.MockTransport(handler)) - - with patch( - "vercel_ai_sdk.models.core.media.download.download", - new_callable=AsyncMock, - return_value=(_MP4_HEADER, "video/mp4"), - ) as mock_dl: - msg = await model.generate([_user("A sunset timelapse")]) - - mock_dl.assert_called_once_with("https://storage.example.com/video.mp4") - assert len(msg.videos) == 1 - assert msg.videos[0].data == _MP4_HEADER - assert msg.videos[0].media_type == "video/mp4" - - @pytest.mark.asyncio - async def test_multiple_videos(self) -> None: - body = _sse( - { - "type": "result", - "videos": [ - {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"}, - {"type": "base64", "data": _WEBM_B64, "mediaType": "video/webm"}, - ], - } - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - msg = await _video_model(httpx.MockTransport(handler)).generate( - [_user("Two versions")], n=2 - ) - assert len(msg.videos) == 2 - assert msg.videos[0].media_type == "video/mp4" - assert msg.videos[1].media_type == "video/webm" - - -# --------------------------------------------------------------------------- -# Request format -# --------------------------------------------------------------------------- - - -class TestRequest: - @pytest.mark.asyncio - async def test_protocol_headers(self) -> None: - captured: dict[str, str] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured.update(dict(req.headers)) - return httpx.Response( - 200, - text=_sse( - { - "type": "result", - "videos": [ - { - "type": "base64", - "data": _MP4_B64, - "mediaType": "video/mp4", - } - ], - } - ), - ) - - model = _video_model( - httpx.MockTransport(handler), - model="google/veo-3.0-generate-001", - api_key="sk-test", - ) - await model.generate([_user("test")]) - - assert captured["authorization"] == "Bearer sk-test" - assert captured["ai-video-model-specification-version"] == "3" - assert captured["ai-model-id"] == "google/veo-3.0-generate-001" - assert captured["accept"] == "text/event-stream" - assert captured["ai-gateway-auth-method"] == "api-key" - - @pytest.mark.asyncio - async def test_parameters_forwarded(self) -> None: - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse( - { - "type": "result", - "videos": [ - { - "type": "base64", - "data": _MP4_B64, - "mediaType": "video/mp4", - } - ], - } - ), - ) - - model = _video_model(httpx.MockTransport(handler)) - await model.generate( - [_user("sunset")], - n=2, - aspect_ratio="16:9", - resolution="1920x1080", - duration=5.0, - fps=30, - seed=42, - provider_options={"google": {"enhancePrompt": True}}, - ) - - assert captured_body["prompt"] == "sunset" - assert captured_body["n"] == 2 - assert captured_body["aspectRatio"] == "16:9" - assert captured_body["resolution"] == "1920x1080" - assert captured_body["duration"] == 5.0 - assert captured_body["fps"] == 30 - assert captured_body["seed"] == 42 - assert captured_body["providerOptions"] == {"google": {"enhancePrompt": True}} - - @pytest.mark.asyncio - async def test_url_posts_to_video_model_endpoint(self) -> None: - captured_url: list[str] = [] - - def handler(req: httpx.Request) -> httpx.Response: - captured_url.append(str(req.url)) - return httpx.Response( - 200, - text=_sse( - { - "type": "result", - "videos": [ - { - "type": "base64", - "data": _MP4_B64, - "mediaType": "video/mp4", - } - ], - } - ), - ) - - model = _video_model(httpx.MockTransport(handler)) - await model.generate([_user("test")]) - - assert captured_url[0] == "https://gw.test/v3/ai/video-model" - - @pytest.mark.asyncio - async def test_image_to_video_input(self) -> None: - """Image in user message → image field in request body.""" - captured_body: dict[str, Any] = {} - - def handler(req: httpx.Request) -> httpx.Response: - captured_body.update(json.loads(req.content)) - return httpx.Response( - 200, - text=_sse( - { - "type": "result", - "videos": [ - { - "type": "base64", - "data": _MP4_B64, - "mediaType": "video/mp4", - } - ], - } - ), - ) - - png_b64 = base64.b64encode(b"\x89PNG").decode() - user_msg = messages.Message( - role="user", - parts=[ - messages.TextPart(text="Animate this"), - messages.FilePart(data=png_b64, media_type="image/png"), - ], - ) - model = _video_model(httpx.MockTransport(handler)) - await model.generate([user_msg]) - - assert captured_body["prompt"] == "Animate this" - assert "image" in captured_body - assert captured_body["image"]["type"] == "file" - assert captured_body["image"]["mediaType"] == "image/png" - - -# --------------------------------------------------------------------------- -# Error handling -# --------------------------------------------------------------------------- - - -class TestErrors: - @pytest.mark.asyncio - async def test_sse_error_event(self) -> None: - """Gateway returns an SSE error event → raises.""" - body = _sse( - { - "type": "error", - "message": "Content policy violation", - "errorType": "content_filter", - "statusCode": 400, - "param": None, - } - ) - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text=body) - - with pytest.raises(errors.GatewayInvalidRequestError, match="Content policy"): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) - - @pytest.mark.asyncio - async def test_401_authentication_error(self) -> None: - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response( - 401, - json={ - "error": { - "message": "Bad key", - "type": "authentication_error", - } - }, - ) - - with pytest.raises(errors.GatewayAuthenticationError): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) - - @pytest.mark.asyncio - async def test_empty_sse_stream(self) -> None: - """SSE stream with no data events → raises.""" - - def handler(req: httpx.Request) -> httpx.Response: - return httpx.Response(200, text="") - - with pytest.raises(errors.GatewayResponseError, match="SSE stream ended"): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) diff --git a/tests/models2/ai_gateway/test_generate_image.py b/tests/models/ai_gateway/test_generate_image.py similarity index 97% rename from tests/models2/ai_gateway/test_generate_image.py rename to tests/models/ai_gateway/test_generate_image.py index 2d8d0f82..ca91365a 100644 --- a/tests/models2/ai_gateway/test_generate_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -21,13 +21,13 @@ import httpx import pytest -from vercel_ai_sdk.models2.ai_gateway import errors -from vercel_ai_sdk.models2.ai_gateway.generate import ( +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.ai_gateway.generate import ( ImageParams, generate, ) -from vercel_ai_sdk.models2.core import client as client_ -from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages # 1x1 transparent PNG (minimal valid PNG for magic-byte detection) diff --git a/tests/models2/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py similarity index 97% rename from tests/models2/ai_gateway/test_generate_video.py rename to tests/models/ai_gateway/test_generate_video.py index 331aad07..06dc6b91 100644 --- a/tests/models2/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -22,13 +22,13 @@ import httpx import pytest -from vercel_ai_sdk.models2.ai_gateway import errors -from vercel_ai_sdk.models2.ai_gateway.generate import ( +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.ai_gateway.generate import ( VideoParams, generate, ) -from vercel_ai_sdk.models2.core import client as client_ -from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages # MP4 magic bytes (ftyp box) @@ -130,7 +130,7 @@ def handler(req: httpx.Request) -> httpx.Response: client = _client(httpx.MockTransport(handler)) with patch( - "vercel_ai_sdk.models2.core.helpers.media.download", + "vercel_ai_sdk.models.core.helpers.media.download", new_callable=AsyncMock, return_value=(_MP4_HEADER, "video/mp4"), ) as mock_dl: diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index 38c27e73..512c83e3 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -1,15 +1,15 @@ """Tests for the v3 protocol serialization and deserialization. Focus areas: -- ``messages_to_v3_prompt``: the critical outgoing translation layer -- ``tools_to_v3`` / ``build_request_body``: using real ``@tool`` -- ``parse_stream_part``: the critical incoming translation layer -- ``parse_generate_result``: non-streaming response handling +- ``_messages_to_prompt``: the critical outgoing translation layer +- ``_build_request_body``: using real ``@tool`` +- ``_parse_stream_part``: the critical incoming translation layer - ``_parse_usage``: the two distinct wire formats """ from __future__ import annotations +import importlib import json from unittest.mock import AsyncMock, patch @@ -17,17 +17,20 @@ import pytest import vercel_ai_sdk as ai -from vercel_ai_sdk.models.ai_gateway import protocol -from vercel_ai_sdk.models.core import llm +from vercel_ai_sdk.models.core.helpers import streaming from vercel_ai_sdk.types import messages +# The ai_gateway __init__.py re-exports `stream` as a function, which +# shadows the module. Use importlib to get the actual module. +stream_mod = importlib.import_module("vercel_ai_sdk.models.ai_gateway.stream") + # --------------------------------------------------------------------------- -# messages_to_v3_prompt +# _messages_to_prompt # --------------------------------------------------------------------------- @pytest.mark.asyncio -class TestMessagesToV3Prompt: +class TestMessagesToPrompt: async def test_system_message(self) -> None: msgs = [ messages.Message( @@ -35,7 +38,7 @@ async def test_system_message(self) -> None: parts=[messages.TextPart(text="You are helpful.")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [{"role": "system", "content": "You are helpful."}] async def test_user_message(self) -> None: @@ -45,7 +48,7 @@ async def test_user_message(self) -> None: parts=[messages.TextPart(text="Hello")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [ { "role": "user", @@ -63,7 +66,7 @@ async def test_assistant_with_reasoning_and_text(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "reasoning", "text": "Let me think..."} assert content[1] == {"type": "text", "text": "42"} @@ -85,7 +88,7 @@ async def test_tool_call_with_result_produces_two_messages(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert len(result) == 2 # Assistant message has the tool-call @@ -114,13 +117,13 @@ async def test_tool_error_result(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) tr = result[1]["content"][0] assert tr["output"]["type"] == "error-text" assert tr["output"]["value"] == "Connection timeout" async def test_user_message_with_image_url(self) -> None: - """FilePart with image URL → downloaded and converted to data: URL.""" + """FilePart with image URL -> downloaded and converted to data: URL.""" fake_jpeg = b"\xff\xd8\xff\xe0" msgs = [ messages.Message( @@ -134,11 +137,11 @@ async def test_user_message_with_image_url(self) -> None: ) ] with patch( - "vercel_ai_sdk.models.core.media.download.download", + "vercel_ai_sdk.models.core.helpers.media.download", new_callable=AsyncMock, return_value=(fake_jpeg, "image/jpeg"), ): - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "text", "text": "Look at this"} assert content[1]["type"] == "file" @@ -146,7 +149,7 @@ async def test_user_message_with_image_url(self) -> None: assert content[1]["data"].startswith("data:image/jpeg;base64,") async def test_user_message_with_file_bytes(self) -> None: - """FilePart with bytes → v3 file content part with data URL.""" + """FilePart with bytes -> v3 file content part with data URL.""" msgs = [ messages.Message( role="user", @@ -157,7 +160,7 @@ async def test_user_message_with_file_bytes(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) part = result[0]["content"][0] assert part["type"] == "file" assert part["mediaType"] == "image/png" @@ -172,7 +175,7 @@ async def test_user_message_text_only_unchanged(self) -> None: parts=[messages.TextPart(text="Hello")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [ {"role": "user", "content": [{"type": "text", "text": "Hello"}]} ] @@ -192,13 +195,13 @@ async def test_pending_tool_call_no_tool_message(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert len(result) == 1 assert result[0]["role"] == "assistant" # --------------------------------------------------------------------------- -# tools_to_v3 / build_request_body — using real @tool +# _build_request_body — using real @tool # --------------------------------------------------------------------------- @@ -212,14 +215,14 @@ async def get_weather(city: str, units: str = "celsius") -> str: class TestBuildRequestBody: async def test_with_real_tool(self) -> None: """Verify @tool-produced schema round-trips through - build_request_body → JSON → gateway wire format.""" + _build_request_body -> JSON -> gateway wire format.""" msgs = [ messages.Message( role="user", parts=[messages.TextPart(text="What's the weather?")], ) ] - body = await protocol.build_request_body(msgs, tools=[get_weather]) + body = await stream_mod._build_request_body(msgs, tools=[get_weather]) assert "tools" in body tool_def = body["tools"][0] @@ -245,7 +248,7 @@ class WeatherResult(pydantic.BaseModel): parts=[messages.TextPart(text="Weather?")], ) ] - body = await protocol.build_request_body(msgs, output_type=WeatherResult) + body = await stream_mod._build_request_body(msgs, output_type=WeatherResult) assert "responseFormat" in body rf = body["responseFormat"] @@ -262,46 +265,46 @@ async def test_provider_options_passthrough(self) -> None: ) ] opts = {"gateway": {"order": ["bedrock", "openai"]}} - body = await protocol.build_request_body(msgs, provider_options=opts) + body = await stream_mod._build_request_body(msgs, provider_options=opts) assert body["providerOptions"] == opts # --------------------------------------------------------------------------- -# parse_stream_part — parametrized simple 1:1 mappings +# _parse_stream_part — parametrized simple 1:1 mappings # --------------------------------------------------------------------------- _SIMPLE_STREAM_PARTS = [ ( {"type": "text-start", "id": "t1"}, - llm.TextStart(block_id="t1"), + streaming.TextStart(block_id="t1"), ), ( {"type": "text-end", "id": "t1"}, - llm.TextEnd(block_id="t1"), + streaming.TextEnd(block_id="t1"), ), ( {"type": "reasoning-start", "id": "r1"}, - llm.ReasoningStart(block_id="r1"), + streaming.ReasoningStart(block_id="r1"), ), ( {"type": "reasoning-delta", "id": "r1", "delta": "hmm"}, - llm.ReasoningDelta(block_id="r1", delta="hmm"), + streaming.ReasoningDelta(block_id="r1", delta="hmm"), ), ( {"type": "reasoning-end", "id": "r1"}, - llm.ReasoningEnd(block_id="r1"), + streaming.ReasoningEnd(block_id="r1"), ), ( {"type": "tool-input-start", "id": "tc-1", "toolName": "search"}, - llm.ToolStart(tool_call_id="tc-1", tool_name="search"), + streaming.ToolStart(tool_call_id="tc-1", tool_name="search"), ), ( {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q"'}, - llm.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), + streaming.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), ), ( {"type": "tool-input-end", "id": "tc-1"}, - llm.ToolEnd(tool_call_id="tc-1"), + streaming.ToolEnd(tool_call_id="tc-1"), ), ] @@ -312,9 +315,9 @@ async def test_provider_options_passthrough(self) -> None: ids=[w["type"] for w, _ in _SIMPLE_STREAM_PARTS], ) def test_parse_stream_part_simple( - wire: dict[str, object], expected: llm.StreamEvent + wire: dict[str, object], expected: streaming.StreamEvent ) -> None: - events = protocol.parse_stream_part(wire) + events = stream_mod._parse_stream_part(wire) assert len(events) == 1 assert events[0] == expected @@ -323,16 +326,16 @@ def test_parse_stream_part_simple( class TestParseStreamPartComplex: async def test_text_delta_uses_textDelta_key(self) -> None: """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( {"type": "text-delta", "id": "t1", "textDelta": "Hello"} ) - assert isinstance(events[0], llm.TextDelta) + assert isinstance(events[0], streaming.TextDelta) assert events[0].delta == "Hello" async def test_tool_call_expands_to_three_events(self) -> None: """A complete ``tool-call`` part must expand into - ToolStart → ToolArgsDelta → ToolEnd.""" - events = protocol.parse_stream_part( + ToolStart -> ToolArgsDelta -> ToolEnd.""" + events = stream_mod._parse_stream_part( { "type": "tool-call", "toolCallId": "tc-1", @@ -341,14 +344,14 @@ async def test_tool_call_expands_to_three_events(self) -> None: } ) assert len(events) == 3 - assert isinstance(events[0], llm.ToolStart) + assert isinstance(events[0], streaming.ToolStart) assert events[0].tool_name == "get_weather" - assert isinstance(events[1], llm.ToolArgsDelta) + assert isinstance(events[1], streaming.ToolArgsDelta) assert json.loads(events[1].delta) == {"city": "SF"} - assert isinstance(events[2], llm.ToolEnd) + assert isinstance(events[2], streaming.ToolEnd) async def test_finish_flat_usage(self) -> None: - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "finish", "finishReason": "stop", @@ -359,14 +362,14 @@ async def test_finish_flat_usage(self) -> None: } ) done = events[0] - assert isinstance(done, llm.MessageDone) + assert isinstance(done, streaming.MessageDone) assert done.finish_reason == "stop" assert done.usage is not None assert done.usage.input_tokens == 10 assert done.usage.output_tokens == 20 async def test_finish_v3_nested_usage(self) -> None: - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "finish", "finishReason": { @@ -386,7 +389,7 @@ async def test_finish_v3_nested_usage(self) -> None: } ) done = events[0] - assert isinstance(done, llm.MessageDone) + assert isinstance(done, streaming.MessageDone) assert done.finish_reason == "tool-calls" assert done.usage is not None assert done.usage.input_tokens == 100 @@ -396,7 +399,7 @@ async def test_finish_v3_nested_usage(self) -> None: async def test_file_part(self) -> None: """A ``file`` stream part (inline image from Gemini/GPT-5) must produce a FileEvent.""" - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "file", "id": "f1", @@ -405,82 +408,21 @@ async def test_file_part(self) -> None: } ) assert len(events) == 1 - assert isinstance(events[0], llm.FileEvent) + assert isinstance(events[0], streaming.FileEvent) assert events[0].block_id == "f1" assert events[0].media_type == "image/png" assert events[0].data == "iVBORw0KGgo=" async def test_file_part_defaults(self) -> None: """A minimal ``file`` part uses sensible defaults.""" - events = protocol.parse_stream_part({"type": "file", "data": "somedata"}) + events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) assert len(events) == 1 - assert isinstance(events[0], llm.FileEvent) + assert isinstance(events[0], streaming.FileEvent) assert events[0].media_type == "application/octet-stream" async def test_unknown_types_produce_no_events(self) -> None: for t in ("stream-start", "raw", "response-metadata", "banana"): - assert protocol.parse_stream_part({"type": t}) == [] - - -# --------------------------------------------------------------------------- -# parse_generate_result -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -class TestParseGenerateResult: - async def test_text_content(self) -> None: - events = protocol.parse_generate_result( - { - "content": [{"type": "text", "text": "Hello!"}], - "finishReason": "stop", - "usage": {"prompt_tokens": 4, "completion_tokens": 10}, - } - ) - # TextStart + TextDelta + TextEnd + MessageDone - assert len(events) == 4 - assert isinstance(events[1], llm.TextDelta) - assert events[1].delta == "Hello!" - assert isinstance(events[3], llm.MessageDone) - - async def test_tool_call_content(self) -> None: - events = protocol.parse_generate_result( - { - "content": [ - { - "type": "tool-call", - "toolCallId": "tc-1", - "toolName": "search", - "input": {"query": "weather"}, - } - ], - "finishReason": "tool-calls", - } - ) - assert isinstance(events[0], llm.ToolStart) - assert isinstance(events[3], llm.MessageDone) - assert events[3].finish_reason == "tool-calls" - - async def test_file_content(self) -> None: - """A ``file`` part in non-streaming result produces a FileEvent.""" - events = protocol.parse_generate_result( - { - "content": [ - { - "type": "file", - "id": "f1", - "mediaType": "image/png", - "data": "iVBORw0KGgo=", - } - ], - "finishReason": "stop", - "usage": {"prompt_tokens": 10, "completion_tokens": 5}, - } - ) - file_events = [e for e in events if isinstance(e, llm.FileEvent)] - assert len(file_events) == 1 - assert file_events[0].media_type == "image/png" - assert isinstance(events[-1], llm.MessageDone) + assert stream_mod._parse_stream_part({"type": t}) == [] # --------------------------------------------------------------------------- @@ -491,12 +433,12 @@ async def test_file_content(self) -> None: @pytest.mark.asyncio class TestParseUsage: async def test_flat_format(self) -> None: - usage = protocol._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) + usage = stream_mod._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) assert usage.input_tokens == 10 assert usage.output_tokens == 20 async def test_v3_nested_format(self) -> None: - usage = protocol._parse_usage( + usage = stream_mod._parse_usage( { "inputTokens": { "total": 100, @@ -513,6 +455,6 @@ async def test_v3_nested_format(self) -> None: assert usage.reasoning_tokens == 10 async def test_non_dict_returns_empty(self) -> None: - usage = protocol._parse_usage("not a dict") + usage = stream_mod._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 diff --git a/tests/models2/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py similarity index 98% rename from tests/models2/ai_gateway/test_stream.py rename to tests/models/ai_gateway/test_stream.py index 559f4874..784dfac1 100644 --- a/tests/models2/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -23,14 +23,14 @@ import pytest import vercel_ai_sdk as ai -from vercel_ai_sdk.models2.ai_gateway import errors -from vercel_ai_sdk.models2.core import client as client_ -from vercel_ai_sdk.models2.core import model as model_ +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages # The ai_gateway __init__.py re-exports `stream` as a function, which # shadows the module. Use importlib to get the actual module. -stream_mod = importlib.import_module("vercel_ai_sdk.models2.ai_gateway.stream") +stream_mod = importlib.import_module("vercel_ai_sdk.models.ai_gateway.stream") # --------------------------------------------------------------------------- # Helpers diff --git a/tests/models/anthropic/__init__.py b/tests/models/anthropic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/anthropic/test_anthropic.py b/tests/models/anthropic/test_anthropic.py deleted file mode 100644 index 8c9633c7..00000000 --- a/tests/models/anthropic/test_anthropic.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Anthropic provider: _messages_to_anthropic conversion tests.""" - -import base64 - -import pytest - -from vercel_ai_sdk.models.anthropic import _messages_to_anthropic -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart, ToolPart - -pytestmark = pytest.mark.asyncio - - -async def test_tool_result_none_still_emits_tool_result() -> None: - """A tool that returns None must still produce a tool_result block. - - Regression: when part.result is None the converter skipped the tool_result, - leaving a tool_use without a matching tool_result. Anthropic rejects this - with: "tool_use ids were found without tool_result blocks immediately after". - """ - tool_part = ToolPart( - tool_call_id="toolu_01abc", - tool_name="send_notification", - tool_args="{}", - ) - tool_part.set_result(None) # tool returned None (fire-and-forget style) - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should have: assistant message with tool_use, then user message with tool_result - assert len(anthropic_msgs) == 2, ( - f"Expected 2 messages (assistant + user/tool_result), " - f"got {len(anthropic_msgs)}: {anthropic_msgs}" - ) - - assistant_msg = anthropic_msgs[0] - assert assistant_msg["role"] == "assistant" - assert any(block["type"] == "tool_use" for block in assistant_msg["content"]) - - user_msg = anthropic_msgs[1] - assert user_msg["role"] == "user" - tool_results = [b for b in user_msg["content"] if b["type"] == "tool_result"] - assert len(tool_results) == 1 - assert tool_results[0]["tool_use_id"] == "toolu_01abc" - - -async def test_tool_with_normal_result() -> None: - """Baseline: a tool with a normal result produces the correct pair.""" - tool_part = ToolPart( - tool_call_id="toolu_02xyz", - tool_name="get_weather", - tool_args='{"city": "SF"}', - ) - tool_part.set_result({"temp": 62}) - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - assert anthropic_msgs[1]["content"][0]["content"] == "{'temp': 62}" - - -async def test_tool_error_produces_tool_result() -> None: - """Tool errors must also produce a tool_result block (with is_error=True).""" - tool_part = ToolPart( - tool_call_id="toolu_03err", - tool_name="failing_tool", - tool_args="{}", - ) - tool_part.set_error("Connection timeout") - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - tool_result = anthropic_msgs[1]["content"][0] - assert tool_result["type"] == "tool_result" - assert tool_result["is_error"] is True - assert tool_result["content"] == "Connection timeout" - - -async def test_multiple_tools_one_returns_none() -> None: - """When one of several tools returns None, all must have tool_results.""" - tool_a = ToolPart( - tool_call_id="toolu_a", - tool_name="tool_a", - tool_args="{}", - ) - tool_a.set_result("some result") - - tool_b = ToolPart( - tool_call_id="toolu_b", - tool_name="tool_b", - tool_args="{}", - ) - tool_b.set_result(None) # returns None - - messages = [ - Message(role="assistant", parts=[tool_a, tool_b]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - - # Both tool_use blocks in assistant message - tool_uses = [b for b in anthropic_msgs[0]["content"] if b["type"] == "tool_use"] - assert len(tool_uses) == 2 - - # Both tool_result blocks in user message - tool_results = [ - b for b in anthropic_msgs[1]["content"] if b["type"] == "tool_result" - ] - assert len(tool_results) == 2 - - result_ids = {r["tool_use_id"] for r in tool_results} - assert result_ids == {"toolu_a", "toolu_b"} - - -# -- Multi-turn: consecutive user messages (tool_result + next user) ------- - - -async def test_multi_turn_no_consecutive_same_role_messages() -> None: - """Multi-turn with tools must not produce consecutive same-role messages. - - Regression: when a previous assistant turn includes a tool call (with - result), _messages_to_anthropic emits: - [assistant(tool_use)] [user(tool_result)] [user(next question)] - The two consecutive user messages violate Anthropic's alternating-role - requirement, causing: "tool_use ids were found without tool_result - blocks immediately after". - - The tool_result user message must be merged with the following user - message (or otherwise avoid consecutive same-role messages). - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="talk_to_mothership", - tool_args='{"question": "when?"}', - ) - tool.set_result({"value": "Soon."}) - - messages = [ - Message(role="user", parts=[TextPart(text="when will the robots take over?")]), - Message( - role="assistant", - parts=[ - TextPart(text="I'll check with the mothership."), - tool, - TextPart(text="The mothership has spoken: Soon."), - ], - ), - Message( - role="user", - parts=[TextPart(text="can you remember the first turn?")], - ), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Verify no consecutive same-role messages - for i in range(1, len(anthropic_msgs)): - assert anthropic_msgs[i]["role"] != anthropic_msgs[i - 1]["role"], ( - f"Consecutive same-role messages at indices {i - 1} and {i}: " - f"both are '{anthropic_msgs[i]['role']}'. " - f"Full messages: {anthropic_msgs}" - ) - - -async def test_multi_turn_tool_result_before_user_merged() -> None: - """When tool_result (user) is followed by a user message, they merge. - - The merged user message should contain both the tool_result blocks - and the text content from the following user message. - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="get_weather", - tool_args='{"city": "SF"}', - ) - tool.set_result("Sunny, 62F") - - messages = [ - Message(role="user", parts=[TextPart(text="what's the weather?")]), - Message(role="assistant", parts=[tool]), - Message(role="user", parts=[TextPart(text="thanks, what about tomorrow?")]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should be: user, assistant, user (tool_result + text) - assert len(anthropic_msgs) == 3 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert anthropic_msgs[2]["role"] == "user" - - # The merged user message should contain the tool_result - user_content = anthropic_msgs[2]["content"] - assert isinstance(user_content, list) - tool_results = [b for b in user_content if b.get("type") == "tool_result"] - assert len(tool_results) == 1 - assert tool_results[0]["tool_use_id"] == "toolu_01abc" - - -async def test_stream_loop_second_iteration_messages() -> None: - """Simulates what stream_loop sends on the 2nd LLM call in a multi-turn. - - After the first stream_step returns a tool call, stream_loop appends - the assistant message (now with status=result after execute_tool) and - calls stream_step again. The messages must not have consecutive - same-role entries. - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="talk_to_mothership", - tool_args='{"question": "test"}', - ) - tool.set_result("answer") - - # These are the messages that stream_loop would pass to the 2nd stream_step: - # original user messages + assistant message from 1st step (with tool result) - messages = [ - Message(role="user", parts=[TextPart(text="ask the mothership")]), - Message(role="assistant", parts=[tool]), - # No user message follows — this is the loop, not a new user turn - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should be: user, assistant(tool_use), user(tool_result) - assert len(anthropic_msgs) == 3 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert anthropic_msgs[2]["role"] == "user" - - # Verify the tool_result is present - tool_results = [ - b for b in anthropic_msgs[2]["content"] if b.get("type") == "tool_result" - ] - assert len(tool_results) == 1 - - -async def test_pending_tool_does_not_emit_tool_result() -> None: - """A tool with status='pending' must not produce a tool_result block. - - When stream_step returns a message mid-stream (before tool execution), - the ToolPart has status='pending'. The converter must emit only - the tool_use block — no tool_result. - """ - tool = ToolPart( - tool_call_id="toolu_pending", - tool_name="slow_tool", - tool_args='{"x": 1}', - ) - # Don't call set_result — status stays "pending" - - messages = [ - Message(role="user", parts=[TextPart(text="do something")]), - Message(role="assistant", parts=[tool]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # assistant message with tool_use, but NO user message with tool_result - assert len(anthropic_msgs) == 2 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert any(b["type"] == "tool_use" for b in anthropic_msgs[1]["content"]) - - # No tool_result anywhere - for msg in anthropic_msgs: - if isinstance(msg["content"], list): - assert not any(b.get("type") == "tool_result" for b in msg["content"]) - - -# -- Multimodal user messages ------------------------------------------------ - - -async def test_user_text_only_is_plain_string() -> None: - """Text-only user messages should produce a plain content string.""" - msgs = [Message(role="user", parts=[TextPart(text="Hello")])] - _sys, result = await _messages_to_anthropic(msgs) - assert result[0]["content"] == "Hello" - - -async def test_user_image_url() -> None: - """Image URL → Anthropic image block with url source.""" - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="Describe this"), - FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), - ], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - content = result[0]["content"] - assert content[0] == {"type": "text", "text": "Describe this"} - assert content[1] == { - "type": "image", - "source": {"type": "url", "url": "https://example.com/cat.jpg"}, - } - - -async def test_user_image_base64() -> None: - """Base64 image → Anthropic image block with base64 source.""" - b64 = base64.b64encode(b"\x89PNG").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="image/png")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - img = result[0]["content"][0] - assert img["type"] == "image" - assert img["source"]["type"] == "base64" - assert img["source"]["media_type"] == "image/png" - assert img["source"]["data"] == b64 - - -async def test_user_pdf_url() -> None: - """PDF URL → Anthropic document block with url source.""" - msgs = [ - Message( - role="user", - parts=[ - FilePart( - data="https://example.com/doc.pdf", media_type="application/pdf" - ) - ], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"] == {"type": "url", "url": "https://example.com/doc.pdf"} - - -async def test_user_pdf_base64() -> None: - """PDF base64 → Anthropic document block with base64 source.""" - b64 = base64.b64encode(b"%PDF-1.4").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="application/pdf")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"]["type"] == "base64" - assert doc["source"]["media_type"] == "application/pdf" - - -async def test_user_text_plain_bytes() -> None: - """text/plain with bytes → Anthropic document with text source.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"]["type"] == "text" - assert doc["source"]["data"] == "Hello, world!" - - -async def test_unsupported_media_type_raises() -> None: - """Unsupported media type → ValueError.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"\x00", media_type="video/mp4")], - ) - ] - with pytest.raises(ValueError, match="Unsupported media type"): - await _messages_to_anthropic(msgs) diff --git a/tests/models/core/media/__init__.py b/tests/models/core/media/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/core/media/test_data.py b/tests/models/core/media/test_data.py deleted file mode 100644 index 55783ad7..00000000 --- a/tests/models/core/media/test_data.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Tests for media data-format helpers (URL detection, base-64, data URLs).""" - -from vercel_ai_sdk.models.core.media.data import ( - data_to_base64, - data_to_data_url, - is_url, - split_data_url, -) - -# -- is_url ---------------------------------------------------------------- - - -def test_is_url_http() -> None: - assert is_url("https://example.com/img.png") is True - assert is_url("http://example.com/img.png") is True - - -def test_is_url_data() -> None: - assert is_url("data:image/png;base64,abc") is True - - -def test_is_url_base64() -> None: - assert is_url("iVBORw0KGgo=") is False - - -# -- data_to_base64 ------------------------------------------------------- - - -def test_data_to_base64_bytes() -> None: - assert data_to_base64(b"\x01\x02\x03") == "AQID" - - -def test_data_to_base64_passthrough() -> None: - assert data_to_base64("AQID") == "AQID" - - -def test_data_to_base64_extracts_from_data_url() -> None: - """data: URLs must have the prefix stripped -- providers need raw base64.""" - result = data_to_base64("data:image/png;base64,AQID") - assert result == "AQID" - - -def test_data_to_base64_passthrough_http_url() -> None: - """HTTP URLs are passed through -- caller must handle.""" - url = "https://example.com/img.png" - assert data_to_base64(url) == url - - -# -- data_to_data_url ------------------------------------------------------ - - -def test_data_to_data_url_from_bytes() -> None: - result = data_to_data_url(b"\x01\x02\x03", "image/png") - assert result == "data:image/png;base64,AQID" - - -def test_data_to_data_url_passthrough_url() -> None: - url = "https://example.com/img.png" - assert data_to_data_url(url, "image/png") == url - - -# -- split_data_url -------------------------------------------------------- - - -def test_split_data_url_valid() -> None: - mt, b64 = split_data_url("data:image/png;base64,iVBOR") - assert mt == "image/png" - assert b64 == "iVBOR" - - -def test_split_data_url_non_data_url() -> None: - mt, b64 = split_data_url("https://example.com/img.png") - assert mt is None - assert b64 is None - - -def test_split_data_url_malformed() -> None: - mt, b64 = split_data_url("data:") - assert mt is None - assert b64 is None diff --git a/tests/models/core/media/test_detect_media_type.py b/tests/models/core/media/test_detect_media_type.py deleted file mode 100644 index 6199a493..00000000 --- a/tests/models/core/media/test_detect_media_type.py +++ /dev/null @@ -1,460 +0,0 @@ -"""Tests for magic-byte media type detection. - -Ported from: .reference/ai/packages/ai/src/util/detect-media-type.test.ts -""" - -from __future__ import annotations - -import base64 - -from vercel_ai_sdk.models.core.media.detect import ( - AUDIO_SIGNATURES, - IMAGE_SIGNATURES, - detect_media_type, -) - -# --------------------------------------------------------------------------- -# Image detection -# --------------------------------------------------------------------------- - - -class TestGif: - def test_detect_gif_from_bytes(self) -> None: - data = bytes([0x47, 0x49, 0x46, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/gif" - - def test_detect_gif_from_base64(self) -> None: - assert detect_media_type("R0lGabc123", IMAGE_SIGNATURES) == "image/gif" - - -class TestPng: - def test_detect_png_from_bytes(self) -> None: - data = bytes([0x89, 0x50, 0x4E, 0x47, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/png" - - def test_detect_png_from_base64(self) -> None: - assert detect_media_type("iVBORwabc123", IMAGE_SIGNATURES) == "image/png" - - -class TestJpeg: - def test_detect_jpeg_from_bytes(self) -> None: - data = bytes([0xFF, 0xD8, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/jpeg" - - def test_detect_jpeg_from_base64(self) -> None: - assert detect_media_type("/9j/abc123", IMAGE_SIGNATURES) == "image/jpeg" - - -class TestWebp: - def test_detect_webp_from_bytes(self) -> None: - # RIFF + 4 bytes (file size) + WEBP + VP8 data - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, # "RIFF" - 0x24, - 0x00, - 0x00, - 0x00, # file size (wildcard in sig) - 0x57, - 0x45, - 0x42, - 0x50, # "WEBP" - 0x56, - 0x50, - 0x38, - 0x20, # "VP8 " (trailing data) - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/webp" - - def test_detect_webp_from_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, - 0x56, - 0x50, - 0x38, - 0x20, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/webp" - - def test_riff_audio_not_detected_as_webp_bytes(self) -> None: - """RIFF + WAVE should NOT match WebP.""" - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, # "WAVE", not "WEBP" - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_riff_audio_not_detected_as_webp_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) is None - - -class TestBmp: - def test_detect_bmp_from_bytes(self) -> None: - data = bytes([0x42, 0x4D, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/bmp" - - def test_detect_bmp_from_base64(self) -> None: - data = bytes([0x42, 0x4D, 0xFF, 0xFF]) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/bmp" - - -class TestTiff: - def test_detect_tiff_le_from_bytes(self) -> None: - data = bytes([0x49, 0x49, 0x2A, 0x00, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_le_from_base64(self) -> None: - assert detect_media_type("SUkqAAabc123", IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_be_from_bytes(self) -> None: - data = bytes([0x4D, 0x4D, 0x00, 0x2A, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_be_from_base64(self) -> None: - assert detect_media_type("TU0AKgabc123", IMAGE_SIGNATURES) == "image/tiff" - - -class TestAvif: - def test_detect_avif_from_bytes(self) -> None: - data = bytes( - [ - 0x00, - 0x00, - 0x00, - 0x20, - 0x66, - 0x74, - 0x79, - 0x70, - 0x61, - 0x76, - 0x69, - 0x66, - 0xFF, - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/avif" - - def test_detect_avif_from_base64(self) -> None: - assert ( - detect_media_type("AAAAIGZ0eXBhdmlmabc123", IMAGE_SIGNATURES) - == "image/avif" - ) - - -class TestHeic: - def test_detect_heic_from_bytes(self) -> None: - data = bytes( - [ - 0x00, - 0x00, - 0x00, - 0x20, - 0x66, - 0x74, - 0x79, - 0x70, - 0x68, - 0x65, - 0x69, - 0x63, - 0xFF, - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/heic" - - def test_detect_heic_from_base64(self) -> None: - assert ( - detect_media_type("AAAAIGZ0eXBoZWljabc123", IMAGE_SIGNATURES) - == "image/heic" - ) - - -# --------------------------------------------------------------------------- -# Audio detection -# --------------------------------------------------------------------------- - - -class TestMp3: - def test_detect_mp3_from_bytes(self) -> None: - data = bytes([0xFF, 0xFB]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_from_base64(self) -> None: - assert detect_media_type("//s=", AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_with_id3v2_tags_from_bytes(self) -> None: - """ID3v2 header (10 bytes tag, size=4) followed by MP3 frame.""" - data = bytes( - [ - 0x49, - 0x44, - 0x33, # "ID3" - 0x04, - 0x00, # version - 0x00, # flags - 0x00, - 0x00, - 0x00, - 0x04, # size = 4 (syncsafe) - 0x00, - 0x00, - 0x00, - 0x00, # 4 bytes of tag data - 0xFF, - 0xFB, # MP3 frame sync - 0x90, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_with_id3v2_tags_from_base64(self) -> None: - data = bytes( - [ - 0x49, - 0x44, - 0x33, - 0x04, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x04, - 0x00, - 0x00, - 0x00, - 0x00, - 0xFF, - 0xFB, - 0x90, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/mpeg" - - -class TestWav: - def test_detect_wav_from_bytes(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, # "RIFF" - 0x24, - 0x00, - 0x00, - 0x00, # file size - 0x57, - 0x41, - 0x56, - 0x45, # "WAVE" - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/wav" - - def test_detect_wav_from_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/wav" - - def test_webp_not_detected_as_wav_bytes(self) -> None: - """RIFF + WEBP should NOT match WAV.""" - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, # "WEBP", not "WAVE" - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_webp_not_detected_as_wav_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) is None - - -class TestOgg: - def test_detect_ogg_from_bytes(self) -> None: - data = bytes([0x4F, 0x67, 0x67, 0x53]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/ogg" - - def test_detect_ogg_from_base64(self) -> None: - assert detect_media_type("T2dnUw", AUDIO_SIGNATURES) == "audio/ogg" - - -class TestFlac: - def test_detect_flac_from_bytes(self) -> None: - data = bytes([0x66, 0x4C, 0x61, 0x43]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/flac" - - def test_detect_flac_from_base64(self) -> None: - assert detect_media_type("ZkxhQw", AUDIO_SIGNATURES) == "audio/flac" - - -class TestAac: - def test_detect_aac_from_bytes(self) -> None: - data = bytes([0x40, 0x15, 0x00, 0x00]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/aac" - - def test_detect_aac_from_base64(self) -> None: - data = bytes([0x40, 0x15, 0x00, 0x00]) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/aac" - - -class TestMp4Audio: - def test_detect_mp4_from_bytes(self) -> None: - data = bytes([0x66, 0x74, 0x79, 0x70]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mp4" - - def test_detect_mp4_from_base64(self) -> None: - assert detect_media_type("ZnR5cA", AUDIO_SIGNATURES) == "audio/mp4" - - -class TestWebmAudio: - def test_detect_webm_from_bytes(self) -> None: - data = bytes([0x1A, 0x45, 0xDF, 0xA3]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/webm" - - def test_detect_webm_from_base64(self) -> None: - assert detect_media_type("GkXfow==", AUDIO_SIGNATURES) == "audio/webm" - - -# --------------------------------------------------------------------------- -# Error / edge cases -# --------------------------------------------------------------------------- - - -class TestEdgeCases: - def test_unknown_image_format(self) -> None: - data = bytes([0x00, 0x01, 0x02, 0x03]) - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_unknown_audio_format(self) -> None: - data = bytes([0x00, 0x01, 0x02, 0x03]) - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_empty_bytes_image(self) -> None: - assert detect_media_type(b"", IMAGE_SIGNATURES) is None - - def test_empty_bytes_audio(self) -> None: - assert detect_media_type(b"", AUDIO_SIGNATURES) is None - - def test_short_bytes_image(self) -> None: - """Bytes shorter than longest signature should not crash.""" - data = bytes([0x89, 0x50]) # incomplete PNG - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_short_bytes_audio(self) -> None: - data = bytes([0x4F, 0x67]) # incomplete OGG - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_invalid_base64_image(self) -> None: - assert detect_media_type("invalid123", IMAGE_SIGNATURES) is None - - def test_invalid_base64_audio(self) -> None: - assert detect_media_type("invalid123", AUDIO_SIGNATURES) is None diff --git a/tests/models/core/media/test_models.py b/tests/models/core/media/test_models.py deleted file mode 100644 index ed5a81ea..00000000 --- a/tests/models/core/media/test_models.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Tests for MediaModel: extraction and message assembly.""" - -from __future__ import annotations - -from typing import Any - -import pytest - -from vercel_ai_sdk.models.core.media import MediaModel, MediaResult -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart, Usage - -# --------------------------------------------------------------------------- -# Concrete stub for testing the base class -# --------------------------------------------------------------------------- - - -class _StubMediaModel(MediaModel): - """Minimal concrete implementation that just returns what we tell it to.""" - - def __init__(self, result: MediaResult) -> None: - self._result = result - - async def make_request( - self, - prompt: str, - input_files: list[FilePart], - *, - n: int = 1, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - return self._result - - -# --------------------------------------------------------------------------- -# _extract_prompt -# --------------------------------------------------------------------------- - - -class TestExtractPrompt: - def test_user_text(self) -> None: - msgs = [Message(role="user", parts=[TextPart(text="hello world")])] - assert MediaModel._extract_prompt(msgs) == "hello world" - - def test_system_and_user(self) -> None: - msgs = [ - Message(role="system", parts=[TextPart(text="be helpful")]), - Message(role="user", parts=[TextPart(text="draw a cat")]), - ] - assert MediaModel._extract_prompt(msgs) == "be helpful draw a cat" - - def test_ignores_assistant(self) -> None: - msgs = [ - Message(role="user", parts=[TextPart(text="hello")]), - Message(role="assistant", parts=[TextPart(text="ignored")]), - ] - assert MediaModel._extract_prompt(msgs) == "hello" - - def test_multiple_text_parts(self) -> None: - msgs = [ - Message( - role="user", - parts=[TextPart(text="first"), TextPart(text="second")], - ) - ] - assert MediaModel._extract_prompt(msgs) == "first second" - - def test_skips_non_text_parts(self) -> None: - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="prompt"), - FilePart(data=b"\x89PNG", media_type="image/png"), - ], - ) - ] - assert MediaModel._extract_prompt(msgs) == "prompt" - - def test_empty_messages(self) -> None: - assert MediaModel._extract_prompt([]) == "" - - -# --------------------------------------------------------------------------- -# _extract_input_files -# --------------------------------------------------------------------------- - - -class TestExtractInputFiles: - def test_user_file_parts(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - pdf = FilePart(data=b"%PDF", media_type="application/pdf") - msgs = [Message(role="user", parts=[TextPart(text="hi"), img, pdf])] - result = MediaModel._extract_input_files(msgs) - assert result == [img, pdf] - - def test_ignores_assistant_files(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - msgs = [Message(role="assistant", parts=[img])] - assert MediaModel._extract_input_files(msgs) == [] - - def test_ignores_system_files(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - msgs = [Message(role="system", parts=[img])] - assert MediaModel._extract_input_files(msgs) == [] - - def test_returns_all_media_types(self) -> None: - """Unlike the old extract_input_images, this returns ALL file parts.""" - img = FilePart(data=b"\x89PNG", media_type="image/png") - audio = FilePart(data=b"\xff\xfb", media_type="audio/mpeg") - video = FilePart(data=b"\x00\x00", media_type="video/mp4") - msgs = [Message(role="user", parts=[img, audio, video])] - result = MediaModel._extract_input_files(msgs) - assert len(result) == 3 - - def test_empty_messages(self) -> None: - assert MediaModel._extract_input_files([]) == [] - - def test_multiple_user_messages(self) -> None: - img1 = FilePart(data=b"\x89PNG", media_type="image/png") - img2 = FilePart(data=b"\xff\xd8", media_type="image/jpeg") - msgs = [ - Message(role="user", parts=[img1]), - Message(role="user", parts=[img2]), - ] - result = MediaModel._extract_input_files(msgs) - assert result == [img1, img2] - - -# --------------------------------------------------------------------------- -# _build_message -# --------------------------------------------------------------------------- - - -class TestBuildMessage: - def test_wraps_files_in_message(self) -> None: - fp = FilePart(data=b"\x89PNG", media_type="image/png") - result = MediaResult(files=[fp]) - msg = MediaModel._build_message(result) - assert msg.role == "assistant" - assert len(msg.parts) == 1 - assert msg.images[0] is fp - - def test_includes_usage(self) -> None: - fp = FilePart(data=b"\x89PNG", media_type="image/png") - usage = Usage(input_tokens=10, output_tokens=20) - result = MediaResult(files=[fp], usage=usage) - msg = MediaModel._build_message(result) - assert msg.usage is not None - assert msg.usage.input_tokens == 10 - assert msg.usage.output_tokens == 20 - - def test_no_usage(self) -> None: - result = MediaResult(files=[]) - msg = MediaModel._build_message(result) - assert msg.usage is None - - def test_empty_files(self) -> None: - result = MediaResult(files=[]) - msg = MediaModel._build_message(result) - assert msg.parts == [] - - -# --------------------------------------------------------------------------- -# Integration: generate() calls make_request() and wraps result -# --------------------------------------------------------------------------- - - -class TestGenerateIntegration: - @pytest.mark.asyncio - async def test_generate_round_trip(self) -> None: - """The base class extracts prompt/files and wraps the result.""" - fp_out = FilePart(data="b64data", media_type="image/png") - usage = Usage(input_tokens=5, output_tokens=15) - stub = _StubMediaModel(MediaResult(files=[fp_out], usage=usage)) - - # We can't call generate() directly on MediaModel since it doesn't - # define one — subclasses do. But we can verify the pipeline by - # calling the helpers manually. - prompt = stub._extract_prompt( - [Message(role="user", parts=[TextPart(text="a sunset")])] - ) - assert prompt == "a sunset" - - input_files = stub._extract_input_files( - [ - Message( - role="user", - parts=[FilePart(data=b"\x89PNG", media_type="image/png")], - ) - ] - ) - assert len(input_files) == 1 - - result = await stub.make_request(prompt, input_files) - msg = stub._build_message(result) - assert msg.role == "assistant" - assert msg.images == [fp_out] - assert msg.usage == usage diff --git a/tests/models/core/test_llm.py b/tests/models/core/test_llm.py deleted file mode 100644 index ba1546e3..00000000 --- a/tests/models/core/test_llm.py +++ /dev/null @@ -1,295 +0,0 @@ -"""StreamHandler: event accumulation, state transitions, message building. -LanguageModel.buffer() with structured output.""" - -import json - -import pydantic -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.models.core.llm import ( - FileEvent, - MessageDone, - ReasoningDelta, - ReasoningEnd, - ReasoningStart, - StreamHandler, - TextDelta, - TextEnd, - TextStart, - ToolArgsDelta, - ToolEnd, - ToolStart, -) -from vercel_ai_sdk.types.messages import ( - FilePart, - ReasoningPart, - TextPart, - ToolPart, - Usage, -) - -from ...conftest import MockLLM, text_msg - - -class _Weather(pydantic.BaseModel): - city: str - temperature: float - - -# -- Text streaming -------------------------------------------------------- - - -def test_text_lifecycle() -> None: - h = StreamHandler(message_id="m1") - m = h.handle_event(TextStart(block_id="b1")) - assert len(m.parts) == 1 - part = m.parts[0] - assert isinstance(part, TextPart) - assert part.state == "streaming" - assert part.text == "" - - m = h.handle_event(TextDelta(block_id="b1", delta="Hello")) - part = m.parts[0] - assert isinstance(part, TextPart) - assert part.text == "Hello" - assert part.delta == "Hello" - assert part.state == "streaming" - - m = h.handle_event(TextDelta(block_id="b1", delta=" world")) - part = m.parts[0] - assert isinstance(part, TextPart) - assert part.text == "Hello world" - assert part.delta == " world" - - m = h.handle_event(TextEnd(block_id="b1")) - part = m.parts[0] - assert isinstance(part, TextPart) - assert part.state == "done" - assert part.delta is None - - -# -- Reasoning streaming --------------------------------------------------- - - -def test_reasoning_lifecycle() -> None: - h = StreamHandler(message_id="m1") - h.handle_event(ReasoningStart(block_id="r1")) - m = h.handle_event(ReasoningDelta(block_id="r1", delta="thinking")) - part = m.parts[0] - assert isinstance(part, ReasoningPart) - assert part.text == "thinking" - assert part.state == "streaming" - - m = h.handle_event(ReasoningEnd(block_id="r1", signature="sig123")) - part = m.parts[0] - assert isinstance(part, ReasoningPart) - assert part.state == "done" - assert part.signature == "sig123" - - -# -- Tool streaming -------------------------------------------------------- - - -def test_tool_lifecycle() -> None: - h = StreamHandler(message_id="m1") - h.handle_event(ToolStart(tool_call_id="tc1", tool_name="get_weather")) - m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"ci')) - part = m.parts[0] - assert isinstance(part, ToolPart) - assert part.tool_name == "get_weather" - assert part.tool_args == '{"ci' - assert part.state == "streaming" - assert part.args_delta == '{"ci' - - m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='ty":"London"}')) - part = m.parts[0] - assert isinstance(part, ToolPart) - assert part.tool_args == '{"city":"London"}' - - m = h.handle_event(ToolEnd(tool_call_id="tc1")) - part = m.parts[0] - assert isinstance(part, ToolPart) - assert part.state == "done" - assert part.args_delta is None - - -# -- Multi-part messages --------------------------------------------------- - - -def test_reasoning_then_text_then_tool() -> None: - """Full message: reasoning block, text block, tool call.""" - h = StreamHandler(message_id="m1") - h.handle_event(ReasoningStart(block_id="r1")) - h.handle_event(ReasoningDelta(block_id="r1", delta="Let me think")) - h.handle_event(ReasoningEnd(block_id="r1")) - - h.handle_event(TextStart(block_id="t1")) - h.handle_event(TextDelta(block_id="t1", delta="I'll check")) - h.handle_event(TextEnd(block_id="t1")) - - h.handle_event(ToolStart(tool_call_id="tc1", tool_name="search")) - h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"q":"test"}')) - m = h.handle_event(ToolEnd(tool_call_id="tc1")) - - assert len(m.parts) == 3 - assert isinstance(m.parts[0], ReasoningPart) - assert isinstance(m.parts[1], TextPart) - assert isinstance(m.parts[2], ToolPart) - assert all( - p.state == "done" - for p in m.parts - if isinstance(p, (TextPart, ToolPart, ReasoningPart)) - ) - - -def test_multiple_tool_calls() -> None: - """Parallel tool calls in one message.""" - h = StreamHandler(message_id="m1") - h.handle_event(ToolStart(tool_call_id="tc1", tool_name="read_file")) - h.handle_event(ToolStart(tool_call_id="tc2", tool_name="list_files")) - - m = h.handle_event(ToolArgsDelta(tool_call_id="tc1", delta='{"path":"a.py"}')) - # Both tools should be in parts - tool_parts = [p for p in m.parts if isinstance(p, ToolPart)] - assert len(tool_parts) == 2 - # tc1 has args, tc2 is empty - assert tool_parts[0].tool_args == '{"path":"a.py"}' - assert tool_parts[1].tool_args == "" - - h.handle_event(ToolArgsDelta(tool_call_id="tc2", delta='{"dir":"."}')) - h.handle_event(ToolEnd(tool_call_id="tc1")) - m = h.handle_event(ToolEnd(tool_call_id="tc2")) - assert all( - p.state == "done" - for p in m.parts - if isinstance(p, (TextPart, ToolPart, ReasoningPart)) - ) - - -# -- MessageDone ----------------------------------------------------------- - - -def test_message_done_finalizes_all() -> None: - h = StreamHandler(message_id="m1") - h.handle_event(TextStart(block_id="t1")) - h.handle_event(TextDelta(block_id="t1", delta="hello")) - # Don't send TextEnd -- MessageDone should finalize everything - m = h.handle_event(MessageDone(finish_reason="end_turn")) - part = m.parts[0] - assert isinstance(part, TextPart) - assert part.state == "done" - assert m.is_done - - -def test_message_done_propagates_usage() -> None: - """Usage on MessageDone surfaces on the built Message.""" - usage = Usage(input_tokens=10, output_tokens=20) - h = StreamHandler(message_id="m1") - h.handle_event(TextStart(block_id="t1")) - h.handle_event(TextDelta(block_id="t1", delta="hi")) - - # Before MessageDone, usage should not be on the message - m = h.handle_event(TextEnd(block_id="t1")) - assert m.usage is None - - m = h.handle_event(MessageDone(usage=usage)) - assert m.usage is not None - assert m.usage.input_tokens == 10 - assert m.usage.output_tokens == 20 - assert m.usage.total_tokens == 30 - - -# -- Message properties propagate ------------------------------------------ - - -def test_message_id_propagates() -> None: - h = StreamHandler(message_id="custom-id") - m = h.handle_event(TextStart(block_id="b1")) - assert m.id == "custom-id" - - -def test_deltas_only_on_active_blocks() -> None: - """Delta should be None on inactive blocks, present only on active.""" - h = StreamHandler(message_id="m1") - h.handle_event(TextStart(block_id="t1")) - h.handle_event(TextDelta(block_id="t1", delta="first")) - h.handle_event(TextEnd(block_id="t1")) - - h.handle_event(TextStart(block_id="t2")) - m = h.handle_event(TextDelta(block_id="t2", delta="second")) - - text_parts = [p for p in m.parts if isinstance(p, TextPart)] - assert text_parts[0].delta is None # t1 is done - assert text_parts[1].delta == "second" # t2 is active - - -# -- LanguageModel.buffer() with structured output ------------------------- - - -@pytest.mark.asyncio -async def test_buffer_structured_output() -> None: - """buffer() returns a message with a validated StructuredOutputPart.""" - json_text = '{"city":"Tokyo","temperature":28.5}' - llm = MockLLM([[text_msg(json_text)]]) - - msg = await llm.buffer(ai.make_messages(user="weather?"), output_type=_Weather) - - assert isinstance(msg.output, _Weather) - assert msg.output.city == "Tokyo" - - -@pytest.mark.asyncio -async def test_buffer_structured_output_invalid_json_raises() -> None: - """Bad LLM output with output_type should raise, not silently pass.""" - llm = MockLLM([[text_msg("not json")]]) - - with pytest.raises((json.JSONDecodeError, pydantic.ValidationError)): - await llm.buffer(ai.make_messages(user="weather?"), output_type=_Weather) - - -# -- File event (inline images from LLMs like Gemini/GPT-5) --------------- - - -def test_file_event_accumulates() -> None: - """FileEvent should produce a FilePart in the message.""" - h = StreamHandler(message_id="m1") - m = h.handle_event( - FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - file_parts = [p for p in m.parts if isinstance(p, FilePart)] - assert len(file_parts) == 1 - assert file_parts[0].media_type == "image/png" - assert file_parts[0].data == "iVBORw0KGgo=" - - -def test_file_event_with_text() -> None: - """A message can have both text and file parts (e.g. Gemini image gen).""" - h = StreamHandler(message_id="m1") - h.handle_event(TextStart(block_id="t1")) - h.handle_event(TextDelta(block_id="t1", delta="Here is your image:")) - h.handle_event(TextEnd(block_id="t1")) - h.handle_event( - FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") - ) - m = h.handle_event(MessageDone(finish_reason="stop")) - - assert len(m.parts) == 2 - assert isinstance(m.parts[0], TextPart) - assert m.parts[0].text == "Here is your image:" - assert isinstance(m.parts[1], FilePart) - assert m.parts[1].media_type == "image/png" - assert m.is_done - - -def test_multiple_file_events() -> None: - """Multiple FileEvents produce multiple FileParts.""" - h = StreamHandler(message_id="m1") - h.handle_event(FileEvent(block_id="f1", media_type="image/png", data="png_data")) - m = h.handle_event( - FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") - ) - file_parts = [p for p in m.parts if isinstance(p, FilePart)] - assert len(file_parts) == 2 - assert file_parts[0].media_type == "image/png" - assert file_parts[1].media_type == "image/jpeg" diff --git a/tests/models2/core/test_media.py b/tests/models/core/test_media.py similarity index 99% rename from tests/models2/core/test_media.py rename to tests/models/core/test_media.py index 1ac85cdc..eb77c96a 100644 --- a/tests/models2/core/test_media.py +++ b/tests/models/core/test_media.py @@ -9,7 +9,7 @@ import base64 -from vercel_ai_sdk.models2.core.helpers.media import ( +from vercel_ai_sdk.models.core.helpers.media import ( data_to_base64, data_to_data_url, detect_audio_media_type, diff --git a/tests/models2/core/test_streaming.py b/tests/models/core/test_streaming.py similarity index 99% rename from tests/models2/core/test_streaming.py rename to tests/models/core/test_streaming.py index a0927608..538d3a50 100644 --- a/tests/models2/core/test_streaming.py +++ b/tests/models/core/test_streaming.py @@ -1,7 +1,6 @@ """StreamHandler: event accumulation, state transitions, message building.""" - -from vercel_ai_sdk.models2.core.helpers.streaming import ( +from vercel_ai_sdk.models.core.helpers.streaming import ( FileEvent, MessageDone, ReasoningDelta, diff --git a/tests/models/openai/__init__.py b/tests/models/openai/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/openai/test_openai.py b/tests/models/openai/test_openai.py deleted file mode 100644 index 964886e7..00000000 --- a/tests/models/openai/test_openai.py +++ /dev/null @@ -1,245 +0,0 @@ -"""OpenAI provider: _messages_to_openai multimodal conversion tests.""" - -import base64 -from unittest.mock import AsyncMock, patch - -import pytest - -from vercel_ai_sdk.models.openai import _messages_to_openai -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart - -# -- text-only (regression) ------------------------------------------------ - - -@pytest.mark.asyncio -async def test_user_text_only_is_plain_string() -> None: - """Text-only user messages should produce a plain content string, not array.""" - msgs = [Message(role="user", parts=[TextPart(text="Hello")])] - result = await _messages_to_openai(msgs) - assert result == [{"role": "user", "content": "Hello"}] - - -# -- images ---------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_image_url() -> None: - """Image URL → OpenAI image_url content part.""" - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="What's this?"), - FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), - ], - ) - ] - result = await _messages_to_openai(msgs) - content = result[0]["content"] - assert content[0] == {"type": "text", "text": "What's this?"} - assert content[1] == { - "type": "image_url", - "image_url": {"url": "https://example.com/cat.jpg"}, - } - - -@pytest.mark.asyncio -async def test_user_image_base64() -> None: - """Base64 image data → OpenAI image_url with data URL.""" - b64 = base64.b64encode(b"\x89PNG").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - content = result[0]["content"] - assert content[0]["type"] == "image_url" - assert content[0]["image_url"]["url"] == f"data:image/png;base64,{b64}" - - -@pytest.mark.asyncio -async def test_user_image_bytes() -> None: - """Raw bytes image → OpenAI image_url with data URL.""" - raw = b"\x89PNG" - msgs = [ - Message( - role="user", - parts=[FilePart(data=raw, media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - url = result[0]["content"][0]["image_url"]["url"] - assert url.startswith("data:image/png;base64,") - - -@pytest.mark.asyncio -async def test_user_image_wildcard_becomes_jpeg() -> None: - """image/* media type is normalized to image/jpeg for the data URL.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="https://example.com/img", media_type="image/*")], - ) - ] - result = await _messages_to_openai(msgs) - # URL passthrough: no data URL conversion needed - assert result[0]["content"][0]["image_url"]["url"] == "https://example.com/img" - - -@pytest.mark.asyncio -async def test_user_image_data_url() -> None: - """data: URL image → base64 extracted correctly for image_url.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="data:image/png;base64,AQID", media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - # data: URLs pass through directly for images - assert result[0]["content"][0]["image_url"]["url"] == "data:image/png;base64,AQID" - - -# -- audio ----------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_audio_base64() -> None: - """Audio base64 → OpenAI input_audio part.""" - b64 = base64.b64encode(b"\xff\xfb").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="audio/wav")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["data"] == b64 - assert part["input_audio"]["format"] == "wav" - - -@pytest.mark.asyncio -async def test_user_audio_data_url_extracts_base64() -> None: - """Audio data: URL → base64 prefix stripped for input_audio.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="data:audio/wav;base64,AAAA", media_type="audio/wav")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["data"] == "AAAA" - - -@pytest.mark.asyncio -async def test_user_audio_url_downloads() -> None: - """Audio URLs are auto-downloaded since OpenAI requires base64.""" - fake_audio = b"\xff\xfb\x90\x00" - msgs = [ - Message( - role="user", - parts=[ - FilePart(data="https://example.com/clip.wav", media_type="audio/wav") - ], - ) - ] - with patch( - "vercel_ai_sdk.models.core.media.download.download", - new_callable=AsyncMock, - return_value=(fake_audio, "audio/wav"), - ): - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["format"] == "wav" - # Should be base64 of the downloaded bytes - assert part["input_audio"]["data"] == base64.b64encode(fake_audio).decode() - - -# -- PDF ------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_pdf_base64() -> None: - """PDF base64 → OpenAI file part.""" - b64 = base64.b64encode(b"%PDF-1.4").decode() - msgs = [ - Message( - role="user", - parts=[ - FilePart(data=b64, media_type="application/pdf", filename="report.pdf") - ], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "file" - assert part["file"]["filename"] == "report.pdf" - assert part["file"]["file_data"].startswith("data:application/pdf;base64,") - - -@pytest.mark.asyncio -async def test_user_pdf_url_downloads() -> None: - """PDF URLs are auto-downloaded since OpenAI requires base64.""" - fake_pdf = b"%PDF-1.4 fake content" - msgs = [ - Message( - role="user", - parts=[ - FilePart( - data="https://example.com/doc.pdf", - media_type="application/pdf", - filename="doc.pdf", - ) - ], - ) - ] - with patch( - "vercel_ai_sdk.models.core.media.download.download", - new_callable=AsyncMock, - return_value=(fake_pdf, "application/pdf"), - ): - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "file" - assert part["file"]["filename"] == "doc.pdf" - assert part["file"]["file_data"].startswith("data:application/pdf;base64,") - - -# -- text/* ---------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_text_file_bytes() -> None: - """text/* file with bytes data → decoded to text content part.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part == {"type": "text", "text": "Hello, world!"} - - -# -- unsupported ----------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_unsupported_media_type_raises() -> None: - """Unknown media type → ValueError.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"\x00", media_type="application/octet-stream")], - ) - ] - with pytest.raises(ValueError, match="Unsupported media type"): - await _messages_to_openai(msgs) diff --git a/tests/models2/__init__.py b/tests/models2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models2/ai_gateway/__init__.py b/tests/models2/ai_gateway/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models2/ai_gateway/test_errors.py b/tests/models2/ai_gateway/test_errors.py deleted file mode 100644 index 01f151ef..00000000 --- a/tests/models2/ai_gateway/test_errors.py +++ /dev/null @@ -1,139 +0,0 @@ -"""Tests for the gateway error factory. - -The factory ``create_gateway_error`` is the real point of contact: -it parses the JSON error response from the gateway server and -dispatches to the correct error class. These tests use payloads -matching the actual gateway wire format. -""" - -from __future__ import annotations - -import json - -from vercel_ai_sdk.models2.ai_gateway import errors - - -class TestGatewayErrorBase: - """Base class behaviour that all concrete errors inherit.""" - - def test_isinstance_hierarchy(self) -> None: - err = errors.GatewayRateLimitError("nope") - assert isinstance(err, errors.GatewayError) - assert isinstance(err, Exception) - - def test_generation_id_in_message(self) -> None: - err = errors.GatewayInternalServerError("boom", generation_id="gen-123") - assert "[gen-123]" in str(err) - assert err.generation_id == "gen-123" - - def test_cause_chained(self) -> None: - original = ValueError("original") - err = errors.GatewayInternalServerError("boom", cause=original) - assert err.__cause__ is original - - -class TestCreateGatewayError: - """The factory must dispatch on ``error.type`` from the response.""" - - def test_authentication_error_from_json_string(self) -> None: - body = json.dumps( - { - "error": { - "message": "Invalid API key", - "type": "authentication_error", - } - } - ) - err = errors.create_gateway_error( - response_body=body, - status_code=401, - api_key_provided=True, - ) - assert isinstance(err, errors.GatewayAuthenticationError) - assert err.status_code == 401 - # contextual message includes the key URL - assert "vercel.com/d?to=" in str(err) - - def test_invalid_request_error(self) -> None: - body = { - "error": { - "message": "Bad format", - "type": "invalid_request_error", - } - } - err = errors.create_gateway_error(response_body=body, status_code=400) - assert isinstance(err, errors.GatewayInvalidRequestError) - assert err.status_code == 400 - - def test_rate_limit_error(self) -> None: - body = { - "error": { - "message": "Rate limit exceeded", - "type": "rate_limit_exceeded", - } - } - err = errors.create_gateway_error(response_body=body, status_code=429) - assert isinstance(err, errors.GatewayRateLimitError) - - def test_model_not_found_extracts_model_id(self) -> None: - body = { - "error": { - "message": "Model xyz not found", - "type": "model_not_found", - "param": {"modelId": "xyz"}, - } - } - err = errors.create_gateway_error(response_body=body, status_code=404) - assert isinstance(err, errors.GatewayModelNotFoundError) - assert err.model_id == "xyz" - - def test_model_not_found_without_param(self) -> None: - body = { - "error": { - "message": "Not found", - "type": "model_not_found", - } - } - err = errors.create_gateway_error(response_body=body, status_code=404) - assert isinstance(err, errors.GatewayModelNotFoundError) - assert err.model_id is None - - def test_internal_server_error(self) -> None: - body = { - "error": { - "message": "Database down", - "type": "internal_server_error", - } - } - err = errors.create_gateway_error(response_body=body, status_code=500) - assert isinstance(err, errors.GatewayInternalServerError) - - def test_unknown_type_falls_back_to_internal(self) -> None: - body = { - "error": { - "message": "Something weird", - "type": "alien_error", - } - } - err = errors.create_gateway_error(response_body=body, status_code=500) - assert isinstance(err, errors.GatewayInternalServerError) - - def test_malformed_json_string(self) -> None: - err = errors.create_gateway_error(response_body="Not JSON", status_code=500) - assert isinstance(err, errors.GatewayResponseError) - - def test_missing_error_field(self) -> None: - body = {"ferror": {"message": "oops"}} - err = errors.create_gateway_error(response_body=body, status_code=404) - assert isinstance(err, errors.GatewayResponseError) - - def test_generation_id_extracted(self) -> None: - body = { - "error": { - "message": "Rate limit", - "type": "rate_limit_exceeded", - }, - "generationId": "gen-abc", - } - err = errors.create_gateway_error(response_body=body, status_code=429) - assert err.generation_id == "gen-abc" diff --git a/tests/models2/ai_gateway/test_protocol.py b/tests/models2/ai_gateway/test_protocol.py deleted file mode 100644 index c3afbc9a..00000000 --- a/tests/models2/ai_gateway/test_protocol.py +++ /dev/null @@ -1,460 +0,0 @@ -"""Tests for the v3 protocol serialization and deserialization. - -Focus areas: -- ``_messages_to_prompt``: the critical outgoing translation layer -- ``_build_request_body``: using real ``@tool`` -- ``_parse_stream_part``: the critical incoming translation layer -- ``_parse_usage``: the two distinct wire formats -""" - -from __future__ import annotations - -import importlib -import json -from unittest.mock import AsyncMock, patch - -import pydantic -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.models2.core.helpers import streaming -from vercel_ai_sdk.types import messages - -# The ai_gateway __init__.py re-exports `stream` as a function, which -# shadows the module. Use importlib to get the actual module. -stream_mod = importlib.import_module("vercel_ai_sdk.models2.ai_gateway.stream") - -# --------------------------------------------------------------------------- -# _messages_to_prompt -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -class TestMessagesToPrompt: - async def test_system_message(self) -> None: - msgs = [ - messages.Message( - role="system", - parts=[messages.TextPart(text="You are helpful.")], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - assert result == [{"role": "system", "content": "You are helpful."}] - - async def test_user_message(self) -> None: - msgs = [ - messages.Message( - role="user", - parts=[messages.TextPart(text="Hello")], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - assert result == [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - } - ] - - async def test_assistant_with_reasoning_and_text(self) -> None: - msgs = [ - messages.Message( - role="assistant", - parts=[ - messages.ReasoningPart(text="Let me think..."), - messages.TextPart(text="42"), - ], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - content = result[0]["content"] - assert content[0] == {"type": "reasoning", "text": "Let me think..."} - assert content[1] == {"type": "text", "text": "42"} - - async def test_tool_call_with_result_produces_two_messages(self) -> None: - """A completed tool call must produce an assistant message - (with the tool-call) AND a tool message (with the result).""" - msgs = [ - messages.Message( - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args='{"city": "SF"}', - status="result", - result={"temp": 72}, - ) - ], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - assert len(result) == 2 - - # Assistant message has the tool-call - tc = result[0]["content"][0] - assert tc["type"] == "tool-call" - assert tc["toolCallId"] == "tc-1" - assert tc["input"] == {"city": "SF"} - - # Tool message has the result - tr = result[1]["content"][0] - assert tr["type"] == "tool-result" - assert tr["output"] == {"type": "json", "value": {"temp": 72}} - - async def test_tool_error_result(self) -> None: - msgs = [ - messages.Message( - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id="tc-1", - tool_name="get_weather", - tool_args="{}", - status="error", - result="Connection timeout", - ) - ], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - tr = result[1]["content"][0] - assert tr["output"]["type"] == "error-text" - assert tr["output"]["value"] == "Connection timeout" - - async def test_user_message_with_image_url(self) -> None: - """FilePart with image URL -> downloaded and converted to data: URL.""" - fake_jpeg = b"\xff\xd8\xff\xe0" - msgs = [ - messages.Message( - role="user", - parts=[ - messages.TextPart(text="Look at this"), - messages.FilePart( - data="https://example.com/cat.jpg", media_type="image/jpeg" - ), - ], - ) - ] - with patch( - "vercel_ai_sdk.models2.core.helpers.media.download", - new_callable=AsyncMock, - return_value=(fake_jpeg, "image/jpeg"), - ): - result = await stream_mod._messages_to_prompt(msgs) - content = result[0]["content"] - assert content[0] == {"type": "text", "text": "Look at this"} - assert content[1]["type"] == "file" - assert content[1]["mediaType"] == "image/jpeg" - assert content[1]["data"].startswith("data:image/jpeg;base64,") - - async def test_user_message_with_file_bytes(self) -> None: - """FilePart with bytes -> v3 file content part with data URL.""" - msgs = [ - messages.Message( - role="user", - parts=[ - messages.FilePart( - data=b"\x89PNG", media_type="image/png", filename="pic.png" - ), - ], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - part = result[0]["content"][0] - assert part["type"] == "file" - assert part["mediaType"] == "image/png" - assert part["data"].startswith("data:image/png;base64,") - assert part["filename"] == "pic.png" - - async def test_user_message_text_only_unchanged(self) -> None: - """Regression: text-only user messages still work.""" - msgs = [ - messages.Message( - role="user", - parts=[messages.TextPart(text="Hello")], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - assert result == [ - {"role": "user", "content": [{"type": "text", "text": "Hello"}]} - ] - - async def test_pending_tool_call_no_tool_message(self) -> None: - """A pending tool call should NOT produce a tool-result message.""" - msgs = [ - messages.Message( - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id="tc-1", - tool_name="search", - tool_args="{}", - status="pending", - ) - ], - ) - ] - result = await stream_mod._messages_to_prompt(msgs) - assert len(result) == 1 - assert result[0]["role"] == "assistant" - - -# --------------------------------------------------------------------------- -# _build_request_body — using real @tool -# --------------------------------------------------------------------------- - - -@ai.tool -async def get_weather(city: str, units: str = "celsius") -> str: - """Get the current weather for a city.""" - return f"Sunny in {city}" - - -@pytest.mark.asyncio -class TestBuildRequestBody: - async def test_with_real_tool(self) -> None: - """Verify @tool-produced schema round-trips through - _build_request_body -> JSON -> gateway wire format.""" - msgs = [ - messages.Message( - role="user", - parts=[messages.TextPart(text="What's the weather?")], - ) - ] - body = await stream_mod._build_request_body(msgs, tools=[get_weather]) - - assert "tools" in body - tool_def = body["tools"][0] - assert tool_def["type"] == "function" - assert tool_def["name"] == "get_weather" - assert tool_def["description"] == ("Get the current weather for a city.") - # The schema comes from pydantic — verify structure, not exact dict - schema = tool_def["inputSchema"] - assert "properties" in schema - assert "city" in schema["properties"] - assert "units" in schema["properties"] - # 'city' is required (no default), 'units' is not (has default) - assert "city" in schema.get("required", []) - - async def test_with_output_type(self) -> None: - class WeatherResult(pydantic.BaseModel): - temp: float - condition: str - - msgs = [ - messages.Message( - role="user", - parts=[messages.TextPart(text="Weather?")], - ) - ] - body = await stream_mod._build_request_body(msgs, output_type=WeatherResult) - - assert "responseFormat" in body - rf = body["responseFormat"] - assert rf["type"] == "json" - assert rf["name"] == "WeatherResult" - assert "properties" in rf["schema"] - assert "temp" in rf["schema"]["properties"] - - async def test_provider_options_passthrough(self) -> None: - msgs = [ - messages.Message( - role="user", - parts=[messages.TextPart(text="Hi")], - ) - ] - opts = {"gateway": {"order": ["bedrock", "openai"]}} - body = await stream_mod._build_request_body(msgs, provider_options=opts) - assert body["providerOptions"] == opts - - -# --------------------------------------------------------------------------- -# _parse_stream_part — parametrized simple 1:1 mappings -# --------------------------------------------------------------------------- - -_SIMPLE_STREAM_PARTS = [ - ( - {"type": "text-start", "id": "t1"}, - streaming.TextStart(block_id="t1"), - ), - ( - {"type": "text-end", "id": "t1"}, - streaming.TextEnd(block_id="t1"), - ), - ( - {"type": "reasoning-start", "id": "r1"}, - streaming.ReasoningStart(block_id="r1"), - ), - ( - {"type": "reasoning-delta", "id": "r1", "delta": "hmm"}, - streaming.ReasoningDelta(block_id="r1", delta="hmm"), - ), - ( - {"type": "reasoning-end", "id": "r1"}, - streaming.ReasoningEnd(block_id="r1"), - ), - ( - {"type": "tool-input-start", "id": "tc-1", "toolName": "search"}, - streaming.ToolStart(tool_call_id="tc-1", tool_name="search"), - ), - ( - {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q"'}, - streaming.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), - ), - ( - {"type": "tool-input-end", "id": "tc-1"}, - streaming.ToolEnd(tool_call_id="tc-1"), - ), -] - - -@pytest.mark.parametrize( - ("wire", "expected"), - _SIMPLE_STREAM_PARTS, - ids=[w["type"] for w, _ in _SIMPLE_STREAM_PARTS], -) -def test_parse_stream_part_simple( - wire: dict[str, object], expected: streaming.StreamEvent -) -> None: - events = stream_mod._parse_stream_part(wire) - assert len(events) == 1 - assert events[0] == expected - - -@pytest.mark.asyncio -class TestParseStreamPartComplex: - async def test_text_delta_uses_textDelta_key(self) -> None: - """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" - events = stream_mod._parse_stream_part( - {"type": "text-delta", "id": "t1", "textDelta": "Hello"} - ) - assert isinstance(events[0], streaming.TextDelta) - assert events[0].delta == "Hello" - - async def test_tool_call_expands_to_three_events(self) -> None: - """A complete ``tool-call`` part must expand into - ToolStart -> ToolArgsDelta -> ToolEnd.""" - events = stream_mod._parse_stream_part( - { - "type": "tool-call", - "toolCallId": "tc-1", - "toolName": "get_weather", - "input": {"city": "SF"}, - } - ) - assert len(events) == 3 - assert isinstance(events[0], streaming.ToolStart) - assert events[0].tool_name == "get_weather" - assert isinstance(events[1], streaming.ToolArgsDelta) - assert json.loads(events[1].delta) == {"city": "SF"} - assert isinstance(events[2], streaming.ToolEnd) - - async def test_finish_flat_usage(self) -> None: - events = stream_mod._parse_stream_part( - { - "type": "finish", - "finishReason": "stop", - "usage": { - "prompt_tokens": 10, - "completion_tokens": 20, - }, - } - ) - done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "stop" - assert done.usage is not None - assert done.usage.input_tokens == 10 - assert done.usage.output_tokens == 20 - - async def test_finish_v3_nested_usage(self) -> None: - events = stream_mod._parse_stream_part( - { - "type": "finish", - "finishReason": { - "unified": "tool-calls", - "raw": "tool_calls", - }, - "usage": { - "inputTokens": { - "total": 100, - "cacheRead": 50, - }, - "outputTokens": { - "total": 200, - "reasoning": 30, - }, - }, - } - ) - done = events[0] - assert isinstance(done, streaming.MessageDone) - assert done.finish_reason == "tool-calls" - assert done.usage is not None - assert done.usage.input_tokens == 100 - assert done.usage.cache_read_tokens == 50 - assert done.usage.reasoning_tokens == 30 - - async def test_file_part(self) -> None: - """A ``file`` stream part (inline image from Gemini/GPT-5) - must produce a FileEvent.""" - events = stream_mod._parse_stream_part( - { - "type": "file", - "id": "f1", - "mediaType": "image/png", - "data": "iVBORw0KGgo=", - } - ) - assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) - assert events[0].block_id == "f1" - assert events[0].media_type == "image/png" - assert events[0].data == "iVBORw0KGgo=" - - async def test_file_part_defaults(self) -> None: - """A minimal ``file`` part uses sensible defaults.""" - events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) - assert len(events) == 1 - assert isinstance(events[0], streaming.FileEvent) - assert events[0].media_type == "application/octet-stream" - - async def test_unknown_types_produce_no_events(self) -> None: - for t in ("stream-start", "raw", "response-metadata", "banana"): - assert stream_mod._parse_stream_part({"type": t}) == [] - - -# --------------------------------------------------------------------------- -# _parse_usage -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -class TestParseUsage: - async def test_flat_format(self) -> None: - usage = stream_mod._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) - assert usage.input_tokens == 10 - assert usage.output_tokens == 20 - - async def test_v3_nested_format(self) -> None: - usage = stream_mod._parse_usage( - { - "inputTokens": { - "total": 100, - "cacheRead": 30, - "cacheWrite": 5, - }, - "outputTokens": {"total": 50, "reasoning": 10}, - } - ) - assert usage.input_tokens == 100 - assert usage.output_tokens == 50 - assert usage.cache_read_tokens == 30 - assert usage.cache_write_tokens == 5 - assert usage.reasoning_tokens == 10 - - async def test_non_dict_returns_empty(self) -> None: - usage = stream_mod._parse_usage("not a dict") - assert usage.input_tokens == 0 - assert usage.output_tokens == 0 diff --git a/tests/models2/core/__init__.py b/tests/models2/core/__init__.py deleted file mode 100644 index e69de29b..00000000 From 21290d136283ab3832652791787a719c65fa9bc6 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 6 Apr 2026 08:11:56 -0700 Subject: [PATCH 18/18] Add pyright to CI, bump the version --- .github/workflows/ci.yml | 1 + pyproject.toml | 3 +- .../models/ai_gateway/generate.py | 16 +++++--- tests/conftest.py | 38 +++++++++---------- tests/telemetry/test_otel_handler.py | 14 +++++-- uv.lock | 26 ++++++++++++- 6 files changed, 65 insertions(+), 33 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f05ecb49..773e9325 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,5 +21,6 @@ jobs: - run: uv run ruff format --check src tests - run: uv run ruff check src tests - run: uv run mypy src tests + - run: uv run pyright src tests - run: uv run pytest diff --git a/pyproject.toml b/pyproject.toml index ac5a0edd..0b21d8de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vercel-ai-sdk" -version = "0.0.1.dev8" +version = "0.0.1.dev9" description = "The AI Toolkit for Python" readme = "README.md" authors = [ @@ -30,6 +30,7 @@ dev = [ "mypy>=1.11", "ruff>=0.8", "opentelemetry-sdk>=1.0", + "pyright>=1.1.408", ] [tool.mypy] diff --git a/src/vercel_ai_sdk/models/ai_gateway/generate.py b/src/vercel_ai_sdk/models/ai_gateway/generate.py index 304c8bce..ab460b02 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/generate.py +++ b/src/vercel_ai_sdk/models/ai_gateway/generate.py @@ -33,10 +33,12 @@ class ImageParams(pydantic.BaseModel): n: int = 1 size: str | None = None - aspect_ratio: str | None = pydantic.Field(None, alias="aspectRatio") + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) seed: int | None = None provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, alias="providerOptions" + default_factory=dict, serialization_alias="providerOptions" ) @@ -46,13 +48,15 @@ class VideoParams(pydantic.BaseModel): model_config = _PARAMS_CONFIG n: int = 1 - aspect_ratio: str | None = pydantic.Field(None, alias="aspectRatio") + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) resolution: str | None = None duration: int | None = None fps: int | None = None seed: int | None = None provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, alias="providerOptions" + default_factory=dict, serialization_alias="providerOptions" ) @@ -102,7 +106,7 @@ async def _generate_image( output_tokens=usage_data.get("outputTokens") or 0, ) - files: list[messages_.FilePart] = [] + files: list[messages_.Part] = [] for img_b64 in raw_images: media_type = media_.detect_image_media_type(img_b64) or "image/png" files.append(messages_.FilePart(data=img_b64, media_type=media_type)) @@ -169,7 +173,7 @@ async def _generate_video( ) raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - files: list[messages_.FilePart] = [] + files: list[messages_.Part] = [] for video_data in raw_videos: vtype = video_data.get("type", "base64") media_type = video_data.get("mediaType", "video/mp4") diff --git a/tests/conftest.py b/tests/conftest.py index e10fef63..e949d1b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import AsyncGenerator, Sequence -from typing import Any +from typing import Any, Literal import pydantic @@ -106,13 +106,14 @@ def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: def text_msg( - text: str, *, id: str = "msg-1", state: str = "done", delta: str | None = None + text: str, + *, + id: str = "msg-1", + state: messages_.PartState | None = "done", + delta: str | None = None, ) -> messages_.Message: - return messages_.Message( - id=id, - role="assistant", - parts=[messages_.TextPart(text=text, state=state, delta=delta)], - ) + part: messages_.Part = messages_.TextPart(text=text, state=state, delta=delta) + return messages_.Message(id=id, role="assistant", parts=[part]) def tool_msg( @@ -121,20 +122,15 @@ def tool_msg( tc_id: str = "tc-1", name: str = "test_tool", args: str = "{}", - status: str = "pending", + status: Literal["pending", "result", "error"] = "pending", result: dict[str, object] | None = None, ) -> messages_.Message: - return messages_.Message( - id=id, - role="assistant", - parts=[ - messages_.ToolPart( - tool_call_id=tc_id, - tool_name=name, - tool_args=args, - status=status, - result=result, - state="done", - ) - ], + part: messages_.Part = messages_.ToolPart( + tool_call_id=tc_id, + tool_name=name, + tool_args=args, + status=status, + result=result, + state="done", ) + return messages_.Message(id=id, role="assistant", parts=[part]) diff --git a/tests/telemetry/test_otel_handler.py b/tests/telemetry/test_otel_handler.py index 3304f38d..a30dcf7c 100644 --- a/tests/telemetry/test_otel_handler.py +++ b/tests/telemetry/test_otel_handler.py @@ -50,8 +50,11 @@ async def test_text_only_spans(spans: InMemorySpanExporter) -> None: stream_span = next(s for s in finished if s.name == "ai.stream") # ai.stream is a child of ai.run - assert stream_span.parent is not None - assert stream_span.parent.span_id == run_span.context.span_id + stream_parent = stream_span.parent + assert stream_parent is not None + run_ctx = run_span.context + assert run_ctx is not None + assert stream_parent.span_id == run_ctx.span_id # run_id attribute is set assert run_span.attributes is not None @@ -84,5 +87,8 @@ async def test_tool_call_spans(spans: InMemorySpanExporter) -> None: # ai.tool is a child of ai.run (tools execute between steps) run_span = next(s for s in finished if s.name == "ai.run") - assert tool_span.parent is not None - assert tool_span.parent.span_id == run_span.context.span_id + tool_parent = tool_span.parent + assert tool_parent is not None + run_ctx = run_span.context + assert run_ctx is not None + assert tool_parent.span_id == run_ctx.span_id diff --git a/uv.lock b/uv.lock index e811f862..a79e345c 100644 --- a/uv.lock +++ b/uv.lock @@ -536,6 +536,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "openai" version = "2.14.0" @@ -754,6 +763,19 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pyright" +version = "1.1.408" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" }, +] + [[package]] name = "pytest" version = "9.0.2" @@ -1049,7 +1071,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev8" +version = "0.0.1.dev9" source = { editable = "." } dependencies = [ { name = "anthropic" }, @@ -1065,6 +1087,7 @@ dependencies = [ dev = [ { name = "mypy" }, { name = "opentelemetry-sdk" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "python-dotenv" }, @@ -1087,6 +1110,7 @@ requires-dist = [ dev = [ { name = "mypy", specifier = ">=1.11" }, { name = "opentelemetry-sdk", specifier = ">=1.0" }, + { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.0" }, { name = "pytest-asyncio", specifier = ">=0.24" }, { name = "python-dotenv", specifier = ">=1.2.1" },