diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index f05ecb49..773e9325 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,5 +21,6 @@ jobs: - run: uv run ruff format --check src tests - run: uv run ruff check src tests - run: uv run mypy src tests + - run: uv run pyright src tests - run: uv run pytest diff --git a/CLAUDE.md b/CLAUDE.md index acb130e2..0605ce5b 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,4 +1,35 @@ +# ai + +## development guidelines + 1. use `uv` to manage the project; `uv add` and `uv remove` to manage dependencies, `uv run` to run 2. after making changes run lint and typecheck: `uv run ruff check --fix src tests` and `uv run mypy src tests` 3. import by module (except `typing`) to improve readability via namespacing 4. treat `stream_step` and `stream_loop` as user code. they are convenience functions that could be reimplemented by the user, they *must* stay clean. + +## design principles + +### 1. maximize composability + +provide simple lego bricks that the user can build their feature with. each block should do one thing and be reasonably decoupled from the rest. +expose correct primitives to make it easy to modify behavior without rewriting it from scratch. + +- *example*: `agents` module provides `@ai.stream`, `@ai.tool` and `@ai.hook` that can be combined into an arbitrarily complex agent graph using plain python. +- *can the user rewrite this feature in plain python using the existing primitives?* + +### 2. minimize dsl-ness and frameworkiness + +express features in a way that doesn't require the user to read documentation and learn the framework. glue things together using python. +handle complexity inside the framework instead of delegating it to users. + +- *example*: `Runtime` does the heavy lifting so that multi-agent graphs can be expressed using python `asyncio`. +- *does this require the user to learn a framework-specific concept that has a direct python equivalent?* + +### 3. keep data model simple + +ensure state is easy to serialize and deserialize, modify, and compose at any level of granularity. +move normalization and translation complexity inside the framework and keep the public data model minimal. + +- *example*: public data model consists of a single unified `Message` type. the framework does not expose events and other intermediate steps unless the user is writing a custom adapter. + + diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 023a66d2..7250fe19 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -16,10 +16,11 @@ async def talk_to_mothership(question: str) -> str: return f"Mothership says: {question} -> Soon." -def get_llm() -> ai.LanguageModel: - """Create the LLM instance.""" - return ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - +MODEL = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", +) TOOLS: list[ai.Tool[..., Any]] = [talk_to_mothership] @@ -43,10 +44,17 @@ async def _execute_with_approval( tc.set_error("Tool call was denied by the user.") +chat_agent = ai.agent( + model=MODEL, + system="", + tools=TOOLS, +) + + +@chat_agent.loop async def graph( - llm: ai.LanguageModel, + agent: ai.Agent, messages: list[ai.Message], - tools: list[ai.Tool[..., Any]], ) -> ai.StreamResult: """Agent graph with human-in-the-loop tool approval. @@ -58,7 +66,7 @@ async def graph( local_messages = list(messages) while True: - result = await ai.stream_step(llm, local_messages, tools) + result = await ai.stream_step(agent.model, local_messages, agent.tools) if not result.tool_calls: return result diff --git a/examples/fastapi-vite/backend/main.py b/examples/fastapi-vite/backend/main.py index 9107adb1..0d3c31b0 100644 --- a/examples/fastapi-vite/backend/main.py +++ b/examples/fastapi-vite/backend/main.py @@ -51,20 +51,12 @@ async def chat(request: ChatRequest) -> fastapi.responses.StreamingResponse: session_id = request.session_id or "default" checkpoint_key = f"checkpoint:{session_id}" - llm = agent.get_llm() - checkpoint = None saved = await file_storage.get(checkpoint_key) if saved: checkpoint = ai.Checkpoint.model_validate(saved) - result = ai.run( - agent.graph, - llm, - messages, - agent.TOOLS, - checkpoint=checkpoint, - ) + result = agent.chat_agent.run(messages, checkpoint=checkpoint) async def stream_response() -> AsyncGenerator[str]: async for chunk in ai.ai_sdk_ui.to_sse_stream(result): diff --git a/examples/models/buffer.py b/examples/models/buffer.py new file mode 100644 index 00000000..4020affd --- /dev/null +++ b/examples/models/buffer.py @@ -0,0 +1,32 @@ +"""Buffered response — drain the stream, get the final message.""" + +import asyncio + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="What is 2 + 2?")], + ), +] + + +async def main() -> None: + result = await m.buffer(m.stream(model, messages)) + print(result.text) + if result.usage: + print( + f"tokens: {result.usage.input_tokens} in, {result.usage.output_tokens} out" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/direct_adapter.py b/examples/models/direct_adapter.py new file mode 100644 index 00000000..df386a0b --- /dev/null +++ b/examples/models/direct_adapter.py @@ -0,0 +1,42 @@ +"""Direct adapter call — bypass the registry, call the adapter function directly.""" + +import asyncio +import os + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.models import ai_gateway as ai_gateway_v3 +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +client = m.Client( + base_url="https://ai-gateway.vercel.sh/v3/ai", + api_key=os.environ["AI_GATEWAY_API_KEY"], +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Say hello in three languages.")], + ), +] + + +async def main() -> None: + # Call the adapter function directly — no registry lookup, no auto-client. + # This is the lowest level of the API. + try: + async for msg in ai_gateway_v3.stream(client, model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + finally: + await client.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/explicit_client.py b/examples/models/explicit_client.py new file mode 100644 index 00000000..6c3d7c6e --- /dev/null +++ b/examples/models/explicit_client.py @@ -0,0 +1,41 @@ +"""Explicit client — bring your own auth and base URL.""" + +import asyncio +import os + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +# Explicit client — useful for custom auth, proxies, or self-hosted gateways. +client = m.Client( + base_url="https://ai-gateway.vercel.sh/v3/ai", + api_key=os.environ["AI_GATEWAY_API_KEY"], + headers={"X-Custom-Header": "example"}, +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Hello!")], + ), +] + + +async def main() -> None: + try: + async for msg in m.stream(model, messages, client=client): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + finally: + await client.aclose() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/image_generation.py b/examples/models/image_generation.py new file mode 100644 index 00000000..63b70d7d --- /dev/null +++ b/examples/models/image_generation.py @@ -0,0 +1,52 @@ +"""Image generation — dedicated image model via generate().""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="google/imagen-4.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("image",), +) + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "Anime girl with twin tails and cat ears, wearing a " + "sailor school uniform, striking a victory pose in front " + "of a futuristic Tokyo skyline at night, neon lights " + "reflecting in her eyes, digital art style" + ) + ), + ], + ), +] + + +async def main() -> None: + result = await m.generate(model, messages, m.ImageParams(n=2, aspect_ratio="16:9")) + + print(f"Generated {len(result.images)} image(s)") + for i, img in enumerate(result.images): + filename = f"generated_{i}.png" + data = img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {img.media_type}, {len(data)} bytes") + + if result.usage: + print( + f"Usage: {result.usage.input_tokens} input, " + f"{result.usage.output_tokens} output tokens" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/inline_image.py b/examples/models/inline_image.py new file mode 100644 index 00000000..91777e87 --- /dev/null +++ b/examples/models/inline_image.py @@ -0,0 +1,74 @@ +"""Inline image generation — LLM that outputs images alongside text. + +Models like Gemini 3 Pro Image can generate images as part of their +language model response. The images arrive as FileParts in the streamed +Message. +""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +# This is a language model that can also output images inline. +model = m.Model( + id="google/gemini-3-pro-image", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("text", "image"), +) + +messages = [ + messages_.Message( + role="system", + parts=[ + messages_.TextPart( + text=( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style." + ) + ), + ], + ), + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "Draw an anime girl with long silver hair and violet eyes, " + "sitting in a field of cherry blossoms at sunset." + ) + ), + ], + ), +] + + +async def main() -> None: + last_msg: messages_.Message | None = None + + # Stream — text deltas arrive as usual, images arrive as FileParts + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + last_msg = msg + + print() + + # Check for images in the final message + if last_msg and last_msg.images: + for i, img in enumerate(last_msg.images): + filename = f"inline_{i}.png" + data = ( + img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + ) + pathlib.Path(filename).write_bytes(data) + print(f"Saved {filename} ({img.media_type}, {len(data)} bytes)") + else: + print("No images were generated in this response.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/multimodal_input.py b/examples/models/multimodal_input.py new file mode 100644 index 00000000..f5a11a14 --- /dev/null +++ b/examples/models/multimodal_input.py @@ -0,0 +1,38 @@ +"""Multimodal input — send a local image to the model and ask about it.""" + +import asyncio +import pathlib + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +# Load a local image file (replace with your own path) +image_path = pathlib.Path("sample_image.jpg") +image_data = image_path.read_bytes() + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart(text="Describe this image in detail."), + messages_.FilePart(data=image_data, media_type="image/jpeg"), + ], + ), +] + + +async def main() -> None: + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/stream.py b/examples/models/stream.py new file mode 100644 index 00000000..1183fb05 --- /dev/null +++ b/examples/models/stream.py @@ -0,0 +1,33 @@ +"""Basic streaming — print text deltas as they arrive.""" + +import asyncio + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +messages = [ + messages_.Message(role="system", parts=[messages_.TextPart(text="Be concise.")]), + messages_.Message( + role="user", + parts=[ + messages_.TextPart(text="Explain why the sky is blue in two sentences.") + ], + ), +] + + +async def main() -> None: + async for msg in m.stream(model, messages): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/structured_output.py b/examples/models/structured_output.py new file mode 100644 index 00000000..172d7201 --- /dev/null +++ b/examples/models/structured_output.py @@ -0,0 +1,45 @@ +"""Structured output — get validated JSON from the model.""" + +import asyncio + +import pydantic + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + + +class Recipe(pydantic.BaseModel): + name: str + ingredients: list[str] + steps: list[str] + prep_time_minutes: int + + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="Give me a simple pancake recipe.")], + ), +] + + +async def main() -> None: + # Stream with structured output — watch JSON arrive, get validated at the end + async for msg in m.stream(model, messages, output_type=Recipe): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + if msg.output: + recipe: Recipe = msg.output + print(f"\n\nParsed recipe: {recipe.name}") + print(f" Ingredients: {', '.join(recipe.ingredients)}") + print(f" Prep time: {recipe.prep_time_minutes} min") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/tools.py b/examples/models/tools.py new file mode 100644 index 00000000..3e3c5d81 --- /dev/null +++ b/examples/models/tools.py @@ -0,0 +1,50 @@ +"""Tools — pass tool schemas to the model.""" + +import asyncio + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ +from vercel_ai_sdk.types import tools as tools_ + +model = m.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + +# Define a tool schema — anything matching the ToolLike protocol works. +get_weather = tools_.ToolSchema( + name="get_weather", + description="Get the current weather for a city.", + param_schema={ + "type": "object", + "properties": { + "city": {"type": "string", "description": "The city name"}, + }, + "required": ["city"], + }, + return_type=str, +) + +messages = [ + messages_.Message( + role="user", + parts=[messages_.TextPart(text="What's the weather in Tokyo?")], + ), +] + + +async def main() -> None: + # Stream with tools — the model may emit tool calls + async for msg in m.stream(model, messages, tools=[get_weather]): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + + for tc in msg.tool_calls: + if tc.state == "done": + print(f"\nTool call: {tc.tool_name}({tc.tool_args})") + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/models/video_generation.py b/examples/models/video_generation.py new file mode 100644 index 00000000..b5f5c8d3 --- /dev/null +++ b/examples/models/video_generation.py @@ -0,0 +1,53 @@ +"""Video generation — dedicated video model via generate().""" + +import asyncio +import base64 +import pathlib + +from vercel_ai_sdk import models as m +from vercel_ai_sdk.types import messages as messages_ + +model = m.Model( + id="google/veo-3.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("video",), +) + +messages = [ + messages_.Message( + role="user", + parts=[ + messages_.TextPart( + text=( + "An anime girl with long pink hair and a flowing white " + "dress stands on a hilltop at golden hour. A warm breeze " + "lifts her hair as she releases a paper lantern into the " + "sunset sky. Soft cel-shaded anime art style, warm palette." + ) + ), + ], + ), +] + + +async def main() -> None: + print("Generating video (this may take a minute or two)...") + + result = await m.generate( + model, + messages, + m.VideoParams(aspect_ratio="16:9", duration=8), + ) + + print(f"Generated {len(result.videos)} video(s)") + for i, vid in enumerate(result.videos): + ext = "mp4" if "mp4" in vid.media_type else "webm" + filename = f"generated_{i}.{ext}" + data = vid.data if isinstance(vid.data, bytes) else base64.b64decode(vid.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {vid.media_type}, {len(data)} bytes") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index fd87276a..a49c8602 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -54,20 +54,39 @@ class Approval(pydantic.BaseModel): # --------------------------------------------------------------------------- -# Sub-agent branches +# Model # --------------------------------------------------------------------------- +MODEL = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", +) -async def mothership_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: + +# --------------------------------------------------------------------------- +# Sub-agent branches (implemented as custom loops on per-branch agents) +# --------------------------------------------------------------------------- + + +mothership_agent = ai.agent( + model=MODEL, + system="You are assistant 1. Use contact_mothership when asked about the future.", + tools=[contact_mothership], +) + + +@mothership_agent.loop +async def mothership_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Agent that contacts the mothership, gated by an approval hook.""" - messages = ai.make_messages( - system="You are assistant 1. Use contact_mothership when asked about the future.", - user=query, - ) - tools = [contact_mothership] + local_messages = list(messages) while True: - result = await ai.stream_step(llm, messages, tools, label="mothership") + result = await ai.stream_step( + agent.model, local_messages, agent.tools, label="mothership" + ) if not result.tool_calls: break @@ -89,21 +108,29 @@ async def mothership_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResul await ai.execute_tool(tc, message=result.last_message) if result.last_message is not None: - messages.append(result.last_message) + local_messages.append(result.last_message) return result -async def data_center_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResult: +data_center_agent = ai.agent( + model=MODEL, + system="You are assistant 2. Use contact_data_centers when asked about the future.", + tools=[contact_data_centers], +) + + +@data_center_agent.loop +async def data_center_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Agent that contacts data centers, gated by an approval hook.""" - messages = ai.make_messages( - system="You are assistant 2. Use contact_data_centers when asked about the future.", - user=query, - ) - tools = [contact_data_centers] + local_messages = list(messages) while True: - result = await ai.stream_step(llm, messages, tools, label="data_centers") + result = await ai.stream_step( + agent.model, local_messages, agent.tools, label="data_centers" + ) if not result.tool_calls: break @@ -125,34 +152,42 @@ async def data_center_branch(llm: ai.LanguageModel, query: str) -> ai.StreamResu await ai.execute_tool(tc, message=result.last_message) if result.last_message is not None: - messages.append(result.last_message) + local_messages.append(result.last_message) return result # --------------------------------------------------------------------------- -# Graph — fan-out, hooks, fan-in +# Orchestrator — fan-out, hooks, fan-in # --------------------------------------------------------------------------- -async def multiagent(llm: ai.LanguageModel, query: str) -> ai.StreamResult: +orchestrator = ai.agent(model=MODEL) + + +@orchestrator.loop +async def multiagent_loop( + agent: ai.Agent, messages: list[ai.Message] +) -> ai.StreamResult: """Run two gated agents in parallel, then summarise their results.""" + query = messages[-1].text + + # Fan out: run both sub-agent loops within this runtime r1, r2 = await asyncio.gather( - mothership_branch(llm, query), - data_center_branch(llm, query), + mothership_loop(mothership_agent, ai.make_messages(user=query)), + data_center_loop(data_center_agent, ai.make_messages(user=query)), ) combined = ( f"Mothership: {r1.messages[-1].text}\nData centers: {r2.messages[-1].text}" ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( + return await ai.stream_step( + agent.model, + ai.make_messages( system="You are assistant 3. Summarise the results from the other assistants.", user=combined, ), - tools=[], label="summary", ) @@ -180,9 +215,7 @@ async def ws_endpoint(websocket: fastapi.WebSocket) -> None: await websocket.accept() print("Client connected") - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - result = ai.run(multiagent, llm, "When will the robots take over?") + result = orchestrator.run(ai.make_messages(user="When will the robots take over?")) # Background task: read hook resolutions from the client. async def read_resolutions() -> None: diff --git a/examples/samples/custom_loop.py b/examples/samples/custom_loop.py index 9c6d9622..923768fd 100644 --- a/examples/samples/custom_loop.py +++ b/examples/samples/custom_loop.py @@ -23,50 +23,60 @@ async def get_population(city: str) -> int: @ai.stream async def custom_stream_step( - llm: ai.LanguageModel, + model: ai.Model, messages: list[ai.Message], tools: list[ai.Tool[..., Any]], label: str | None = None, ) -> AsyncGenerator[ai.Message]: - """Wraps llm.stream to inject a label on every message.""" - async for msg in llm.stream(messages=messages, tools=tools): + """Wraps models.stream to inject a label on every message.""" + async for msg in ai.models.stream(model, messages, tools=tools): msg.label = label yield msg -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Custom agent loop with manual tool execution. +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) - Uses @ai.stream for custom streaming, stream_step-style while loop, - and asyncio.gather for parallel tool execution. - """ - tools = [get_weather, get_population] - messages = ai.make_messages( + my_agent = ai.agent( + model=model, system="Answer questions using the weather and population tools.", - user=user_query, + tools=[get_weather, get_population], ) - while True: - result = await custom_stream_step(llm, messages, tools, label="agent") + @my_agent.loop + async def custom(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: + """Custom agent loop with manual tool execution. - if not result.tool_calls: - return result + Uses @ai.stream for custom streaming and + asyncio.gather for parallel tool execution. + """ + local_messages = list(messages) - if result.last_message is not None: - messages.append(result.last_message) - await asyncio.gather( - *( - ai.execute_tool(tc, message=result.last_message) - for tc in result.tool_calls + while True: + result = await custom_stream_step( + agent.model, local_messages, agent.tools, label="agent" ) - ) + if not result.tool_calls: + return result -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + if result.last_message is not None: + local_messages.append(result.last_message) + await asyncio.gather( + *( + ai.execute_tool(tc, message=result.last_message) + for tc in result.tool_calls + ) + ) - async for msg in ai.run( - agent, llm, "What's the weather and population of New York and Los Angeles?" + async for msg in my_agent.run( + ai.make_messages( + user="What's the weather and population of New York and Los Angeles?" + ) ): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 4b11bc3f..33bbe282 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -19,46 +19,56 @@ class CommunicationApproval(pydantic.BaseModel): reason: str -async def graph(llm: ai.LanguageModel, query: str) -> ai.StreamResult: - messages = ai.make_messages( - system="Use the contact_mothership tool when asked about the future.", - user=query, +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) - tools = [contact_mothership] - - while True: - result = await ai.stream_step(llm, messages, tools) - - if not result.tool_calls: - break - - for tc in result.tool_calls: - if tc.tool_name == "contact_mothership": - # Blocks until resolved (long-running) or cancelled (serverless) - # TODO: mypy doesn't support class decorators that change the - # class type — @ai.hook returns type[Hook[T]] but mypy still - # sees the original BaseModel. - approval = await CommunicationApproval.create( # type: ignore[attr-defined] - f"approve_{tc.tool_call_id}", - metadata={"tool": tc.tool_name}, - ) - if approval.granted: - await ai.execute_tool(tc, message=result.last_message) - else: - tc.set_error(f"Rejected: {approval.reason}") - else: - await ai.execute_tool(tc, message=result.last_message) - if result.last_message is not None: - messages.append(result.last_message) + my_agent = ai.agent( + model=model, + system="Use the contact_mothership tool when asked about the future.", + tools=[contact_mothership], + ) - return result + @my_agent.loop + async def with_approval( + agent: ai.Agent, messages: list[ai.Message] + ) -> ai.StreamResult: + local_messages = list(messages) + + while True: + result = await ai.stream_step(agent.model, local_messages, agent.tools) + + if not result.tool_calls: + break + + for tc in result.tool_calls: + if tc.tool_name == "contact_mothership": + # Blocks until resolved (long-running) or cancelled (serverless) + # TODO: mypy doesn't support class decorators that change the + # class type — @ai.hook returns type[Hook[T]] but mypy still + # sees the original BaseModel. + approval = await CommunicationApproval.create( # type: ignore[attr-defined] + f"approve_{tc.tool_call_id}", + metadata={"tool": tc.tool_name}, + ) + if approval.granted: + await ai.execute_tool(tc, message=result.last_message) + else: + tc.set_error(f"Rejected: {approval.reason}") + else: + await ai.execute_tool(tc, message=result.last_message) + if result.last_message is not None: + local_messages.append(result.last_message) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + return result - async for msg in ai.run(graph, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): # Hook parts arrive as pending, waiting for resolution if (hook := msg.get_hook_part()) and hook.status == "pending": answer = input(f"Approve {hook.hook_id}? [y/n] ") diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index 1a2f58d1..aa020560 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -9,8 +9,12 @@ import vercel_ai_sdk as ai -async def context7_agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Agent with Context7 MCP tools for up-to-date library documentation.""" +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) context7_tools: list[ai.Tool[..., Any]] = await ai.mcp.get_http_tools( "https://mcp.context7.com/mcp", @@ -18,22 +22,14 @@ async def context7_agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamRes tool_prefix="context7", ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are a helpful assistant. Use context7 to look up documentation.", - user=user_query, - ), + my_agent = ai.agent( + model=model, + system="You are a helpful assistant. Use context7 to look up documentation.", tools=context7_tools, - label="context7", ) - -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run( - context7_agent, llm, "How do I create middleware in Next.js?" + async for msg in my_agent.run( + ai.make_messages(user="How do I create middleware in Next.js?") ): rich.print(msg) diff --git a/examples/samples/media/image_gen_inline.py b/examples/samples/media/image_gen_inline.py index c23b94fc..190ef936 100644 --- a/examples/samples/media/image_gen_inline.py +++ b/examples/samples/media/image_gen_inline.py @@ -15,24 +15,13 @@ import vercel_ai_sdk as ai -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system=( - "You are an anime art assistant. When asked to draw or create " - "an image, generate it in a soft pastel anime style with " - "detailed backgrounds and expressive characters." - ), - user=user_query, - ), - tools=[], - ) - - async def main() -> None: # Gemini 3 Pro Image is a language model that can output images inline - llm = ai.ai_gateway.GatewayModel(model="google/gemini-3-pro-image") + model = ai.Model( + id="google/gemini-3-pro-image", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) prompt = ( "Draw an anime girl with long silver hair and violet eyes, " @@ -40,7 +29,17 @@ async def main() -> None: "She's wearing a traditional kimono and reading a book." ) - async for msg in ai.run(agent, llm, prompt): + my_agent = ai.agent( + model=model, + system=( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style with " + "detailed backgrounds and expressive characters." + ), + tools=[], + ) + + async for msg in my_agent.run(ai.make_messages(user=prompt)): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/media/multimodal.py b/examples/samples/media/multimodal.py index cad74e55..2f2348ce 100644 --- a/examples/samples/media/multimodal.py +++ b/examples/samples/media/multimodal.py @@ -13,26 +13,26 @@ ) -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=[ +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) + + my_agent = ai.agent(model=model, tools=[]) + + async for msg in my_agent.run( + [ ai.Message( role="user", parts=[ - ai.TextPart(text=user_query), + ai.TextPart(text="What's in this image? Be concise."), ai.FilePart.from_url(IMAGE_URL), ], ) - ], - tools=[], - ) - - -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "What's in this image? Be concise."): + ] + ): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/multiagent.py b/examples/samples/multiagent.py index efb042d3..aaf8bc8e 100644 --- a/examples/samples/multiagent.py +++ b/examples/samples/multiagent.py @@ -15,47 +15,60 @@ async def multiply_by_two(number: int) -> int: return number * 2 -async def multiagent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Run two agents in parallel, then combine their results.""" - - result1, result2 = await asyncio.gather( - ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are assistant 1. Use your tool on the number.", - user=user_query, - ), - tools=[add_one], - label="a1", - ), - ai.stream_loop( - llm, - messages=ai.make_messages( - system="You are assistant 2. Use your tool on the number.", - user=user_query, - ), - tools=[multiply_by_two], - label="a2", - ), +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) - combined = f"{result1.messages[-1].text}\n{result2.messages[-1].text}" + agent1 = ai.agent( + model=model, + system="You are assistant 1. Use your tool on the number.", + tools=[add_one], + ) - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Summarize the results from the other assistants.", - user=combined, - ), - tools=[], - label="summary", + agent2 = ai.agent( + model=model, + system="You are assistant 2. Use your tool on the number.", + tools=[multiply_by_two], ) + orchestrator = ai.agent(model=model) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + @orchestrator.loop + async def multi(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: + """Run two sub-agents in parallel, then summarize.""" + user_query = messages[-1].text + + # Sub-agents run their loops within the same runtime + result1, result2 = await asyncio.gather( + ai.stream_step( + agent1.model, + ai.make_messages(system=agent1.system, user=user_query), + agent1.tools, + label="a1", + ), + ai.stream_step( + agent2.model, + ai.make_messages(system=agent2.system, user=user_query), + agent2.tools, + label="a2", + ), + ) + + combined = f"{result1.text}\n{result2.text}" + + return await ai.stream_step( + agent.model, + ai.make_messages( + system="Summarize the results from the other assistants.", + user=combined, + ), + label="summary", + ) - async for msg in ai.run(multiagent, llm, "Process the number 5"): + async for msg in orchestrator.run(ai.make_messages(user="Process the number 5")): if msg.text_delta: prefix = f"[{msg.label}] " if msg.label else "" print(f"{prefix}{msg.text_delta}", end="", flush=True) diff --git a/examples/samples/simple.py b/examples/samples/simple.py index ca8f09c4..afd0f676 100644 --- a/examples/samples/simple.py +++ b/examples/samples/simple.py @@ -8,21 +8,22 @@ async def talk_to_mothership(question: str) -> str: return "Soon." -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Start every response with 'You are absolutely right!'", - user=user_query, - ), - tools=[talk_to_mothership], +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) + my_agent = ai.agent( + model=model, + system="Start every response with 'You are absolutely right!'", + tools=[talk_to_mothership], + ) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index bae3930c..109d5368 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -21,21 +21,22 @@ async def talk_to_mothership(question: str, runtime: ai.Runtime) -> str: return "The mothership says: Soon." -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages( - system="Use the mothership tool when asked about the future.", - user=user_query, - ), - tools=[talk_to_mothership], +async def main() -> None: + model = ai.Model( + id="anthropic/claude-opus-4.6", + adapter="ai-gateway-v3", + provider="ai-gateway", ) + my_agent = ai.agent( + model=model, + system="Use the mothership tool when asked about the future.", + tools=[talk_to_mothership], + ) -async def main() -> None: - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - - async for msg in ai.run(agent, llm, "When will the robots take over?"): + async for msg in my_agent.run( + ai.make_messages(user="When will the robots take over?") + ): if msg.label == "tool_progress": print(f" [{msg.text}]") elif msg.text_delta: diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index c09f1eec..571792d4 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -1,10 +1,16 @@ -"""Temporal workflow — the durable agent loop.""" +"""Temporal workflow — the durable agent loop. + +NOTE: This example still uses the old models.LanguageModel ABC because +it wraps Temporal activities as a custom model. When the models layer +is fully migrated to models, this will need a custom adapter instead. +""" from __future__ import annotations +import asyncio import datetime from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence -from typing import override +from typing import Any, override import pydantic import temporalio.common @@ -16,7 +22,7 @@ import vercel_ai_sdk as ai -class DurableModel(ai.LanguageModel): +class DurableModel(ai.models.LanguageModel): def __init__( self, call_fn: Callable[ @@ -76,15 +82,45 @@ async def get_population(city: str) -> int: # ── Agent ──────────────────────────────────────────────────────── +# +# TODO: This example uses the old LanguageModel ABC and ai.run() / +# ai.stream_loop free-function patterns. Once the models layer is +# migrated, convert to use ai.agent() + models.Model with a custom +# adapter for Temporal activity-based LLM calls. + +async def agent(llm: Any, user_query: str) -> ai.StreamResult: + """Agent loop — uses old-style stream_loop via models.LanguageModel. -async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: - """Agent loop — identical to the non-Temporal version.""" + This is a transitional pattern. The old ai.stream_loop and ai.run + are no longer part of the public API. This example needs a custom + models adapter to work with the new Agent API. + """ messages = ai.make_messages( system="Answer questions using the weather and population tools.", user=user_query, ) - return await ai.stream_loop(llm, messages, [get_weather, get_population]) + + # Manually implement the loop since we can't use Agent with LanguageModel + tools = [get_weather, get_population] + local_messages = list(messages) + + while True: + result_messages: list[ai.Message] = [] + async for msg in llm.stream(local_messages, tools=tools): + result_messages.append(msg) + result = ai.StreamResult(messages=result_messages) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *(ai.execute_tool(tc, message=last_msg) for tc in result.tool_calls) + ) # ── Workflow ───────────────────────────────────────────────────── @@ -103,8 +139,12 @@ async def run(self, user_query: str) -> str: ) ) + # TODO: This uses the old free-function pattern. Once models + # supports custom adapters for Temporal, use Agent.run() instead. + from vercel_ai_sdk.agents import run + final_text = "" - async for msg in ai.run(agent, llm, user_query): + async for msg in run(agent, llm, user_query): if msg.text: final_text = msg.text return final_text diff --git a/pyproject.toml b/pyproject.toml index ac5a0edd..0b21d8de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "vercel-ai-sdk" -version = "0.0.1.dev8" +version = "0.0.1.dev9" description = "The AI Toolkit for Python" readme = "README.md" authors = [ @@ -30,6 +30,7 @@ dev = [ "mypy>=1.11", "ruff>=0.8", "opentelemetry-sdk>=1.0", + "pyright>=1.1.408", ] [tool.mypy] diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index d0bdbe1b..5dcc0c59 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -1,35 +1,31 @@ -from . import adapters, telemetry +from . import adapters, models, telemetry from .adapters import ai_sdk_ui from .agents import ( + Agent, + AgentRun, Checkpoint, + Context, Hook, HookInfo, + LoopFn, PendingHookInfo, RunResult, Runtime, StreamResult, Tool, ToolApproval, + ToolSource, + agent, execute_tool, get_checkpoint, + get_context, hook, mcp, - run, stream, - stream_loop, stream_step, tool, ) -from .models import ( - ImageModel, - LanguageModel, - MediaModel, - MediaResult, - VideoModel, - ai_gateway, - anthropic, - openai, -) +from .models import Client, Model, ModelCost # Re-export core types from .types import ( @@ -66,35 +62,41 @@ "Usage", "make_messages", # Models (from models/) - "LanguageModel", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", - # Agents (from agents/) + "Model", + "ModelCost", + "Client", + "models", + # Agents — primary API + "Agent", + "AgentRun", + "agent", + "LoopFn", + # Agents — composition primitives + "stream_step", + "execute_tool", + "get_checkpoint", + "stream", + "StreamResult", + # Agents — tools "Tool", + "tool", + # Agents — hooks + "Hook", + "hook", + "ToolApproval", + # Agents — context + "Context", + "ToolSource", + "get_context", + # Agents — runtime (developer API) "Runtime", "RunResult", "HookInfo", - "StreamResult", - "Hook", - "ToolApproval", + # Agents — checkpoint "Checkpoint", "PendingHookInfo", - # Functions (from agents/) - "tool", - "stream", - "stream_step", - "stream_loop", - "execute_tool", - "get_checkpoint", - "run", - "hook", # Submodules "telemetry", - "ai_gateway", - "anthropic", - "openai", "mcp", "ai_sdk_ui", "adapters", diff --git a/src/vercel_ai_sdk/agents/__init__.py b/src/vercel_ai_sdk/agents/__init__.py index d33640a2..d7a62b0f 100644 --- a/src/vercel_ai_sdk/agents/__init__.py +++ b/src/vercel_ai_sdk/agents/__init__.py @@ -5,31 +5,44 @@ """ from . import mcp +from .agent import Agent, AgentRun, LoopFn, agent, stream_step from .checkpoint import Checkpoint, PendingHookInfo +from .context import Context, ToolSource, get_context from .hooks import Hook, ToolApproval, hook from .runtime import ( + EventLog, HookInfo, + LoopExecutor, RunResult, Runtime, execute_tool, get_checkpoint, run, - stream_loop, - stream_step, ) from .streams import StreamResult, stream from .tools import Tool, ToolLike, ToolSchema, get_tool, tool __all__ = [ - # Core loop - "run", + # Agent (primary user API) + "Agent", + "AgentRun", + "agent", + "LoopFn", + # Composition primitives "stream_step", - "stream_loop", "execute_tool", "get_checkpoint", + # Context + "Context", + "ToolSource", + "get_context", + # Runtime (developer API) "Runtime", + "EventLog", + "LoopExecutor", "RunResult", "HookInfo", + "run", # Stream "stream", "StreamResult", diff --git a/src/vercel_ai_sdk/agents/agent.py b/src/vercel_ai_sdk/agents/agent.py new file mode 100644 index 00000000..1a5b1ec7 --- /dev/null +++ b/src/vercel_ai_sdk/agents/agent.py @@ -0,0 +1,265 @@ +"""Agent — the primary user-facing API. + +Bundles model, system prompt, and tools into a reusable, composable +unit. Provides a default tool-calling loop and a decorator to +override it. + +Usage:: + + agent = ai.agent( + model=my_model, + system="You are a helpful assistant.", + tools=[get_weather, get_population], + ) + + # stream messages + async for msg in agent.run(messages): + print(msg.text_delta, end="") + + # or collect the final result + result = await agent.run(messages).collect() + print(result.text) +""" + +from __future__ import annotations + +import asyncio +from collections.abc import AsyncGenerator, Awaitable, Callable, Sequence +from typing import Any + +import pydantic + +from .. import models +from ..types import messages as messages_ +from . import checkpoint as checkpoint_ +from . import context as context_ +from . import runtime as runtime_ +from . import streams as streams_ +from . import tools as tools_ + +# ── Types ───────────────────────────────────────────────────────── + +LoopFn = Callable[ + ["Agent", list[messages_.Message]], Awaitable[streams_.StreamResult | None] +] + + +# ── Composition primitives ──────────────────────────────────────── + + +@streams_.stream +async def stream_step( + model: models.Model, + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + label: str | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Single LLM call that streams into the Runtime queue. + + This is a composition primitive for custom ``@agent.loop`` + functions and multi-agent orchestration. It is decorated with + ``@stream``, so each call becomes a replayable step in the + event log. + """ + async for msg in models.stream( + model, messages, tools=tools, output_type=output_type, **kwargs + ): + msg.label = label + yield msg + + +# ── AgentRun ────────────────────────────────────────────────────── + + +class AgentRun: + """Returned by ``agent.run()``. Async-iterate for messages, then + inspect post-run state. + + Usage:: + + run = agent.run(messages) + + # streaming + async for msg in run: + print(msg.text_delta, end="") + run.checkpoint # checkpoint after iteration + run.pending_hooks # unresolved hooks (empty if completed) + + # non-streaming + result = await agent.run(messages).collect() + print(result.text) + """ + + def __init__(self, inner: runtime_.RunResult) -> None: + self._inner = inner + + async def __aiter__(self) -> AsyncGenerator[messages_.Message]: + async for msg in self._inner: + yield msg + + async def collect(self) -> streams_.StreamResult: + """Drain the stream and return a :class:`StreamResult`.""" + msgs: list[messages_.Message] = [] + async for msg in self._inner: + msgs.append(msg) + return streams_.StreamResult(messages=msgs) + + @property + def checkpoint(self) -> checkpoint_.Checkpoint: + return self._inner.checkpoint + + @property + def pending_hooks(self) -> dict[str, runtime_.HookInfo]: + return self._inner.pending_hooks + + +# ── Agent ───────────────────────────────────────────────────────── + + +class Agent: + """An agent — bundles model, system prompt, tools, and loop logic. + + Create via :func:`agent`:: + + weather = ai.agent( + model=my_model, + system="Answer questions about weather.", + tools=[get_weather], + ) + + Tools default to all globally registered tools when ``None`` + (the default). Pass ``tools=[]`` to explicitly disable tools. + + Override the default tool-calling loop with ``@agent.loop``:: + + @weather.loop + async def custom(agent, messages): + ... + """ + + def __init__( + self, + model: models.Model, + system: str = "", + tools: list[tools_.Tool[..., Any]] | None = None, + ) -> None: + self._model = model + self._system = system + self._tools = tools + self._custom_loop: LoopFn | None = None + + @property + def model(self) -> models.Model: + return self._model + + @property + def system(self) -> str: + return self._system + + @property + def tools(self) -> list[tools_.Tool[..., Any]]: + """Registered tools. ``None`` at init resolves to all globally + registered tools at access time.""" + if self._tools is None: + return list(tools_._tool_registry.values()) + return list(self._tools) + + def loop(self, fn: LoopFn) -> LoopFn: + """Decorator to override the default agent loop. + + The decorated function receives the :class:`Agent` instance and + the per-run messages:: + + @my_agent.loop + async def custom( + agent: ai.Agent, messages: list[ai.Message], + ) -> ai.StreamResult: + ... + """ + self._custom_loop = fn + return fn + + async def _default_loop( + self, messages: list[messages_.Message] + ) -> streams_.StreamResult: + """Built-in loop: stream LLM, execute tools, repeat.""" + local_messages = list(messages) + + while True: + result = await stream_step(self.model, local_messages, self.tools) + + if not result.tool_calls: + return result + + last_msg = result.last_message + if last_msg is not None: + local_messages.append(last_msg) + + await asyncio.gather( + *( + runtime_.execute_tool(tc, message=last_msg) + for tc in result.tool_calls + ) + ) + + def run( + self, + messages: list[messages_.Message], + *, + checkpoint: checkpoint_.Checkpoint | None = None, + ) -> AgentRun: + """Run the agent. + + Returns an :class:`AgentRun` — async-iterate for streamed + messages, or call ``.collect()`` for the final result. + + Args: + messages: Conversation messages (user, assistant, etc.). + checkpoint: Resume from a previous checkpoint. + """ + # Prepend system prompt + full_messages: list[messages_.Message] = [] + if self._system: + full_messages.append( + messages_.Message( + role="system", + parts=[messages_.TextPart(text=self._system)], + ) + ) + full_messages.extend(messages) + + ctx = context_.Context(tools=self.tools) + + # Build the graph function that runtime_.run() expects + async def _graph() -> streams_.StreamResult | None: + if self._custom_loop: + return await self._custom_loop(self, full_messages) + return await self._default_loop(full_messages) + + inner = runtime_.run( + _graph, + checkpoint=checkpoint, + context=ctx, + ) + return AgentRun(inner) + + +# ── Factory ─────────────────────────────────────────────────────── + + +def agent( + model: models.Model, + system: str = "", + tools: list[tools_.Tool[..., Any]] | None = None, +) -> Agent: + """Create an :class:`Agent`. + + Args: + model: The language model to use. + system: System prompt. + tools: Tools available to the agent. ``None`` (default) means + all globally registered tools. Pass ``[]`` to disable. + """ + return Agent(model=model, system=system, tools=tools) diff --git a/src/vercel_ai_sdk/agents/context.py b/src/vercel_ai_sdk/agents/context.py new file mode 100644 index 00000000..73dd56b1 --- /dev/null +++ b/src/vercel_ai_sdk/agents/context.py @@ -0,0 +1,206 @@ +"""Context — everything the LLM sees during a run. + +Consolidates tool registry, system prompt, message history, and model +reference into a single, serializable object. Independent of execution +machinery (Runtime) — can be constructed, inspected, and serialized +without starting a run. + +The context is stashed in a contextvar during ``run()`` so that +framework internals (``execute_tool``, MCP client, etc.) can access it. +""" + +from __future__ import annotations + +import contextvars +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +import pydantic + +from ..types import messages as messages_ + +if TYPE_CHECKING: + from . import tools as tools_ + + +# ── ToolSource ──────────────────────────────────────────────────── + + +class ToolSource(pydantic.BaseModel): + """Provenance info for a tool — how to find or reconstruct it. + + Carries enough information to locate the code behind a tool, + whether it's a decorated Python function or an MCP server. + + Attributes: + kind: ``"python"``, ``"mcp_stdio"``, or ``"mcp_http"``. + module: Python module path, e.g. ``"myapp.tools"``. + qualname: Qualified name, e.g. ``"get_weather"``. + uri: Remote URL for HTTP-based MCP servers. + server_command: Launch command for stdio MCP servers. + """ + + model_config = pydantic.ConfigDict(frozen=True) + + kind: str + module: str | None = None + qualname: str | None = None + uri: str | None = None + server_command: str | None = None + + +# ── Context ─────────────────────────────────────────────────────── + + +class Context(pydantic.BaseModel): + """Everything the LLM sees: tools, system prompt, messages, model. + + Independent of execution machinery (Runtime). Constructable by the + user or auto-constructed by ``run()``. + + Usage:: + + ctx = Context( + system_prompt="You are a helpful assistant.", + tools=[get_weather, get_population], + ) + ctx.get_tool("get_weather") # look up by name + data = ctx.model_dump() # serializable snapshot + """ + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + model: Any = None + system_prompt: str = "" + messages: list[messages_.Message] = pydantic.Field(default_factory=list) + + _tools: dict[str, tools_.Tool[..., Any]] = pydantic.PrivateAttr( + default_factory=dict + ) + + def __init__( + self, + *, + tools: Sequence[tools_.Tool[..., Any]] | None = None, + **data: Any, + ) -> None: + super().__init__(**data) + if tools: + for t in tools: + self.register_tool(t) + + # ── Tool registry (scoped to this context) ──────────────── + + def register_tool(self, tool: tools_.Tool[..., Any]) -> None: + """Register a tool in this context's scoped registry.""" + self._tools[tool.name] = tool + + def get_tool(self, name: str) -> tools_.Tool[..., Any] | None: + """Look up a tool by name. Returns ``None`` if not found.""" + return self._tools.get(name) + + @property + def tools(self) -> list[tools_.Tool[..., Any]]: + """All tools registered in this context.""" + return list(self._tools.values()) + + @property + def tool_schemas(self) -> list[tools_.ToolSchema]: + """Tool schemas — what gets sent to the LLM.""" + return [t.schema for t in self._tools.values()] + + # ── Serialization ───────────────────────────────────────── + + @pydantic.model_serializer + def _serialize(self) -> dict[str, Any]: + """Serialize including tool schemas and sources. + + Tool code is not serialized — only schemas and source + references. + """ + return { + "system_prompt": self.system_prompt, + "messages": [m.model_dump() for m in self.messages], + "tools": [ + { + "schema": t.schema.model_dump(), + "source": (t.source.model_dump() if t.source is not None else None), + } + for t in self._tools.values() + ], + } + + @pydantic.model_validator(mode="wrap") + @classmethod + def _validate( + cls, + data: Any, + handler: pydantic.ValidatorFunctionWrapHandler, + ) -> Context: + """Reconstruct from serialized form or pass through normal init. + + When deserializing, tools are schema-only (not executable) + unless their sources can be resolved from the global registry. + """ + # Normal construction (already a Context, or keyword args without + # a ``tools`` key that looks like serialized tool dicts). + if isinstance(data, cls): + return data + if not isinstance(data, dict) or "tools" not in data: + result: Context = handler(data) + return result + + # Check whether tools contains serialized dicts (from model_dump) + # vs. live Tool objects (from normal __init__). + tools_value = data["tools"] + if tools_value and isinstance(tools_value[0], dict): + return cls._from_serialized(data) + + # Live Tool objects — let the normal init path handle it. + result = handler(data) + return result + + @classmethod + def _from_serialized(cls, data: dict[str, Any]) -> Context: + """Reconstruct from ``model_dump()`` output.""" + from . import tools as tools_ + + ctx = cls( + system_prompt=data.get("system_prompt", ""), + messages=[ + messages_.Message.model_validate(m) for m in data.get("messages", []) + ], + ) + + for tool_data in data.get("tools", []): + schema = tools_.ToolSchema.model_validate(tool_data["schema"]) + source_data = tool_data.get("source") + source = ToolSource(**source_data) if source_data else None + + # Try to resolve the tool from the global registry + live_tool = tools_.get_tool(schema.name) + if live_tool is not None: + ctx.register_tool(live_tool) + else: + # Schema-only placeholder — inspectable but not executable + placeholder = tools_.Tool( + fn=tools_._unresolvable_tool_fn(schema.name), + schema=schema, + source=source, + ) + ctx.register_tool(placeholder) + + return ctx + + +# ── Contextvar ──────────────────────────────────────────────────── + +_context: contextvars.ContextVar[Context] = contextvars.ContextVar("context") + + +def get_context() -> Context: + """Get the active Context from the current run. + + Raises ``LookupError`` if called outside of ``ai.run()``. + """ + return _context.get() diff --git a/src/vercel_ai_sdk/agents/hooks.py b/src/vercel_ai_sdk/agents/hooks.py index 758a4c4f..948539bb 100644 --- a/src/vercel_ai_sdk/agents/hooks.py +++ b/src/vercel_ai_sdk/agents/hooks.py @@ -46,8 +46,7 @@ def _cleanup_run(labels: set[str]) -> None: class Hook[T: pydantic.BaseModel]: - """ - Hook: a suspension point that requires external input to continue. + """Hook: a suspension point that requires external input to continue. Usage in graph code: @@ -77,20 +76,14 @@ class Hook[T: pydantic.BaseModel]: @classmethod async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: - """ - Create a hook and await its resolution. + """Create a hook and await its resolution. - The hook is submitted to the Runtime's step queue. run() will either: + The hook is submitted to the LoopExecutor's step queue. run() will + either: - Resolve immediately (if a resolution is available from checkpoint or pre-registered via Hook.resolve()) - Cancel the future (cancels_future=True, serverless mode) - Hold the future (cancels_future=False, long-running mode) - - Args: - label: Stable identifier for this hook. Used to match resolutions - across requests in serverless mode. Must be unique within - a single run. - metadata: Optional metadata surfaced in the pending HookPart message. """ from . import runtime as rt_mod @@ -101,16 +94,16 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: # Check pre-registered resolutions (serverless re-entry path) pre_registered = _pending_resolutions.pop(label, None) if pre_registered is not None: - rt.record_hook(label, pre_registered) + rt.log.record_hook(label, pre_registered) return cls._schema(**pre_registered) # type: ignore[return-value] # Check checkpoint for a previously resolved value - resolution = rt.get_hook_resolution(label) + resolution = rt.log.get_hook_resolution(label) if resolution is not None: - rt.record_hook(label, resolution) + rt.log.record_hook(label, resolution) return cls._schema(**resolution) # type: ignore[return-value] - # Submit to step queue — run() decides what to do + # Submit to executor queue — run() decides what to do future: asyncio.Future[dict[str, Any]] = asyncio.Future() suspension = rt_mod.HookSuspension( label=label, @@ -119,12 +112,12 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: future=future, cancels_future=cls.cancels_future, ) - await rt.put_hook_suspension(suspension) + await rt.executor.put_hook(suspension) # Register in module-level registry for external resolution hook_metadata = metadata or {} _live_hooks[label] = (future, hook_metadata, rt) - rt.track_hook_label(label) + rt.executor.track_hook_label(label) # Await resolution — may be resolved immediately by run(), # cancelled by run() (serverless), or resolved later by @@ -135,10 +128,10 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: _live_hooks.pop(label, None) # Record for checkpoint - rt.record_hook(label, resolution) + rt.log.record_hook(label, resolution) # Emit resolved message - await rt.put_message( + await rt.executor.put_message( messages_.Message( role="assistant", parts=[ @@ -157,8 +150,7 @@ async def create(cls, label: str, metadata: dict[str, Any] | None = None) -> T: @classmethod def resolve(cls, label: str, data: T | dict[str, Any]) -> None: - """ - Resolve a hook by label. + """Resolve a hook by label. Works in two modes: @@ -169,19 +161,12 @@ def resolve(cls, label: str, data: T | dict[str, Any]) -> None: stashes it in the pre-registration registry. When ai.run() replays the graph and Hook.create() executes, it finds the pre-registered resolution and returns without suspending. - - Args: - label: The hook label to resolve. - data: Resolution payload (dict or pydantic model). Validated - against the hook's schema immediately. """ # Validate and normalize to dict if isinstance(data, dict): - # Validate by constructing the schema model validated = cls._schema(**data) resolution = validated.model_dump() else: - # Already a model instance — validate it's the right type if not isinstance(data, cls._schema): raise TypeError( f"Expected {cls._schema.__name__} or dict, " @@ -211,7 +196,7 @@ async def cancel(cls, label: str, reason: str | None = None) -> None: future, hook_metadata, rt = _live_hooks.pop(label) future.cancel(reason) - await rt.put_message( + await rt.executor.put_message( messages_.Message( role="assistant", parts=[ @@ -227,8 +212,7 @@ async def cancel(cls, label: str, reason: str | None = None) -> None: def hook[T: pydantic.BaseModel](cls: type[T]) -> type[Hook[T]]: - """ - Decorator to create a Hook type from a pydantic model. + """Decorator to create a Hook type from a pydantic model. The pydantic model defines the schema for the hook's resolution payload. """ diff --git a/src/vercel_ai_sdk/agents/mcp/client.py b/src/vercel_ai_sdk/agents/mcp/client.py index c17a25a0..def1f0e6 100644 --- a/src/vercel_ai_sdk/agents/mcp/client.py +++ b/src/vercel_ai_sdk/agents/mcp/client.py @@ -14,6 +14,7 @@ import mcp.client.streamable_http import mcp.types +from .. import context as context_ from .. import tools as tools_ __all__ = [ @@ -243,11 +244,30 @@ def _mcp_tool_to_native( return_type=Any, ) + # Determine source provenance from connection key + if connection_key.startswith("http:"): + source = context_.ToolSource( + kind="mcp_http", + uri=connection_key.removeprefix("http:"), + ) + elif connection_key.startswith("stdio:"): + source = context_.ToolSource( + kind="mcp_stdio", + server_command=connection_key.removeprefix("stdio:"), + ) + else: + source = context_.ToolSource(kind="mcp") + t = tools_.Tool( fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), schema=schema, + source=source, ) - # Register so execute_tool() can find it by name + + # Register on active Context if available, else fall back to global + ctx = context_._context.get(None) + if ctx is not None: + ctx.register_tool(t) tools_._tool_registry[name] = t return t diff --git a/src/vercel_ai_sdk/agents/runtime.py b/src/vercel_ai_sdk/agents/runtime.py index e9d8aadf..724267b3 100644 --- a/src/vercel_ai_sdk/agents/runtime.py +++ b/src/vercel_ai_sdk/agents/runtime.py @@ -5,15 +5,15 @@ import dataclasses import json import logging -from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine, Sequence +from collections.abc import AsyncGenerator, Awaitable, Callable, Coroutine from typing import Any, get_type_hints import pydantic -from ..models.core import llm as llm_ from ..telemetry import events as telemetry_ from ..types import messages as messages_ from . import checkpoint as checkpoint_ +from . import context as context_ from . import hooks as hooks_ from . import mcp from . import streams as streams_ @@ -21,51 +21,23 @@ logger = logging.getLogger(__name__) -# ── Queue item types ────────────────────────────────────────────── - -@dataclasses.dataclass -class HookSuspension: - """Submitted to the step queue when a hook needs resolution.""" - - label: str - hook_type: str - metadata: dict[str, Any] - future: asyncio.Future[Any] - cancels_future: bool = False - - -# ── Runtime ─────────────────────────────────────────────────────── +# ── EventLog ────────────────────────────────────────────────────── +# +# Pure bookkeeping: replay from checkpoint + record new events. +# No asyncio, no queues — just data in, data out. +# -class Runtime: - """ - Central coordinator for the agent loop. +class EventLog: + """Replay/record layer backed by a Checkpoint. - Functions decorated with @stream submit step functions to the queue. - Hooks submit HookSuspension items to the same queue. - run() pulls items, processes them, yields messages, and resolves futures. + Holds replay cursors (read pointer into the checkpoint) and + recording lists (new events produced during this run). + Completely synchronous — no queues, no async. """ - class _Sentinel: - pass - - _SENTINEL = _Sentinel() - - def __init__( - self, - checkpoint: checkpoint_.Checkpoint | None = None, - ) -> None: - self._step_queue: asyncio.Queue[ - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] - | HookSuspension - | Runtime._Sentinel - ] = asyncio.Queue() - - # Message queue for streaming tool results and hook messages - self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() - - # Checkpoint: replay state from previous run + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: self._checkpoint = checkpoint or checkpoint_.Checkpoint() # Replay cursors @@ -77,45 +49,16 @@ def __init__( h.label: h.resolution for h in self._checkpoint.hooks } - # Recording: new events from this run + # Recording lists (new events from this run) self._step_log: list[checkpoint_.StepEvent] = [] self._tool_log: list[checkpoint_.ToolEvent] = [] self._hook_log: list[checkpoint_.HookEvent] = [] - # Pending hooks (unresolved during this run) - self._pending_hooks: dict[str, HookSuspension] = {} - - # Track hook labels registered in this run for cleanup - self._hook_labels: set[str] = set() - - # ── Step queue ──────────────────────────────────────────────── - - async def put_step( - self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] - ) -> None: - await self._step_queue.put((step_fn, future)) - - async def put_hook_suspension(self, suspension: HookSuspension) -> None: - await self._step_queue.put(suspension) - - async def signal_done(self) -> None: - await self._step_queue.put(self._SENTINEL) - - # ── Message queue ───────────────────────────────────────────── - - async def put_message(self, message: messages_.Message) -> None: - await self._message_queue.put(message) - - def get_all_messages(self) -> list[messages_.Message]: - msgs = [] - while not self._message_queue.empty(): - try: - msgs.append(self._message_queue.get_nowait()) - except asyncio.QueueEmpty: - break - return msgs + # ── Steps ───────────────────────────────────────────────── - # ── Replay / record: steps ──────────────────────────────────── + @property + def step_index(self) -> int: + return self._step_index def try_replay_step(self) -> streams_.StreamResult | None: if self._step_index < len(self._checkpoint.steps): @@ -133,10 +76,9 @@ def record_step(self, result: streams_.StreamResult) -> None: self._step_log.append(event) self._step_index += 1 - # ── Replay / record: tools ──────────────────────────────────── + # ── Tools ───────────────────────────────────────────────── def try_replay_tool(self, tool_call_id: str) -> checkpoint_.ToolEvent | None: - """Return the cached ToolEvent if available, else None.""" event = self._tool_replay.get(tool_call_id) if event is not None: logger.info( @@ -155,7 +97,7 @@ def record_tool( ) ) - # ── Replay / record: hooks ──────────────────────────────────── + # ── Hooks ───────────────────────────────────────────────── def get_hook_resolution(self, label: str) -> dict[str, Any] | None: resolution = self._hook_replay.get(label) @@ -166,25 +108,158 @@ def get_hook_resolution(self, label: str) -> dict[str, Any] | None: def record_hook(self, label: str, resolution: dict[str, Any]) -> None: self._hook_log.append(checkpoint_.HookEvent(label=label, resolution=resolution)) - def track_hook_label(self, label: str) -> None: - """Track a hook label for cleanup when the run completes.""" - self._hook_labels.add(label) - - # ── Checkpoint ──────────────────────────────────────────────── + # ── Snapshot ────────────────────────────────────────────── - def get_checkpoint(self) -> checkpoint_.Checkpoint: + def checkpoint( + self, pending_hooks: list[checkpoint_.PendingHookInfo] | None = None + ) -> checkpoint_.Checkpoint: + """Build a full Checkpoint merging prior state + new recordings.""" return checkpoint_.Checkpoint( steps=list(self._checkpoint.steps) + self._step_log, tools=list(self._checkpoint.tools) + self._tool_log, hooks=list(self._checkpoint.hooks) + self._hook_log, - pending_hooks=[ - checkpoint_.PendingHookInfo( - label=sus.label, - hook_type=sus.hook_type, - metadata=sus.metadata, - ) - for sus in self._pending_hooks.values() - ], + pending_hooks=pending_hooks or [], + ) + + +# ── LoopExecutor ───────────────────────────────────────────────── +# +# Async coordination: queues that let graph code (streams, hooks, +# tools) talk to the driver loop. Pure mailbox — no replay, no +# checkpoint awareness. +# + + +@dataclasses.dataclass +class HookSuspension: + """Submitted to the step queue when a hook needs resolution.""" + + label: str + hook_type: str + metadata: dict[str, Any] + future: asyncio.Future[Any] + cancels_future: bool = False + + +class LoopExecutor: + """Async coordination layer between graph code and the driver loop. + + Graph code (``@stream`` decorators, hooks, tool execution) submits + work via the producer methods. The driver loop consumes via + ``next()`` and ``drain_messages()``. + """ + + class _Sentinel: + pass + + _SENTINEL = _Sentinel() + + def __init__(self) -> None: + self._step_queue: asyncio.Queue[ + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | LoopExecutor._Sentinel + ] = asyncio.Queue() + + self._message_queue: asyncio.Queue[messages_.Message] = asyncio.Queue() + + # Pending hooks (unresolved during this run) + self._pending_hooks: dict[str, HookSuspension] = {} + + # Track hook labels registered in this run for cleanup + self._hook_labels: set[str] = set() + + # ── Producers (called by graph code) ────────────────────── + + async def put_step( + self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] + ) -> None: + await self._step_queue.put((step_fn, future)) + + async def put_hook(self, suspension: HookSuspension) -> None: + await self._step_queue.put(suspension) + + async def put_message(self, message: messages_.Message) -> None: + await self._message_queue.put(message) + + async def done(self) -> None: + await self._step_queue.put(self._SENTINEL) + + # ── Consumer (called by driver loop) ────────────────────── + + async def next( + self, timeout: float = 0.1 + ) -> ( + tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + | HookSuspension + | None + ): + """Pull the next item from the step queue. + + Returns ``None`` on timeout (no item available). + Returns the sentinel's semantic equivalent by raising StopIteration + when the graph signals completion. + """ + try: + item = await asyncio.wait_for(self._step_queue.get(), timeout=timeout) + except TimeoutError: + return None + + if isinstance(item, LoopExecutor._Sentinel): + raise _LoopDone + return item + + def drain_messages(self) -> list[messages_.Message]: + msgs: list[messages_.Message] = [] + while not self._message_queue.empty(): + try: + msgs.append(self._message_queue.get_nowait()) + except asyncio.QueueEmpty: + break + return msgs + + # ── Hook label tracking ─────────────────────────────────── + + def track_hook_label(self, label: str) -> None: + self._hook_labels.add(label) + + def pending_hook_infos(self) -> list[checkpoint_.PendingHookInfo]: + return [ + checkpoint_.PendingHookInfo( + label=sus.label, + hook_type=sus.hook_type, + metadata=sus.metadata, + ) + for sus in self._pending_hooks.values() + ] + + +class _LoopDone(Exception): + """Internal signal: the loop function has finished.""" + + +# ── Runtime ─────────────────────────────────────────────────────── +# +# Thin composition of EventLog + LoopExecutor. +# The context var points here; graph code accesses rt.log and +# rt.executor directly. +# + + +class Runtime: + """Central coordinator — composes EventLog and LoopExecutor. + + Graph code accesses ``rt.log`` for replay/record and + ``rt.executor`` for async coordination. + """ + + def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: + self.log = EventLog(checkpoint) + self.executor = LoopExecutor() + + def checkpoint(self) -> checkpoint_.Checkpoint: + return self.log.checkpoint( + pending_hooks=self.executor.pending_hook_infos(), ) @@ -193,7 +268,7 @@ def get_checkpoint(self) -> checkpoint_.Checkpoint: def get_checkpoint() -> checkpoint_.Checkpoint: """Get the current checkpoint from the active Runtime.""" - return _runtime.get().get_checkpoint() + return _runtime.get().checkpoint() def _find_runtime_param(fn: Callable[..., Any]) -> str | None: @@ -208,36 +283,17 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: return None -# ── Convenience functions ───────────────────────────────────────── - - -@streams_.stream -async def stream_step( - llm: llm_.LanguageModel, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, -) -> AsyncGenerator[messages_.Message]: - """Single LLM call that streams to Runtime.""" - async for msg in llm.stream( - messages=messages, tools=tools, output_type=output_type - ): - msg.label = label - yield msg - - async def execute_tool( tool_call: messages_.ToolPart, message: messages_.Message | None = None, ) -> Any: - """ - Execute a single tool call with replay support. + """Execute a single tool call with replay support. - Looks up the tool by name from the global registry, executes it, - and updates the ToolPart (and parent Message) with the result. - Emits the updated message to the Runtime queue so the UI sees - the transition from status="pending" to status="result" (or "error"). + Looks up the tool by name — first from the active Context (if any), + then from the global registry. Executes it and updates the ToolPart + (and parent Message) with the result. Emits the updated message to + the LoopExecutor queue so the UI sees the transition from + status="pending" to status="result" (or "error"). If a checkpoint exists with a cached result for this tool_call_id, returns the cached result without re-executing. @@ -246,7 +302,7 @@ async def execute_tool( # Replay: return cached result if available if rt: - cached = rt.try_replay_tool(tool_call.tool_call_id) + cached = rt.log.try_replay_tool(tool_call.tool_call_id) if cached is not None: if cached.status == "error": tool_call.set_error(cached.result) @@ -263,8 +319,13 @@ async def execute_tool( ) t0 = telemetry_.time_ms() - # Fresh execution - tool = tools_.get_tool(tool_call.tool_name) + # Fresh execution — resolve from Context first, then global registry + tool: tools_.Tool[..., Any] | None = None + ctx = context_._context.get(None) + if ctx is not None: + tool = ctx.get_tool(tool_call.tool_name) + if tool is None: + tool = tools_.get_tool(tool_call.tool_name) if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") @@ -273,8 +334,6 @@ async def execute_tool( result = await tool.validate_and_call(tool_call.tool_args, rt) tool_call.set_result(result) except (json.JSONDecodeError, pydantic.ValidationError) as exc: - # LLM produced malformed JSON or args that don't match the schema. - # Report back as a tool error so the model can retry. result = f"{type(exc).__name__}: {exc}" error_str = result tool_call.set_error(result) @@ -291,42 +350,15 @@ async def execute_tool( # Record for checkpoint if rt: - rt.record_tool(tool_call.tool_call_id, result, status=tool_call.status) + rt.log.record_tool(tool_call.tool_call_id, result, status=tool_call.status) # Emit updated message so UI sees status change if rt and message: - await rt.put_message(message.model_copy(deep=True)) + await rt.executor.put_message(message.model_copy(deep=True)) return result -async def stream_loop( - llm: llm_.LanguageModel, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike], - label: str | None = None, - output_type: type[pydantic.BaseModel] | None = None, -) -> streams_.StreamResult: - """Agent loop: stream LLM, execute tools, repeat until done.""" - local_messages = list(messages) - - while True: - result = await stream_step( - llm, local_messages, tools, label=label, output_type=output_type - ) - - if not result.tool_calls: - return result - - last_msg = result.last_message - if last_msg is not None: - local_messages.append(last_msg) - - await asyncio.gather( - *(execute_tool(tc, message=last_msg) for tc in result.tool_calls) - ) - - # ── RunResult ───────────────────────────────────────────────────── @@ -340,8 +372,7 @@ class HookInfo: class RunResult: - """ - Returned by run(). Async-iterate for messages, then check state. + """Returned by run(). Async-iterate for messages, then check state. Usage: result = ai.run(my_graph, llm, query) @@ -359,7 +390,7 @@ def __init__(self) -> None: def checkpoint(self) -> checkpoint_.Checkpoint: if self._runtime is None: return checkpoint_.Checkpoint() - return self._runtime.get_checkpoint() + return self._runtime.checkpoint() @property def pending_hooks(self) -> dict[str, HookInfo]: @@ -371,7 +402,7 @@ def pending_hooks(self) -> dict[str, HookInfo]: hook_type=sus.hook_type, metadata=sus.metadata, ) - for label, sus in self._runtime._pending_hooks.items() + for label, sus in self._runtime.executor._pending_hooks.items() } async def __aiter__(self) -> AsyncGenerator[messages_.Message]: @@ -383,37 +414,38 @@ async def __aiter__(self) -> AsyncGenerator[messages_.Message]: # ── run() ───────────────────────────────────────────────────────── -async def _stop_when_done(runtime: Runtime, task: Awaitable[None]) -> None: +async def _stop_when_done(executor: LoopExecutor, task: Awaitable[None]) -> None: try: await task finally: - await runtime.signal_done() + await executor.done() def run( root: Callable[..., Coroutine[Any, Any, Any]], *args: Any, checkpoint: checkpoint_.Checkpoint | None = None, + context: context_.Context | None = None, ) -> RunResult: - """ - Main entry point. + """Main entry point. 1. Starts the root function as a background task - 2. Pulls steps and hook suspensions from the Runtime queue + 2. Pulls steps and hook suspensions from the LoopExecutor queue 3. Executes each step, yielding messages 4. Resolves or suspends hooks depending on the hook's cancels_future - class variable: - - cancels_future=True (serverless): cancel the future, branch dies, - caller inspects result.pending_hooks and result.checkpoint to resume - - cancels_future=False (long-running, default): future stays alive, - external code calls Hook.resolve() / Hook.cancel() to unblock 5. Returns RunResult with .checkpoint and .pending_hooks + + Args: + root: The loop function to execute. + *args: Positional arguments forwarded to ``root``. + checkpoint: Checkpoint to resume from. + context: LLM prompt context (tools, system prompt, messages). + If ``None``, an empty Context is created automatically. """ result = RunResult() # Discard stale checkpoints: if the checkpoint has pending hooks but - # none of them have been resolved (via Hook.resolve() / to_messages()), - # this isn't a resume — it's a fresh turn with an outdated checkpoint. + # none of them have been resolved, this isn't a resume. effective_checkpoint = checkpoint if checkpoint and checkpoint.pending_hooks: pending_labels = [ph.label for ph in checkpoint.pending_hooks] @@ -435,9 +467,13 @@ class variable: ) async def _generate() -> AsyncGenerator[messages_.Message]: - runtime = Runtime(checkpoint=effective_checkpoint) - result._runtime = runtime - token_runtime = _runtime.set(runtime) + rt = Runtime(checkpoint=effective_checkpoint) + result._runtime = rt + token_runtime = _runtime.set(rt) + + ctx = context or context_.Context() + token_context = context_._context.set(ctx) + token_run_id = telemetry_.start_run() telemetry_.handle(telemetry_.RunStartEvent()) @@ -447,7 +483,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: kwargs: dict[str, Any] = {} if runtime_param := _find_runtime_param(root): - kwargs[runtime_param] = runtime + kwargs[runtime_param] = rt run_error: BaseException | None = None total_usage: messages_.Usage | None = None @@ -455,77 +491,64 @@ async def _generate() -> AsyncGenerator[messages_.Message]: try: async with asyncio.TaskGroup() as tg: _task: asyncio.Task[None] = tg.create_task( - _stop_when_done(runtime, root(*args, **kwargs)) + _stop_when_done(rt.executor, root(*args, **kwargs)) ) while True: # Drain pending messages - for msg in runtime.get_all_messages(): + for msg in rt.executor.drain_messages(): yield msg - # Wait for next queue item + # Pull next item from the graph executor try: - step_item = await asyncio.wait_for( - runtime._step_queue.get(), timeout=0.1 - ) - except TimeoutError: - continue - - if isinstance(step_item, Runtime._Sentinel): - for msg in runtime.get_all_messages(): + item = await rt.executor.next() + except _LoopDone: + for msg in rt.executor.drain_messages(): yield msg break + if item is None: + # Timeout — no item available, loop again + continue + # ── Hook suspension ──────────────────────── - if isinstance(step_item, HookSuspension): - resolution = runtime.get_hook_resolution(step_item.label) + if isinstance(item, HookSuspension): + resolution = rt.log.get_hook_resolution(item.label) if resolution is not None: - # Resolve immediately - step_item.future.set_result(resolution) - runtime.record_hook(step_item.label, resolution) + item.future.set_result(resolution) + rt.log.record_hook(item.label, resolution) else: - # No resolution available - runtime._pending_hooks[step_item.label] = step_item - if step_item.cancels_future: - # Serverless: cancel the future so the branch - # dies with CancelledError. Caller inspects - # result.pending_hooks to resume later. - step_item.future.cancel() - # else: long-running — future stays alive, - # external code calls Hook.resolve() to unblock. - - # Yield pending hook message + rt.executor._pending_hooks[item.label] = item + if item.cancels_future: + item.future.cancel() + yield messages_.Message( role="assistant", parts=[ messages_.HookPart( - hook_id=step_item.label, - hook_type=step_item.hook_type, + hook_id=item.label, + hook_type=item.hook_type, status="pending", - metadata=step_item.metadata, + metadata=item.metadata, ) ], ) - # Let resolved branches resume and submit their - # next steps before we pull from the queue again. await asyncio.sleep(0) - - # Drain messages after hook processing - for msg in runtime.get_all_messages(): + for msg in rt.executor.drain_messages(): yield msg continue # ── Regular step ─────────────────────────── - step_fn, future = step_item + step_fn, future = item telemetry_.handle( telemetry_.StepStartEvent( - step_index=runtime._step_index, + step_index=rt.log.step_index, ) ) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg result_messages: list[messages_.Message] = [] @@ -535,7 +558,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: yield msg_copy result_messages.append(msg) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg step_result = streams_.StreamResult(messages=result_messages) @@ -543,7 +566,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: telemetry_.handle( telemetry_.StepFinishEvent( - step_index=runtime._step_index, + step_index=rt.log.step_index, result=step_result, ) ) @@ -558,7 +581,7 @@ async def _generate() -> AsyncGenerator[messages_.Message]: ) await asyncio.sleep(0) - for tool_msg in runtime.get_all_messages(): + for tool_msg in rt.executor.drain_messages(): yield tool_msg except BaseException as exc: @@ -574,13 +597,13 @@ async def _generate() -> AsyncGenerator[messages_.Message]: ) telemetry_.end_run(token_run_id) - # Clean up module-level hook registries for this run - hooks_._cleanup_run(runtime._hook_labels) + hooks_._cleanup_run(rt.executor._hook_labels) if mcp_token is not None: await mcp.client.close_connections() mcp.client._pool.reset(mcp_token) + context_._context.reset(token_context) _runtime.reset(token_runtime) result._messages = _generate() diff --git a/src/vercel_ai_sdk/agents/streams.py b/src/vercel_ai_sdk/agents/streams.py index 80ca7cf8..fadf6747 100644 --- a/src/vercel_ai_sdk/agents/streams.py +++ b/src/vercel_ai_sdk/agents/streams.py @@ -66,10 +66,9 @@ def total_usage(self) -> messages_.Usage | None: def stream[**P]( fn: Callable[P, AsyncGenerator[messages_.Message]], ) -> Callable[P, Awaitable[StreamResult]]: - """ - Decorator to put an async generator into the Runtime queue. + """Decorator to put an async generator into the LoopExecutor queue. - The decorated function submits its work to the Runtime queue and + The decorated function submits its work to the executor queue and blocks until run() processes it, then returns the StreamResult. If a checkpoint exists with a cached result for this step index, @@ -85,22 +84,22 @@ async def wrapped(*args: Any, **kwargs: Any) -> StreamResult: raise ValueError("No Runtime context - must be called within ai.run()") # Replay: return cached result if available - cached = rt.try_replay_step() + cached = rt.log.try_replay_step() if cached is not None: return cached - # Fresh execution: submit to queue and wait + # Fresh execution: submit to executor queue and wait future: asyncio.Future[StreamResult] = asyncio.Future() async def stream_fn() -> AsyncGenerator[messages_.Message]: async for msg in fn(*args, **kwargs): yield msg - await rt.put_step(stream_fn, future) + await rt.executor.put_step(stream_fn, future) result = await future # Record for checkpoint - rt.record_step(result) + rt.log.record_step(result) return result return wrapped diff --git a/src/vercel_ai_sdk/agents/tools.py b/src/vercel_ai_sdk/agents/tools.py index f4e3744b..39a9aa28 100644 --- a/src/vercel_ai_sdk/agents/tools.py +++ b/src/vercel_ai_sdk/agents/tools.py @@ -9,6 +9,7 @@ from ..types.tools import ToolLike as ToolLike from ..types.tools import ToolSchema as ToolSchema +from .context import ToolSource if TYPE_CHECKING: from . import runtime as runtime_ @@ -36,10 +37,12 @@ def __init__( fn: Callable[P, Awaitable[R]], schema: ToolSchema, validator: type[pydantic.BaseModel] | None = None, + source: ToolSource | None = None, ) -> None: self._fn = fn self._validator = validator self.schema = schema + self.source = source async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: return await self._fn(*args, **kwargs) @@ -102,8 +105,32 @@ def tool[**P, R](fn: Callable[P, Awaitable[R]]) -> Tool[P, R]: return_type=hints.get("return", None), ) - t = Tool(fn=fn, schema=schema, validator=validator) + source = ToolSource( + kind="python", + module=getattr(fn, "__module__", None), + qualname=getattr(fn, "__qualname__", None), + ) + + t = Tool(fn=fn, schema=schema, validator=validator, source=source) # 3. register in global registry _tool_registry[t.name] = t return t + + +def _unresolvable_tool_fn(name: str) -> Callable[..., Awaitable[Any]]: + """Create a placeholder async function for schema-only tools. + + Used by ``Context.from_dict()`` when a tool's source cannot be + resolved to live code. + """ + + async def _placeholder(**kwargs: Any) -> Any: + raise RuntimeError( + f"Tool {name!r} was reconstructed from serialized context " + f"and has no executable implementation." + ) + + _placeholder.__name__ = name + _placeholder.__qualname__ = name + return _placeholder diff --git a/src/vercel_ai_sdk/models/__init__.py b/src/vercel_ai_sdk/models/__init__.py index c1e6000a..9e921afd 100644 --- a/src/vercel_ai_sdk/models/__init__.py +++ b/src/vercel_ai_sdk/models/__init__.py @@ -1,32 +1,207 @@ -"""Model adapters — standalone LLM streaming layer. +"""models — composable model layer. -Provides the LanguageModel ABC and concrete provider adapters. -Depends only on types/, never on agents/. +Usage:: + + from vercel_ai_sdk import models + from vercel_ai_sdk.types import Message, TextPart + + model = models.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", + ) + msgs = [Message(role="user", parts=[TextPart(text="hello")])] + + # stream — auto-creates client from env vars + async for msg in models.stream(model, msgs): + print(msg.text_delta, end="") + + # buffer the whole response + result = await models.buffer(models.stream(model, msgs)) + print(result.text) + + # explicit client + client = models.Client( + base_url="https://custom.example.com/v3/ai", api_key="sk-...", + ) + async for msg in models.stream(model, msgs, client=client): + ... """ -from . import ai_gateway, anthropic, core, openai -from .core import ( - ImageModel, - LanguageModel, - MediaModel, - MediaResult, - StreamEvent, - StreamHandler, - VideoModel, -) +from __future__ import annotations + +import os +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import pydantic + +from ..types import messages as messages_ +from ..types import tools as tools_ +from .ai_gateway.generate import GenerateParams, ImageParams, VideoParams +from .core.client import Client +from .core.model import Model, ModelCost +from .core.proto import GenerateFn, StreamFn + +# --------------------------------------------------------------------------- +# Adapter registry — maps adapter string → adapter function. +# Adapter modules are imported lazily on first use. +# --------------------------------------------------------------------------- + +_stream_adapters: dict[str, StreamFn] = {} +_generate_adapters: dict[str, GenerateFn] = {} +_adapters_loaded = False + + +def _ensure_adapters() -> None: + """Lazily register built-in adapter functions on first call.""" + global _adapters_loaded # noqa: PLW0603 + if _adapters_loaded: + return + _adapters_loaded = True + + from .ai_gateway import generate as ai_gw_generate + from .ai_gateway import stream as ai_gw_stream + from .anthropic.adapter import stream as anthropic_stream + from .openai.adapter import stream as openai_stream + + _stream_adapters["ai-gateway-v3"] = ai_gw_stream + _generate_adapters["ai-gateway-v3"] = ai_gw_generate + _stream_adapters["openai"] = openai_stream + _stream_adapters["anthropic"] = anthropic_stream + + +def register_stream(adapter: str, fn: StreamFn) -> None: + """Register a stream adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _stream_adapters[adapter] = fn + + +def register_generate(adapter: str, fn: GenerateFn) -> None: + """Register a generate adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _generate_adapters[adapter] = fn + + +# --------------------------------------------------------------------------- +# Provider defaults — base URLs and env var names for auto-client creation. +# --------------------------------------------------------------------------- + +_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { + "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), + "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), + "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), +} + + +def _auto_client(model: Model) -> Client: + """Create a :class:`Client` from env vars for the given model's provider.""" + defaults = _PROVIDER_DEFAULTS.get(model.provider) + if defaults is None: + raise ValueError( + f"No default client config for provider {model.provider!r}. " + f"Pass an explicit client= argument." + ) + base_url, env_var = defaults + return Client(base_url=base_url, api_key=os.environ.get(env_var)) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def stream( + model: Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + client: Client | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response. + + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from env vars if none is provided, and yields + ``Message`` snapshots. + """ + _ensure_adapters() + c = client or _auto_client(model) + adapter_fn = _stream_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_stream_adapters)) or "(none)" + raise KeyError( + f"No stream adapter registered for adapter={model.adapter!r}. " + f"Registered: {registered}" + ) + async for msg in adapter_fn( + c, model, messages, tools=tools, output_type=output_type, **kwargs + ): + yield msg + + +async def generate( + model: Model, + messages: list[messages_.Message], + params: GenerateParams | None = None, + *, + client: Client | None = None, +) -> messages_.Message: + """Generate a response (images, video, etc.). + + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from env vars if none is provided. + + ``params`` controls the generation type: + + * :class:`ImageParams` — image generation (``/image-model``). + * :class:`VideoParams` — video generation (``/video-model``). + * ``None`` — auto-detect from ``model.capabilities``. + """ + _ensure_adapters() + c = client or _auto_client(model) + adapter_fn = _generate_adapters.get(model.adapter) + if adapter_fn is None: + registered = ", ".join(sorted(_generate_adapters)) or "(none)" + raise KeyError( + f"No generate adapter registered for adapter={model.adapter!r}. " + f"Registered: {registered}" + ) + return await adapter_fn(c, model, messages, params=params) + + +async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: + """Drain a stream and return the final ``Message``. + + Raises :class:`ValueError` if the stream yields nothing. + """ + result: messages_.Message | None = None + async for msg in gen: + result = msg + if result is None: + raise ValueError("empty stream") + return result + __all__ = [ - # Core abstractions - "LanguageModel", - "StreamEvent", - "StreamHandler", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", - "core", - # Provider adapters - "openai", - "anthropic", - "ai_gateway", + # Core types + "Client", + "GenerateFn", + "GenerateParams", + "ImageParams", + "Model", + "ModelCost", + "StreamFn", + "VideoParams", + # Public API + "buffer", + "generate", + "register_generate", + "register_stream", + "stream", ] diff --git a/src/vercel_ai_sdk/models/ai_gateway/__init__.py b/src/vercel_ai_sdk/models/ai_gateway/__init__.py index e467b8ec..7cc9f429 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/__init__.py +++ b/src/vercel_ai_sdk/models/ai_gateway/__init__.py @@ -1,14 +1,14 @@ -"""Vercel AI Gateway provider — language, image, and video models.""" +"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol.""" from . import errors -from .image import GatewayImageModel -from .llm import GatewayModel -from .video import GatewayEmbeddingModel, GatewayVideoModel +from .generate import GenerateParams, ImageParams, VideoParams, generate +from .stream import stream __all__ = [ - "GatewayModel", - "GatewayImageModel", - "GatewayVideoModel", - "GatewayEmbeddingModel", + "GenerateParams", + "ImageParams", + "VideoParams", "errors", + "generate", + "stream", ] diff --git a/src/vercel_ai_sdk/models/ai_gateway/_common.py b/src/vercel_ai_sdk/models/ai_gateway/_common.py new file mode 100644 index 00000000..0031661f --- /dev/null +++ b/src/vercel_ai_sdk/models/ai_gateway/_common.py @@ -0,0 +1,145 @@ +"""Shared helpers for the AI Gateway v3 adapter. + +Contains utilities used by both the streaming (language-model) and generation +(image-model, video-model) endpoints. + +.. note:: + + Several helpers here are candidates for lifting to framework-level: + + - ``extract_prompt`` / ``extract_input_files`` → ``Message`` methods + - ``parse_sse_lines`` → ``core/helpers/sse.py`` +""" + +from __future__ import annotations + +import base64 +import json +from collections.abc import AsyncGenerator +from typing import Any + +import httpx + +from ...types import messages as messages_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ + +_PROTOCOL_VERSION = "0.0.1" + + +# --------------------------------------------------------------------------- +# Message extraction helpers +# --------------------------------------------------------------------------- +# TODO: lift to Message methods — these are universally useful. + + +def extract_prompt(messages: list[messages_.Message]) -> str: + """Concatenate all text from user/system messages into a single prompt string.""" + parts: list[str] = [] + for msg in messages: + if msg.role in ("user", "system"): + for p in msg.parts: + if isinstance(p, messages_.TextPart): + parts.append(p.text) + return " ".join(parts) + + +def extract_input_files(messages: list[messages_.Message]) -> list[messages_.FilePart]: + """Collect all file parts from user messages.""" + files: list[messages_.FilePart] = [] + for msg in messages: + if msg.role == "user": + for p in msg.parts: + if isinstance(p, messages_.FilePart): + files.append(p) + return files + + +# --------------------------------------------------------------------------- +# Wire format helpers +# --------------------------------------------------------------------------- + + +def file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to the gateway wire format for input files.""" + data = part.data + if isinstance(data, str) and media_.is_url(data): + return {"type": "url", "url": data} + if isinstance(data, bytes): + b64 = base64.b64encode(data).decode("ascii") + elif isinstance(data, str): + b64 = data + else: + b64 = str(data) + return {"type": "file", "data": b64, "mediaType": part.media_type} + + +# --------------------------------------------------------------------------- +# Request headers +# --------------------------------------------------------------------------- + + +def request_headers( + client: client_.Client, + model: model_.Model, + *, + model_type: str = "language", + streaming: bool = False, +) -> dict[str, str]: + """Build gateway-specific request headers. + + Args: + client: The HTTP client (provides api_key). + model: The model (provides id). + model_type: One of ``"language"``, ``"image"``, ``"video"``. + streaming: Whether this is a streaming request (language-model only). + """ + h: dict[str, str] = { + "Content-Type": "application/json", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + + if model_type == "language": + h["ai-language-model-specification-version"] = "3" + h["ai-language-model-id"] = model.id + h["ai-language-model-streaming"] = str(streaming).lower() + elif model_type == "image": + h["ai-image-model-specification-version"] = "3" + h["ai-model-id"] = model.id + elif model_type == "video": + h["ai-video-model-specification-version"] = "3" + h["ai-model-id"] = model.id + + if client.api_key: + h["Authorization"] = f"Bearer {client.api_key}" + h["ai-gateway-auth-method"] = "api-key" + + return h + + +# --------------------------------------------------------------------------- +# SSE parsing +# --------------------------------------------------------------------------- +# TODO: lift to core/helpers/sse.py — any SSE-based adapter will need this. + + +async def parse_sse_lines( + response: httpx.Response, +) -> AsyncGenerator[dict[str, Any]]: + """Yield parsed JSON dicts from an SSE response stream. + + Handles the ``data: `` / ``data: [DONE]`` protocol used by the + AI Gateway's streaming endpoints. + """ + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + yield json.loads(payload) + except json.JSONDecodeError: + continue diff --git a/src/vercel_ai_sdk/models/ai_gateway/generate.py b/src/vercel_ai_sdk/models/ai_gateway/generate.py new file mode 100644 index 00000000..ab460b02 --- /dev/null +++ b/src/vercel_ai_sdk/models/ai_gateway/generate.py @@ -0,0 +1,253 @@ +"""AI Gateway v3 generation adapter — image-model and video-model endpoints. + +Provides typed parameter objects (:class:`ImageParams`, :class:`VideoParams`) +and a unified :func:`generate` entry point that dispatches based on param type +and validates against model capabilities. +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pydantic + +from ...types import messages as messages_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from . import _common +from . import errors as errors_ + +# --------------------------------------------------------------------------- +# Parameter types +# --------------------------------------------------------------------------- + +_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) + + +class ImageParams(pydantic.BaseModel): + """Parameters for image generation (``/image-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + size: str | None = None + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +class VideoParams(pydantic.BaseModel): + """Parameters for video generation (``/video-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + resolution: str | None = None + duration: int | None = None + fps: int | None = None + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +GenerateParams = ImageParams | VideoParams + + +# --------------------------------------------------------------------------- +# Image generation — /image-model +# --------------------------------------------------------------------------- + + +async def _generate_image( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: ImageParams, +) -> messages_.Message: + """Hit ``/image-model`` and return a Message with FileParts.""" + prompt = _common.extract_prompt(messages) + input_files = _common.extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["files"] = [_common.file_part_to_wire(f) for f in input_files] + + url = f"{client.base_url.rstrip('/')}/image-model" + headers = _common.request_headers(client, model, model_type="image") + + response = await client.http.post(url, json=body, headers=headers) + if response.status_code >= 400: + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), + ) + + data = response.json() + raw_images: list[str] = data.get("images", []) + usage_data = data.get("usage") + usage = None + if usage_data: + usage = messages_.Usage( + input_tokens=usage_data.get("inputTokens") or 0, + output_tokens=usage_data.get("outputTokens") or 0, + ) + + files: list[messages_.Part] = [] + for img_b64 in raw_images: + media_type = media_.detect_image_media_type(img_b64) or "image/png" + files.append(messages_.FilePart(data=img_b64, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files, usage=usage) + + +# --------------------------------------------------------------------------- +# Video generation — /video-model (SSE response) +# --------------------------------------------------------------------------- + + +async def _generate_video( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: VideoParams, +) -> messages_.Message: + """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" + prompt = _common.extract_prompt(messages) + input_files = _common.extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["image"] = _common.file_part_to_wire(input_files[0]) + + url = f"{client.base_url.rstrip('/')}/video-model" + headers = _common.request_headers(client, model, model_type="video") + headers["accept"] = "text/event-stream" + + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + timeout=httpx.Timeout(timeout=600.0, connect=10.0), + ) as response: + if response.status_code >= 400: + await response.aread() + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), + ) + + # Read first SSE data event — the gateway sends a single result event. + event_data: dict[str, Any] = {} + async for parsed in _common.parse_sse_lines(response): + event_data = parsed + break + + if not event_data: + raise errors_.GatewayResponseError( + "SSE stream ended without any data events", + ) + + if event_data.get("type") == "error": + raise errors_.GatewayInvalidRequestError( + message=event_data.get("message", "unknown error"), + status_code=event_data.get("statusCode", 400), + ) + + raw_videos: list[dict[str, Any]] = event_data.get("videos", []) + files: list[messages_.Part] = [] + for video_data in raw_videos: + vtype = video_data.get("type", "base64") + media_type = video_data.get("mediaType", "video/mp4") + + if vtype == "url": + downloaded_bytes, content_type = await media_.download(video_data["url"]) + if content_type: + media_type = content_type + files.append( + messages_.FilePart(data=downloaded_bytes, media_type=media_type) + ) + else: + raw_data = video_data.get("data", "") + files.append(messages_.FilePart(data=raw_data, media_type=media_type)) + + return messages_.Message(role="assistant", parts=files) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +def _check_capabilities( + model: model_.Model, + params: GenerateParams, +) -> None: + """Validate that model capabilities match the requested generation type.""" + caps = model.capabilities + + if isinstance(params, VideoParams): + if "video" not in caps: + raise ValueError( + f"Model {model.id!r} does not have 'video' capability " + f"(capabilities={caps}). Use ImageParams for image models." + ) + if "text" in caps and "video" not in caps: + raise ValueError( + f"Model {model.id!r} is a text model (capabilities={caps}). " + f"Use stream() for text generation, not generate()." + ) + elif isinstance(params, ImageParams): + if "video" in caps and "image" not in caps: + raise ValueError( + f"Model {model.id!r} has 'video' capability but not 'image' " + f"(capabilities={caps}). Use VideoParams for video models." + ) + if "text" in caps and "image" not in caps: + raise ValueError( + f"Model {model.id!r} is a text model (capabilities={caps}). " + f"Use stream() for text generation, not generate()." + ) + + +async def generate( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + params: GenerateParams | None = None, +) -> messages_.Message: + """Generate media (images or video) through the AI Gateway. + + Dispatches to ``/image-model`` or ``/video-model`` based on ``params`` + type, with fallback to model capabilities when ``params`` is ``None``. + + Raises :class:`ValueError` if the model capabilities don't match the + requested generation type. + """ + # Auto-detect from capabilities when no params provided + if params is None: + params = VideoParams() if "video" in model.capabilities else ImageParams() + + _check_capabilities(model, params) + + if isinstance(params, VideoParams): + return await _generate_video(client, model, messages, params) + return await _generate_image(client, model, messages, params) diff --git a/src/vercel_ai_sdk/models/ai_gateway/image.py b/src/vercel_ai_sdk/models/ai_gateway/image.py deleted file mode 100644 index 1f86d8dd..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/image.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Vercel AI Gateway image model.""" - -from __future__ import annotations - -import os -from typing import Any, override - -import httpx - -from ...types import messages as messages_ -from ..core import image as image_ -from ..core.media import base as media_base -from ..core.media import detect as detect_media_type -from . import errors as errors_ -from .llm import _DEFAULT_BASE_URL, _base_headers, _file_part_to_wire, _raise_for_status - - -class GatewayImageModel(image_.ImageModel): - """Vercel AI Gateway image model. - - Sends requests to ``/v3/ai/image-model`` and returns a :class:`Message` - with :class:`FilePart`\\s for each generated image. - - Args: - model: Model identifier (e.g. ``'google/imagen-4.0-generate-001'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "google/imagen-4.0-generate-001", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._extra_headers = headers or {} - self._transport = _transport - - def _headers(self) -> dict[str, str]: - return _base_headers( - self._api_key, - { - "ai-image-model-specification-version": "3", - "ai-model-id": self._model, - **self._extra_headers, - }, - ) - - @override - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> media_base.MediaResult: - body: dict[str, Any] = { - "prompt": prompt, - "n": n, - "providerOptions": provider_options or {}, - } - if size is not None: - body["size"] = size - if aspect_ratio is not None: - body["aspectRatio"] = aspect_ratio - if seed is not None: - body["seed"] = seed - if input_files: - body["files"] = [_file_part_to_wire(f) for f in input_files] - - url = f"{self._base_url}/image-model" - try: - async with httpx.AsyncClient(transport=self._transport) as client: - response = await client.post( - url, - json=body, - headers=self._headers(), - timeout=httpx.Timeout(timeout=300.0, connect=10.0), - ) - if response.status_code >= 400: - await _raise_for_status(response, api_key=self._api_key) - - data = response.json() - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError(cause=exc) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=f"Gateway image request failed: {exc}", - cause=exc, - ) from exc - - # Parse response: {images: string[], warnings?, usage?} - raw_images: list[str] = data.get("images", []) - usage_data = data.get("usage") - usage = None - if usage_data: - usage = messages_.Usage( - input_tokens=usage_data.get("inputTokens") or 0, - output_tokens=usage_data.get("outputTokens") or 0, - ) - - files: list[messages_.FilePart] = [] - for img_b64 in raw_images: - media_type = detect_media_type.detect_image_media_type(img_b64) - files.append( - messages_.FilePart( - data=img_b64, - media_type=media_type or "image/png", - ) - ) - - return media_base.MediaResult(files=files, usage=usage) diff --git a/src/vercel_ai_sdk/models/ai_gateway/llm.py b/src/vercel_ai_sdk/models/ai_gateway/llm.py deleted file mode 100644 index b4e0d38f..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/llm.py +++ /dev/null @@ -1,219 +0,0 @@ -"""Vercel AI Gateway language model using the v3 protocol.""" - -from __future__ import annotations - -import base64 -import json -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import httpx -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core.media import data as media_data -from . import errors as errors_ -from . import protocol as protocol_ - -_DEFAULT_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" -_PROTOCOL_VERSION = "0.0.1" - - -class GatewayModel(llm_.LanguageModel): - """Vercel AI Gateway language model using the v3 protocol. - - Sends the AI SDK's native message format directly to the gateway - server and receives responses in the AI SDK's native stream-part - format. The gateway server handles all provider-specific - translation. - - Args: - model: Model identifier in ``provider/model`` format - (e.g. ``'anthropic/claude-sonnet-4'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - provider_options: Gateway options (``order``, ``only``, - ``models``, ``byok``, ``tags``, etc.). - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "anthropic/claude-sonnet-4", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - provider_options: dict[str, Any] | None = None, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._provider_options = provider_options - self._extra_headers = headers or {} - self._transport = _transport - - # -- Internals ----------------------------------------------------------- - - def _headers(self, *, streaming: bool) -> dict[str, str]: - h: dict[str, str] = { - "Content-Type": "application/json", - "Authorization": f"Bearer {self._api_key}", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - "ai-language-model-specification-version": "3", - "ai-language-model-id": self._model, - "ai-language-model-streaming": str(streaming).lower(), - } - if self._api_key: - h["ai-gateway-auth-method"] = "api-key" - h.update(self._extra_headers) - return h - - async def _raise_for_status(self, response: httpx.Response) -> None: - """Raise a typed :class:`GatewayError` for HTTP >= 400.""" - try: - body: Any = response.json() - except Exception: - body = response.text - raise errors_.create_gateway_error( - response_body=body, - status_code=response.status_code, - api_key_provided=bool(self._api_key), - ) - - # -- Stream events ------------------------------------------------------- - - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield ``StreamEvent`` objects from the gateway SSE stream.""" - body = await protocol_.build_request_body( - messages, - tools=tools, - output_type=output_type, - provider_options=self._provider_options, - ) - url = f"{self._base_url}/language-model" - try: - async with ( - httpx.AsyncClient(transport=self._transport) as client, - client.stream( - "POST", - url, - json=body, - headers=self._headers(streaming=True), - timeout=httpx.Timeout(timeout=300.0, connect=10.0), - ) as response, - ): - if response.status_code >= 400: - await response.aread() - await self._raise_for_status(response) - - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - data = json.loads(payload) - except json.JSONDecodeError: - continue - for event in protocol_.parse_stream_part(data): - yield event - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError( - cause=exc, - ) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=( - f"Invalid error response format: Gateway request failed: {exc}" - ), - cause=exc, - ) from exc - - # -- LanguageModel interface --------------------------------------------- - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type): - msg = handler.handle_event(event) - yield msg - - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) - part = messages_.StructuredOutputPart( - data=data, - output_type_name=( - f"{output_type.__module__}.{output_type.__qualname__}" - ), - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg - - -# --------------------------------------------------------------------------- -# Shared helpers for image/video models -# --------------------------------------------------------------------------- - - -def _base_headers(api_key: str, extra: dict[str, str]) -> dict[str, str]: - """Build common gateway headers.""" - h: dict[str, str] = { - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - } - if api_key: - h["ai-gateway-auth-method"] = "api-key" - h.update(extra) - return h - - -async def _raise_for_status(response: httpx.Response, *, api_key: str) -> None: - """Raise a typed :class:`GatewayError` for HTTP >= 400.""" - try: - body: Any = response.json() - except Exception: - body = response.text - raise errors_.create_gateway_error( - response_body=body, - status_code=response.status_code, - api_key_provided=bool(api_key), - ) - - -def _file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to the gateway wire format for input files.""" - data = part.data - if isinstance(data, str) and media_data.is_url(data): - return {"type": "url", "url": data} - if isinstance(data, bytes): - b64 = base64.b64encode(data).decode("ascii") - elif isinstance(data, str): - # Assume raw base64 - b64 = data - else: - b64 = str(data) - return {"type": "file", "data": b64, "mediaType": part.media_type} diff --git a/src/vercel_ai_sdk/models/ai_gateway/protocol.py b/src/vercel_ai_sdk/models/ai_gateway/protocol.py deleted file mode 100644 index 8b547396..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/protocol.py +++ /dev/null @@ -1,425 +0,0 @@ -"""Vercel AI Gateway v3 protocol: serialization and deserialization. - -Converts between the Python SDK's internal ``Message`` / ``StreamEvent`` -types and the LanguageModelV3 wire format used by the gateway at -``/v3/ai/language-model``. - -Wire format reference (from ``@ai-sdk/provider``): - -* **Request body** -- ``LanguageModelV3CallOptions`` (prompt + tools + - provider options, sent as JSON). -* **Stream response** -- Server-Sent Events where each ``data:`` line is - a JSON ``LanguageModelV3StreamPart`` (discriminated on ``type``). -* **Non-stream response** -- JSON ``LanguageModelV3GenerateResult``. -""" - -import json -from collections.abc import Sequence -from typing import Any - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core.media import data as media_data -from ..core.media import download as media_download - -# --------------------------------------------------------------------------- -# Internal messages -> v3 prompt format (outgoing request body) -# --------------------------------------------------------------------------- - - -async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: - """Convert an internal :class:`FilePart` to a v3 ``file`` content part. - - Binary data is converted to a ``data:`` URL for JSON transport (matching - the JS SDK gateway's ``maybeEncodeFileParts``). HTTP(S) URLs are - downloaded and converted to ``data:`` URLs because the gateway wire - format does not accept raw HTTP URLs for file content. - """ - data = part.data - if isinstance(data, str) and media_data.is_downloadable_url(data): - downloaded, _ = await media_download.download(data) - data = downloaded - - entry: dict[str, Any] = { - "type": "file", - "mediaType": part.media_type, - "data": media_data.data_to_data_url(data, part.media_type), - } - if part.filename is not None: - entry["filename"] = part.filename - return entry - - -async def messages_to_v3_prompt( - messages: list[messages_.Message], -) -> list[dict[str, Any]]: - """Convert internal ``Message`` list to ``LanguageModelV3Prompt``. - - The v3 prompt format is an array of messages, each with a ``role`` and - typed ``content`` parts:: - - [ - {"role": "system", "content": "You are helpful."}, - {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, - {"role": "assistant", "content": [ - {"type": "text", "text": "Hello!"}, - {"type": "reasoning", "text": "..."}, - {"type": "tool-call", "toolCallId": "tc-1", ...}, - ]}, - {"role": "tool", "content": [ - {"type": "tool-result", "toolCallId": "tc-1", ...}, - ]}, - ] - """ - result: list[dict[str, Any]] = [] - for msg in messages: - match msg.role: - case "system": - text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "system", "content": text}) - - case "user": - content: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - content.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - content.append(await _file_part_to_v3(p)) - result.append({"role": "user", "content": content}) - - case "assistant": - assistant_content: list[dict[str, Any]] = [] - tool_results: list[dict[str, Any]] = [] - - for part in msg.parts: - match part: - case messages_.ReasoningPart(text=text): - assistant_content.append( - {"type": "reasoning", "text": text} - ) - - case messages_.TextPart(text=text): - assistant_content.append({"type": "text", "text": text}) - - case messages_.ToolPart() as tp: - tool_input: Any = ( - json.loads(tp.tool_args) if tp.tool_args else {} - ) - assistant_content.append( - { - "type": "tool-call", - "toolCallId": tp.tool_call_id, - "toolName": tp.tool_name, - "input": tool_input, - } - ) - if tp.status in ("result", "error"): - output = ( - { - "type": "error-text", - "value": ( - str(tp.result) - if tp.result is not None - else "" - ), - } - if tp.status == "error" - else { - "type": "json", - "value": tp.result, - } - ) - tool_results.append( - { - "type": "tool-result", - "toolCallId": tp.tool_call_id, - "toolName": tp.tool_name, - "output": output, - } - ) - - result.append( - { - "role": "assistant", - "content": assistant_content, - } - ) - if tool_results: - result.append( - { - "role": "tool", - "content": tool_results, - } - ) - - return result - - -# --------------------------------------------------------------------------- -# Request body serialization -# --------------------------------------------------------------------------- - - -async def build_request_body( - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[Any] | None = None, - provider_options: dict[str, Any] | None = None, -) -> dict[str, Any]: - """Build the full ``LanguageModelV3CallOptions`` request body.""" - body: dict[str, Any] = { - "prompt": await messages_to_v3_prompt(messages), - } - if tools: - body["tools"] = [ - { - "type": "function", - "name": tool.name, - "description": tool.description, - "inputSchema": tool.param_schema, - } - for tool in tools - ] - if output_type is not None: - import pydantic - - if issubclass(output_type, pydantic.BaseModel): - body["responseFormat"] = { - "type": "json", - "schema": output_type.model_json_schema(), - "name": output_type.__name__, - } - if provider_options: - body["providerOptions"] = provider_options - return body - - -# --------------------------------------------------------------------------- -# v3 stream parts -> internal StreamEvent (incoming SSE response) -# --------------------------------------------------------------------------- - - -def parse_stream_part( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Convert a ``LanguageModelV3StreamPart`` to internal events. - - Most parts map 1:1. A ``tool-call`` part (complete, non-streaming) - expands to Start + ArgsDelta + End. Lifecycle events - (``stream-start``, ``response-metadata``, ``raw``) are silently - dropped. - """ - match data.get("type", ""): - case "text-start": - return [ - llm_.TextStart( - block_id=data.get("id", "text"), - ) - ] - - case "text-delta": - return [ - llm_.TextDelta( - block_id=data.get("id", "text"), - delta=data.get("textDelta", data.get("delta", "")), - ) - ] - - case "text-end": - return [ - llm_.TextEnd( - block_id=data.get("id", "text"), - ) - ] - - case "reasoning-start": - return [ - llm_.ReasoningStart( - block_id=data.get("id", "reasoning"), - ) - ] - - case "reasoning-delta": - return [ - llm_.ReasoningDelta( - block_id=data.get("id", "reasoning"), - delta=data.get("delta", ""), - ) - ] - - case "reasoning-end": - return [ - llm_.ReasoningEnd( - block_id=data.get("id", "reasoning"), - ) - ] - - case "tool-input-start": - return [ - llm_.ToolStart( - tool_call_id=data.get("id", ""), - tool_name=data.get("toolName", ""), - ) - ] - - case "tool-input-delta": - return [ - llm_.ToolArgsDelta( - tool_call_id=data.get("id", ""), - delta=data.get("delta", ""), - ) - ] - - case "tool-input-end": - return [ - llm_.ToolEnd( - tool_call_id=data.get("id", ""), - ) - ] - - case "tool-call": - return _expand_tool_call(data) - - case "file": - return [ - llm_.FileEvent( - block_id=data.get("id", f"file-{len(data)}"), - media_type=data.get("mediaType", "application/octet-stream"), - data=data.get("data", ""), - ) - ] - - case "finish": - return [_parse_finish(data)] - - case _: - return [] - - -# --------------------------------------------------------------------------- -# Non-streaming response -> internal StreamEvents -# --------------------------------------------------------------------------- - - -def parse_generate_result( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Convert a ``LanguageModelV3GenerateResult`` into events. - - Synthesises Start/Delta/End events from the content, then a final - ``MessageDone``. - """ - events: list[llm_.StreamEvent] = [] - - def _expand_content_item(item: dict[str, Any]) -> None: - match item.get("type", ""): - case "text": - bid = item.get("id", "text") - text = item.get("text", "") - events.append(llm_.TextStart(block_id=bid)) - events.append(llm_.TextDelta(block_id=bid, delta=text)) - events.append(llm_.TextEnd(block_id=bid)) - - case "reasoning": - bid = item.get("id", "reasoning") - text = item.get("text", "") - events.append(llm_.ReasoningStart(block_id=bid)) - events.append(llm_.ReasoningDelta(block_id=bid, delta=text)) - events.append(llm_.ReasoningEnd(block_id=bid)) - - case "tool-call": - events.extend(_expand_tool_call(item)) - - case "file": - events.append( - llm_.FileEvent( - block_id=item.get("id", f"file-{len(events)}"), - media_type=item.get("mediaType", "application/octet-stream"), - data=item.get("data", ""), - ) - ) - - match data.get("content"): - case list() as items: - for item in items: - _expand_content_item(item) - case dict() as item: - _expand_content_item(item) - - events.append(_parse_finish(data)) - return events - - -# --------------------------------------------------------------------------- -# Shared helpers (called from multiple sites) -# --------------------------------------------------------------------------- - - -def _expand_tool_call( - data: dict[str, Any], -) -> list[llm_.StreamEvent]: - """Expand a complete ``tool-call`` part into three events.""" - tc_id = data.get("toolCallId", "") - tool_name = data.get("toolName", "") - tool_input = data.get("input", "") - args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) - return [ - llm_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - llm_.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), - llm_.ToolEnd(tool_call_id=tc_id), - ] - - -def _parse_finish(data: dict[str, Any]) -> llm_.MessageDone: - """Parse a ``finish`` stream part into a ``MessageDone`` event.""" - usage_data = data.get("usage") - usage = _parse_usage(usage_data) if usage_data else None - - match data.get("finishReason"): - case dict() as d: - finish_reason = d.get("unified", "stop") - case str() as s: - finish_reason = s - case _: - finish_reason = "stop" - - return llm_.MessageDone(finish_reason=finish_reason, usage=usage) - - -def _parse_usage(data: Any) -> messages_.Usage: - """Parse a v3 ``LanguageModelV3Usage`` into an internal ``Usage``. - - Supports both the v3 nested format:: - - {"inputTokens": {"total": 10, ...}, "outputTokens": {...}} - - and the flat OpenAI-style format:: - - {"prompt_tokens": 10, "completion_tokens": 20} - """ - if not isinstance(data, dict): - return messages_.Usage() - - input_tokens_obj = data.get("inputTokens") - output_tokens_obj = data.get("outputTokens") - - if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): - inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} - out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} - return messages_.Usage( - input_tokens=inp.get("total") or 0, - output_tokens=out.get("total") or 0, - reasoning_tokens=out.get("reasoning"), - cache_read_tokens=inp.get("cacheRead"), - cache_write_tokens=inp.get("cacheWrite"), - raw=data, - ) - - return messages_.Usage( - input_tokens=(data.get("prompt_tokens") or data.get("inputTokens") or 0), - output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), - raw=data, - ) diff --git a/src/vercel_ai_sdk/models/ai_gateway/stream.py b/src/vercel_ai_sdk/models/ai_gateway/stream.py new file mode 100644 index 00000000..92a63266 --- /dev/null +++ b/src/vercel_ai_sdk/models/ai_gateway/stream.py @@ -0,0 +1,338 @@ +"""AI Gateway v3 streaming adapter — language-model endpoint. + +Handles text, tool-call, reasoning, and inline file streaming via SSE. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import httpx +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ +from . import _common +from . import errors as errors_ + +# --------------------------------------------------------------------------- +# Request building — Message list → v3 prompt +# --------------------------------------------------------------------------- + + +async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to a v3 ``file`` content part.""" + data = part.data + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + + entry: dict[str, Any] = { + "type": "file", + "mediaType": part.media_type, + "data": media_.data_to_data_url(data, part.media_type), + } + if part.filename is not None: + entry["filename"] = part.filename + return entry + + +async def _messages_to_prompt( + messages: list[messages_.Message], +) -> list[dict[str, Any]]: + """Convert ``Message`` list to the v3 prompt wire format.""" + result: list[dict[str, Any]] = [] + + for msg in messages: + match msg.role: + case "system": + text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "system", "content": text}) + + case "user": + content: list[dict[str, Any]] = [] + for p in msg.parts: + if isinstance(p, messages_.TextPart): + content.append({"type": "text", "text": p.text}) + elif isinstance(p, messages_.FilePart): + content.append(await _file_part_to_v3(p)) + result.append({"role": "user", "content": content}) + + case "assistant": + assistant_content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text): + assistant_content.append( + {"type": "reasoning", "text": text} + ) + + case messages_.TextPart(text=text): + assistant_content.append({"type": "text", "text": text}) + + case messages_.ToolPart() as tp: + tool_input: Any = ( + json.loads(tp.tool_args) if tp.tool_args else {} + ) + assistant_content.append( + { + "type": "tool-call", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "input": tool_input, + } + ) + if tp.status in ("result", "error"): + output = ( + { + "type": "error-text", + "value": ( + str(tp.result) + if tp.result is not None + else "" + ), + } + if tp.status == "error" + else { + "type": "json", + "value": tp.result, + } + ) + tool_results.append( + { + "type": "tool-result", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "output": output, + } + ) + + result.append({"role": "assistant", "content": assistant_content}) + if tool_results: + result.append({"role": "tool", "content": tool_results}) + + return result + + +async def _build_request_body( + messages: list[messages_.Message], + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[Any] | None = None, + **kwargs: Any, +) -> dict[str, Any]: + """Build the ``LanguageModelV3CallOptions`` request body.""" + body: dict[str, Any] = { + "prompt": await _messages_to_prompt(messages), + } + if tools: + body["tools"] = [ + { + "type": "function", + "name": tool.name, + "description": tool.description, + "inputSchema": tool.param_schema, + } + for tool in tools + ] + if output_type is not None and issubclass(output_type, pydantic.BaseModel): + body["responseFormat"] = { + "type": "json", + "schema": output_type.model_json_schema(), + "name": output_type.__name__, + } + if kwargs.get("provider_options"): + body["providerOptions"] = kwargs["provider_options"] + return body + + +# --------------------------------------------------------------------------- +# SSE response parsing — v3 stream parts → StreamEvent +# --------------------------------------------------------------------------- + + +def _expand_tool_call(data: dict[str, Any]) -> list[streaming_.StreamEvent]: + """Expand a complete ``tool-call`` part into Start + ArgsDelta + End.""" + tc_id = data.get("toolCallId", "") + tool_name = data.get("toolName", "") + tool_input = data.get("input", "") + args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) + return [ + streaming_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + streaming_.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), + streaming_.ToolEnd(tool_call_id=tc_id), + ] + + +def _parse_usage(data: Any) -> messages_.Usage: + """Parse v3 usage data into an internal ``Usage``.""" + if not isinstance(data, dict): + return messages_.Usage() + + input_tokens_obj = data.get("inputTokens") + output_tokens_obj = data.get("outputTokens") + + if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): + inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} + out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} + return messages_.Usage( + input_tokens=inp.get("total") or 0, + output_tokens=out.get("total") or 0, + reasoning_tokens=out.get("reasoning"), + cache_read_tokens=inp.get("cacheRead"), + cache_write_tokens=inp.get("cacheWrite"), + raw=data, + ) + + return messages_.Usage( + input_tokens=data.get("prompt_tokens") or data.get("inputTokens") or 0, + output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), + raw=data, + ) + + +def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: + """Convert a ``LanguageModelV3StreamPart`` to internal events.""" + match data.get("type", ""): + case "text-start": + return [streaming_.TextStart(block_id=data.get("id", "text"))] + + case "text-delta": + return [ + streaming_.TextDelta( + block_id=data.get("id", "text"), + delta=data.get("textDelta", data.get("delta", "")), + ) + ] + + case "text-end": + return [streaming_.TextEnd(block_id=data.get("id", "text"))] + + case "reasoning-start": + return [streaming_.ReasoningStart(block_id=data.get("id", "reasoning"))] + + case "reasoning-delta": + return [ + streaming_.ReasoningDelta( + block_id=data.get("id", "reasoning"), + delta=data.get("delta", ""), + ) + ] + + case "reasoning-end": + return [streaming_.ReasoningEnd(block_id=data.get("id", "reasoning"))] + + case "tool-input-start": + return [ + streaming_.ToolStart( + tool_call_id=data.get("id", ""), + tool_name=data.get("toolName", ""), + ) + ] + + case "tool-input-delta": + return [ + streaming_.ToolArgsDelta( + tool_call_id=data.get("id", ""), + delta=data.get("delta", ""), + ) + ] + + case "tool-input-end": + return [streaming_.ToolEnd(tool_call_id=data.get("id", ""))] + + case "tool-call": + return _expand_tool_call(data) + + case "file": + return [ + streaming_.FileEvent( + block_id=data.get("id", f"file-{len(data)}"), + media_type=data.get("mediaType", "application/octet-stream"), + data=data.get("data", ""), + ) + ] + + case "finish": + usage_data = data.get("usage") + usage = _parse_usage(usage_data) if usage_data else None + match data.get("finishReason"): + case dict() as d: + finish_reason = d.get("unified", "stop") + case str() as s: + finish_reason = s + case _: + finish_reason = "stop" + return [streaming_.MessageDone(finish_reason=finish_reason, usage=usage)] + + case _: + return [] + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response through the AI Gateway v3 protocol. + + Yields ``Message`` snapshots as the response streams in. Each + snapshot is a complete, self-contained message reflecting the + accumulated state up to that point. + """ + body = await _build_request_body( + messages, tools=tools, output_type=output_type, **kwargs + ) + headers = _common.request_headers( + client, model, model_type="language", streaming=True + ) + url = f"{client.base_url.rstrip('/')}/language-model" + + handler = streaming_.StreamHandler() + + try: + async with client.http.stream( + "POST", + url, + json=body, + headers=headers, + ) as response: + if response.status_code >= 400: + await response.aread() + raise errors_.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(client.api_key), + ) + + async for data in _common.parse_sse_lines(response): + for event in _parse_stream_part(data): + msg = handler.handle_event(event) + yield msg + except errors_.GatewayError: + raise + except httpx.TimeoutException as exc: + raise errors_.GatewayTimeoutError(cause=exc) from exc + except Exception as exc: + raise errors_.GatewayResponseError( + message=f"Unexpected error during streaming: {exc}", + cause=exc, + ) from exc diff --git a/src/vercel_ai_sdk/models/ai_gateway/video.py b/src/vercel_ai_sdk/models/ai_gateway/video.py deleted file mode 100644 index 86ca88b3..00000000 --- a/src/vercel_ai_sdk/models/ai_gateway/video.py +++ /dev/null @@ -1,212 +0,0 @@ -"""Vercel AI Gateway video model.""" - -from __future__ import annotations - -import json -import os -from typing import Any, override - -import httpx - -from ...types import messages as messages_ -from ..core import video as video_ -from ..core.media import base as media_base -from ..core.media import detect as detect_media_type -from ..core.media import download as media_download -from . import errors as errors_ -from .llm import _DEFAULT_BASE_URL, _base_headers, _file_part_to_wire, _raise_for_status - - -class GatewayVideoModel(video_.VideoModel): - """Vercel AI Gateway video model. - - Sends requests to ``/v3/ai/video-model`` (with SSE response) and returns - a :class:`Message` with :class:`FilePart`\\s for each generated video. - - Args: - model: Model identifier (e.g. ``'google/veo-3.0-generate-001'``). - api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. - base_url: Gateway base URL. - headers: Extra headers for every request. - """ - - def __init__( - self, - model: str = "google/veo-3.0-generate-001", - api_key: str | None = None, - base_url: str = _DEFAULT_BASE_URL, - headers: dict[str, str] | None = None, - *, - _transport: httpx.AsyncBaseTransport | None = None, - ) -> None: - self._model = model - self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" - self._base_url = base_url.rstrip("/") - self._extra_headers = headers or {} - self._transport = _transport - - def _headers(self) -> dict[str, str]: - return _base_headers( - self._api_key, - { - "ai-video-model-specification-version": "3", - "ai-model-id": self._model, - "accept": "text/event-stream", - **self._extra_headers, - }, - ) - - @override - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> media_base.MediaResult: - image_wire: dict[str, Any] | None = None - if input_files: - image_wire = _file_part_to_wire(input_files[0]) - - body: dict[str, Any] = { - "prompt": prompt, - "n": n, - "providerOptions": provider_options or {}, - } - if aspect_ratio is not None: - body["aspectRatio"] = aspect_ratio - if resolution is not None: - body["resolution"] = resolution - if duration is not None: - body["duration"] = duration - if fps is not None: - body["fps"] = fps - if seed is not None: - body["seed"] = seed - if image_wire is not None: - body["image"] = image_wire - - url = f"{self._base_url}/video-model" - try: - async with ( - httpx.AsyncClient(transport=self._transport) as client, - client.stream( - "POST", - url, - json=body, - headers=self._headers(), - timeout=httpx.Timeout(timeout=600.0, connect=10.0), - ) as response, - ): - if response.status_code >= 400: - await response.aread() - await _raise_for_status(response, api_key=self._api_key) - - event_data = await self._read_first_sse_event(response) - - except errors_.GatewayError: - raise - except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError(cause=exc) from exc - except Exception as exc: - raise errors_.GatewayResponseError( - message=f"Gateway video request failed: {exc}", - cause=exc, - ) from exc - - # Handle error event - if event_data.get("type") == "error": - status = event_data.get("statusCode", 500) - message = event_data.get("message", "Video generation failed") - error_type = event_data.get("errorType", "") - if status == 400 or error_type == "invalid_request_error": - raise errors_.GatewayInvalidRequestError( - message=message, status_code=status - ) - raise errors_.GatewayResponseError(message=message, status_code=status) - - # Handle result event - raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - files: list[messages_.FilePart] = [] - for video_data in raw_videos: - file_part = await self._video_data_to_file_part(video_data) - files.append(file_part) - - return media_base.MediaResult(files=files) - - @staticmethod - async def _read_first_sse_event(response: httpx.Response) -> dict[str, Any]: - """Read and parse the first SSE data event from the response.""" - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - result: dict[str, Any] = json.loads(payload) - return result - except json.JSONDecodeError: - continue - raise errors_.GatewayResponseError( - message="SSE stream ended without a data event", - ) - - @staticmethod - async def _video_data_to_file_part( - video_data: dict[str, Any], - ) -> messages_.FilePart: - """Convert a gateway video result to a :class:`FilePart`. - - Handles ``{type: "url", url, mediaType}`` (downloads the video) - and ``{type: "base64", data, mediaType}``. - """ - vtype = video_data.get("type", "base64") - media_type = video_data.get("mediaType", "video/mp4") - - if vtype == "url": - video_url = video_data["url"] - downloaded_bytes, content_type = await media_download.download(video_url) - # Prefer provider mediaType, then download content-type, then detect - if media_type == "video/mp4" and content_type: - media_type = content_type - detected = detect_media_type.detect_media_type( - downloaded_bytes, detect_media_type.VIDEO_SIGNATURES - ) - if detected: - media_type = detected - return messages_.FilePart( - data=downloaded_bytes, - media_type=media_type, - ) - - # base64 - data = video_data.get("data", "") - detected = detect_media_type.detect_media_type( - data, detect_media_type.VIDEO_SIGNATURES - ) - if detected: - media_type = detected - return messages_.FilePart( - data=data, - media_type=media_type, - ) - - -# --------------------------------------------------------------------------- -# Stubs for future model types -# --------------------------------------------------------------------------- - - -class GatewayEmbeddingModel: - """Stub -- not yet implemented.""" - - def __init__(self, model: str, **kwargs: Any) -> None: - raise NotImplementedError("GatewayEmbeddingModel is not yet implemented.") diff --git a/src/vercel_ai_sdk/models/anthropic/__init__.py b/src/vercel_ai_sdk/models/anthropic/__init__.py index 38716ce4..a9a0436b 100644 --- a/src/vercel_ai_sdk/models/anthropic/__init__.py +++ b/src/vercel_ai_sdk/models/anthropic/__init__.py @@ -1,5 +1,7 @@ -"""Anthropic provider adapter.""" +"""Anthropic provider — adapter for the Anthropic messages API.""" -from .llm import AnthropicModel, _messages_to_anthropic +from .adapter import stream -__all__ = ["AnthropicModel", "_messages_to_anthropic"] +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models/anthropic/adapter.py b/src/vercel_ai_sdk/models/anthropic/adapter.py new file mode 100644 index 00000000..7ad3d25c --- /dev/null +++ b/src/vercel_ai_sdk/models/anthropic/adapter.py @@ -0,0 +1,389 @@ +"""Anthropic adapter — messages API. + +Message/tool conversion and streaming via the official ``anthropic`` SDK. +The SDK client is constructed from :class:`Client` params on each call. +""" + +from __future__ import annotations + +import json +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import anthropic +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ + +# --------------------------------------------------------------------------- +# Message / tool conversion — internal types → Anthropic wire format +# --------------------------------------------------------------------------- + + +def _tools_to_anthropic( + tools: Sequence[tools_.ToolLike], +) -> list[dict[str, Any]]: + """Convert internal Tool objects to Anthropic tool schema format.""" + return [ + { + "name": tool.name, + "description": tool.description, + "input_schema": tool.param_schema, + } + for tool in tools + ] + + +def _file_part_to_anthropic( + part: messages_.FilePart, +) -> dict[str, Any]: + """Convert a :class:`FilePart` to an Anthropic content block. + + * ``image/*`` -> ``{"type": "image", "source": ...}`` + * ``application/pdf`` -> ``{"type": "document", "source": ...}`` + * ``text/plain`` -> ``{"type": "document", "source": ...}`` + * anything else -> ``ValueError`` + """ + mt = part.media_type + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + if isinstance(part.data, str) and media_.is_url(part.data): + return { + "type": "image", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": media_.data_to_base64(part.data), + }, + } + + if mt == "application/pdf": + if isinstance(part.data, str) and media_.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": media_.data_to_base64(part.data), + }, + } + + if mt == "text/plain": + if isinstance(part.data, bytes): + text_data = part.data.decode("utf-8") + elif media_.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + else: + import base64 as _b64 + + text_data = _b64.b64decode(part.data).decode("utf-8") + return { + "type": "document", + "source": { + "type": "text", + "media_type": "text/plain", + "data": text_data, + }, + } + + raise ValueError(f"Unsupported media type for Anthropic: {mt}") + + +async def _messages_to_anthropic( + messages: list[messages_.Message], +) -> tuple[str | None, list[dict[str, Any]]]: + """Convert internal messages to Anthropic API format. + + Returns ``(system_prompt, messages_list)``. The system prompt is + extracted separately because the Anthropic API takes it as a + top-level parameter. + """ + system_prompt: str | None = None + result: list[dict[str, Any]] = [] + + for msg in messages: + match msg.role: + case "system": + system_prompt = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + case "assistant": + content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text, signature=signature): + if signature: + content.append( + { + "type": "thinking", + "thinking": text, + "signature": signature, + } + ) + case messages_.TextPart(text=text): + content.append({"type": "text", "text": text}) + case messages_.ToolPart(): + tool_input = ( + json.loads(part.tool_args) if part.tool_args else {} + ) + content.append( + { + "type": "tool_use", + "id": part.tool_call_id, + "name": part.tool_name, + "input": tool_input, + } + ) + if part.status in ("result", "error"): + entry: dict[str, Any] = { + "type": "tool_result", + "tool_use_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + if part.status == "error": + entry["is_error"] = True + tool_results.append(entry) + + if content: + result.append({"role": "assistant", "content": content}) + if tool_results: + result.append({"role": "user", "content": tool_results}) + + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + content_text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": content_text}) + else: + user_content: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + user_content.append({"type": "text", "text": text}) + case messages_.FilePart(): + user_content.append(_file_part_to_anthropic(p)) + result.append({"role": "user", "content": user_content}) + + result = _merge_consecutive_roles(result) + return system_prompt, result + + +def _merge_consecutive_roles( + messages: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Merge consecutive messages that share the same role. + + Anthropic requires strictly alternating user/assistant roles. + """ + if not messages: + return messages + + merged: list[dict[str, Any]] = [messages[0]] + + for msg in messages[1:]: + if msg["role"] == merged[-1]["role"]: + prev = _to_content_list(merged[-1]["content"]) + cur = _to_content_list(msg["content"]) + merged[-1]["content"] = prev + cur + else: + merged.append(msg) + + return merged + + +def _to_content_list(content: Any) -> list[dict[str, Any]]: + """Normalize Anthropic message content to list-of-blocks.""" + if isinstance(content, list): + return list(content) + return [{"type": "text", "text": content}] + + +# --------------------------------------------------------------------------- +# SDK client factory +# --------------------------------------------------------------------------- + + +def _make_client( + client: client_.Client, +) -> anthropic.AsyncAnthropic: + """Construct an ``AsyncAnthropic`` from our generic ``Client``.""" + return anthropic.AsyncAnthropic( + base_url=client.base_url, + api_key=client.api_key or "", + ) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + thinking: bool = False, + budget_tokens: int = 10000, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response via the Anthropic messages API. + + Yields ``Message`` snapshots as the response streams in. + + Extra keyword arguments beyond the ``StreamFn`` protocol: + + * ``thinking`` — enable extended thinking output. + * ``budget_tokens`` — max tokens for thinking (default 10000). + """ + sdk_client = _make_client(client) + system_prompt, anthropic_messages = await _messages_to_anthropic(messages) + anthropic_tools = _tools_to_anthropic(tools) if tools else None + + api_kwargs: dict[str, Any] = { + "model": model.id, + "messages": anthropic_messages, + "max_tokens": 8192, + } + if system_prompt: + api_kwargs["system"] = system_prompt + if anthropic_tools: + api_kwargs["tools"] = anthropic_tools + + if thinking: + api_kwargs["thinking"] = { + "type": "enabled", + "budget_tokens": budget_tokens, + } + + if output_type is not None: + api_kwargs["output_format"] = output_type + + handler = streaming_.StreamHandler() + + block_types: dict[int, str] = {} + tool_ids: dict[int, str] = {} + signature_buffer: dict[int, str] = {} + + try: + stream_cm = sdk_client.messages.stream(**api_kwargs) + + async with stream_cm as sdk_stream: + async for event in sdk_stream: + match event.type: + case "content_block_start": + block = event.content_block + idx = event.index + block_types[idx] = block.type + + match block.type: + case "text": + yield handler.handle_event( + streaming_.TextStart(block_id=str(idx)) + ) + case "thinking": + yield handler.handle_event( + streaming_.ReasoningStart(block_id=str(idx)) + ) + case "tool_use": + tool_ids[idx] = block.id + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=block.id, + tool_name=block.name, + ) + ) + + case "content_block_delta": + delta = event.delta + idx = event.index + + match delta.type: + case "text_delta": + yield handler.handle_event( + streaming_.TextDelta( + block_id=str(idx), + delta=delta.text, + ) + ) + case "thinking_delta": + yield handler.handle_event( + streaming_.ReasoningDelta( + block_id=str(idx), + delta=delta.thinking, + ) + ) + case "signature_delta": + signature_buffer[idx] = ( + signature_buffer.get(idx, "") + delta.signature + ) + case "input_json_delta": + tool_id = tool_ids.get(idx) + if tool_id: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=tool_id, + delta=delta.partial_json, + ) + ) + + case "content_block_stop": + idx = event.index + match block_types.get(idx): + case "text": + yield handler.handle_event( + streaming_.TextEnd(block_id=str(idx)) + ) + case "thinking": + yield handler.handle_event( + streaming_.ReasoningEnd( + block_id=str(idx), + signature=signature_buffer.get(idx), + ) + ) + case "tool_use": + tool_id = tool_ids.get(idx) + if tool_id: + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=tool_id) + ) + + snapshot = sdk_stream.current_message_snapshot + sdk_usage = snapshot.usage + usage = messages_.Usage( + input_tokens=sdk_usage.input_tokens or 0, + output_tokens=sdk_usage.output_tokens or 0, + cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), + cache_write_tokens=getattr( + sdk_usage, "cache_creation_input_tokens", None + ), + raw=sdk_usage.model_dump(exclude_none=True) or None, + ) + yield handler.handle_event(streaming_.MessageDone(usage=usage)) + finally: + await sdk_client.close() diff --git a/src/vercel_ai_sdk/models/anthropic/llm.py b/src/vercel_ai_sdk/models/anthropic/llm.py deleted file mode 100644 index f2812b31..00000000 --- a/src/vercel_ai_sdk/models/anthropic/llm.py +++ /dev/null @@ -1,368 +0,0 @@ -from __future__ import annotations - -import json -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import anthropic -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core import media - - -def _tools_to_anthropic(tools: Sequence[tools_.ToolLike]) -> list[dict[str, Any]]: - """Convert internal Tool objects to Anthropic tool schema format.""" - return [ - { - "name": tool.name, - "description": tool.description, - "input_schema": tool.param_schema, - } - for tool in tools - ] - - -def _file_part_to_anthropic(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to an Anthropic content block. - - * ``image/*`` → ``{"type": "image", "source": ...}`` - * ``application/pdf`` → ``{"type": "document", "source": ...}`` - * ``text/plain`` → ``{"type": "document", "source": {"type": "text", ...}}`` - * anything else → ``ValueError`` - """ - mt = part.media_type - - if mt.startswith("image/"): - media_type = "image/jpeg" if mt == "image/*" else mt - if isinstance(part.data, str) and media.data.is_url(part.data): - return { - "type": "image", - "source": {"type": "url", "url": part.data}, - } - return { - "type": "image", - "source": { - "type": "base64", - "media_type": media_type, - "data": media.data.data_to_base64(part.data), - }, - } - - if mt == "application/pdf": - if isinstance(part.data, str) and media.data.is_url(part.data): - return { - "type": "document", - "source": {"type": "url", "url": part.data}, - } - return { - "type": "document", - "source": { - "type": "base64", - "media_type": "application/pdf", - "data": media.data.data_to_base64(part.data), - }, - } - - if mt == "text/plain": - # Anthropic accepts text documents with source.type="text" - if isinstance(part.data, bytes): - text_data = part.data.decode("utf-8") - elif media.data.is_url(part.data): - return { - "type": "document", - "source": {"type": "url", "url": part.data}, - } - else: - import base64 as _b64 - - text_data = _b64.b64decode(part.data).decode("utf-8") - return { - "type": "document", - "source": { - "type": "text", - "media_type": "text/plain", - "data": text_data, - }, - } - - raise ValueError(f"Unsupported media type for Anthropic: {mt}") - - -async def _messages_to_anthropic( - messages: list[messages_.Message], -) -> tuple[str | None, list[dict[str, Any]]]: - """Convert internal messages to Anthropic API format. - - Returns (system_prompt, messages) tuple since Anthropic handles - system prompts separately. - - Converts to the Anthropic wire format: - - - ``tool_use`` blocks in assistant messages - - ``tool_result`` blocks in user messages (immediately after) - - A final merge pass ensures strictly alternating roles (Anthropic - rejects consecutive same-role messages). - """ - system_prompt: str | None = None - result: list[dict[str, Any]] = [] - - for msg in messages: - if msg.role == "system": - system_prompt = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - elif msg.role == "assistant": - content: list[dict[str, Any]] = [] - tool_results: list[dict[str, Any]] = [] - - for part in msg.parts: - if isinstance(part, messages_.ReasoningPart): - if part.signature: - content.append( - { - "type": "thinking", - "thinking": part.text, - "signature": part.signature, - } - ) - elif isinstance(part, messages_.TextPart): - content.append({"type": "text", "text": part.text}) - elif isinstance(part, messages_.ToolPart): - tool_input = json.loads(part.tool_args) if part.tool_args else {} - content.append( - { - "type": "tool_use", - "id": part.tool_call_id, - "name": part.tool_name, - "input": tool_input, - } - ) - if part.status in ("result", "error"): - entry: dict[str, Any] = { - "type": "tool_result", - "tool_use_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - if part.status == "error": - entry["is_error"] = True - tool_results.append(entry) - - if content: - result.append({"role": "assistant", "content": content}) - if tool_results: - result.append({"role": "user", "content": tool_results}) - elif msg.role == "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - content_text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "user", "content": content_text}) - else: - user_content: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - user_content.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - user_content.append(_file_part_to_anthropic(p)) - result.append({"role": "user", "content": user_content}) - - # Merge consecutive same-role messages (e.g. synthetic user(tool_result) - # followed by a real user message). - result = _merge_consecutive_roles(result) - - return system_prompt, result - - -def _merge_consecutive_roles( - messages: list[dict[str, Any]], -) -> list[dict[str, Any]]: - """Merge consecutive messages that share the same role. - - Anthropic requires strictly alternating user/assistant roles. When - our conversion emits a synthetic ``user`` message for ``tool_result`` - blocks followed by a real ``user`` message, they must be merged. - - Content is normalized to list-of-blocks so heterogeneous content - (tool_result dicts + text strings) can coexist. - """ - if not messages: - return messages - - merged: list[dict[str, Any]] = [messages[0]] - - for msg in messages[1:]: - if msg["role"] == merged[-1]["role"]: - prev = _to_content_list(merged[-1]["content"]) - cur = _to_content_list(msg["content"]) - merged[-1]["content"] = prev + cur - else: - merged.append(msg) - - return merged - - -def _to_content_list(content: Any) -> list[dict[str, Any]]: - """Normalize Anthropic message content to list-of-blocks format.""" - if isinstance(content, list): - return list(content) - return [{"type": "text", "text": content}] - - -class AnthropicModel(llm_.LanguageModel): - """Anthropic adapter with native extended thinking support.""" - - def __init__( - self, - model: str = "claude-sonnet-4-5-20250929", - base_url: str | None = None, - api_key: str | None = None, - thinking: bool = False, - budget_tokens: int = 10000, - ) -> None: - self._model = model - self._thinking = thinking - self._budget_tokens = budget_tokens - resolved_key = api_key or os.environ.get("ANTHROPIC_API_KEY") or "" - self._client = anthropic.AsyncAnthropic(base_url=base_url, api_key=resolved_key) - - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield raw stream events from Anthropic API.""" - system_prompt, anthropic_messages = await _messages_to_anthropic(messages) - anthropic_tools = _tools_to_anthropic(tools) if tools else None - - kwargs: dict[str, Any] = { - "model": self._model, - "messages": anthropic_messages, - "max_tokens": 8192, - } - if system_prompt: - kwargs["system"] = system_prompt - if anthropic_tools: - kwargs["tools"] = anthropic_tools - - if self._thinking: - kwargs["thinking"] = { - "type": "enabled", - "budget_tokens": self._budget_tokens, - } - - # Structured output: SDK handles schema transformation internally - if output_type is not None: - kwargs["output_format"] = output_type - - # Track block types by index to know what End event to emit - block_types: dict[int, str] = {} # index -> "text" | "thinking" | "tool_use" - tool_ids: dict[int, str] = {} # index -> tool_call_id - signature_buffer: dict[int, str] = {} # index -> accumulated signature - - stream_cm = self._client.messages.stream(**kwargs) - - async with stream_cm as stream: - async for event in stream: - if event.type == "content_block_start": - block = event.content_block - idx = event.index - block_types[idx] = block.type - - if block.type == "text": - yield llm_.TextStart(block_id=str(idx)) - elif block.type == "thinking": - yield llm_.ReasoningStart(block_id=str(idx)) - elif block.type == "tool_use": - tool_ids[idx] = block.id - yield llm_.ToolStart( - tool_call_id=block.id, tool_name=block.name - ) - - elif event.type == "content_block_delta": - delta = event.delta - idx = event.index - - if delta.type == "text_delta": - yield llm_.TextDelta(block_id=str(idx), delta=delta.text) - elif delta.type == "thinking_delta": - yield llm_.ReasoningDelta( - block_id=str(idx), delta=delta.thinking - ) - elif delta.type == "signature_delta": - # Accumulate signature for ReasoningEnd - signature_buffer[idx] = ( - signature_buffer.get(idx, "") + delta.signature - ) - elif delta.type == "input_json_delta": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolArgsDelta( - tool_call_id=tool_id, delta=delta.partial_json - ) - - elif event.type == "content_block_stop": - idx = event.index - block_type = block_types.get(idx) - - if block_type == "text": - yield llm_.TextEnd(block_id=str(idx)) - elif block_type == "thinking": - yield llm_.ReasoningEnd( - block_id=str(idx), - signature=signature_buffer.get(idx), - ) - elif block_type == "tool_use": - tool_id = tool_ids.get(idx) - if tool_id: - yield llm_.ToolEnd(tool_call_id=tool_id) - - # The Anthropic SDK accumulates usage across message_start and - # message_delta events into current_message_snapshot. Read it - # once here instead of tracking state ourselves. - snapshot = stream.current_message_snapshot - sdk_usage = snapshot.usage - usage = messages_.Usage( - input_tokens=sdk_usage.input_tokens or 0, - output_tokens=sdk_usage.output_tokens or 0, - cache_read_tokens=getattr(sdk_usage, "cache_read_input_tokens", None), - cache_write_tokens=getattr( - sdk_usage, "cache_creation_input_tokens", None - ), - raw=sdk_usage.model_dump(exclude_none=True) or None, - ) - yield llm_.MessageDone(usage=usage) - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - """Stream Messages (uses StreamHandler internally).""" - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type=output_type): - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg diff --git a/src/vercel_ai_sdk/models/core/__init__.py b/src/vercel_ai_sdk/models/core/__init__.py index 63c6c89f..32bed109 100644 --- a/src/vercel_ai_sdk/models/core/__init__.py +++ b/src/vercel_ai_sdk/models/core/__init__.py @@ -1,18 +1,13 @@ -"""Core model abstractions — LanguageModel, ImageModel, VideoModel.""" +"""Core types for models.""" -from . import media -from .image import ImageModel -from .llm import LanguageModel, StreamEvent, StreamHandler -from .media.base import MediaModel, MediaResult -from .video import VideoModel +from .client import Client +from .model import Model, ModelCost +from .proto import GenerateFn, StreamFn __all__ = [ - "LanguageModel", - "StreamEvent", - "StreamHandler", - "MediaModel", - "MediaResult", - "ImageModel", - "VideoModel", - "media", + "Client", + "GenerateFn", + "Model", + "ModelCost", + "StreamFn", ] diff --git a/src/vercel_ai_sdk/models/core/client.py b/src/vercel_ai_sdk/models/core/client.py new file mode 100644 index 00000000..6cb0fb12 --- /dev/null +++ b/src/vercel_ai_sdk/models/core/client.py @@ -0,0 +1,45 @@ +"""HTTP client for adapter functions.""" + +from __future__ import annotations + +import dataclasses + +import httpx + + +@dataclasses.dataclass +class Client: + """Connection parameters for a provider API. + + Adapter functions receive a ``Client`` instead of creating their own HTTP + session. This keeps auth and base URL decoupled from the adapter logic. + + The :pyattr:`http` property lazily creates a shared + :class:`httpx.AsyncClient` so that consecutive calls reuse the same + connection pool. + """ + + base_url: str + api_key: str | None = None + headers: dict[str, str] = dataclasses.field(default_factory=dict) + + _http: httpx.AsyncClient | None = dataclasses.field( + default=None, repr=False, compare=False + ) + + @property + def http(self) -> httpx.AsyncClient: + """Lazy-init shared httpx client.""" + if self._http is None or self._http.is_closed: + self._http = httpx.AsyncClient( + base_url=self.base_url, + headers=self.headers, + timeout=httpx.Timeout(timeout=300.0, connect=10.0), + ) + return self._http + + async def aclose(self) -> None: + """Close the underlying HTTP client if open.""" + if self._http is not None and not self._http.is_closed: + await self._http.aclose() + self._http = None diff --git a/src/vercel_ai_sdk/models/core/media/detect.py b/src/vercel_ai_sdk/models/core/helpers/media.py similarity index 50% rename from src/vercel_ai_sdk/models/core/media/detect.py rename to src/vercel_ai_sdk/models/core/helpers/media.py index a9bf770a..3fc3e793 100644 --- a/src/vercel_ai_sdk/models/core/media/detect.py +++ b/src/vercel_ai_sdk/models/core/helpers/media.py @@ -1,13 +1,100 @@ -"""Magic-byte media type detection. - -Port of ``@ai-sdk/ai/src/util/detect-media-type.ts``. Detects image, -audio, and video formats by inspecting the first bytes of binary data -(or the first characters of a base-64 string). -""" - from __future__ import annotations +import base64 import base64 as _b64 +import mimetypes + +import httpx + +# -- URL helpers ----------------------------------------------------------- + + +def is_url(data: str) -> bool: + """Return True if *data* looks like a URL rather than raw base-64.""" + return data.startswith(("http://", "https://", "data:")) + + +def is_downloadable_url(data: str) -> bool: + """Return True if *data* is an ``http(s)://`` URL that can be fetched.""" + return data.startswith(("http://", "https://")) + + +def split_data_url(url: str) -> tuple[str | None, str | None]: + """Parse a ``data:`` URL into ``(media_type, base64_content)``. + + Returns ``(None, None)`` if the input is not a valid ``data:`` URL. + + Example:: + + >>> split_data_url("data:image/png;base64,iVBOR...") + ("image/png", "iVBOR...") + """ + if not url.startswith("data:"): + return None, None + try: + header, b64_content = url.split(",", 1) + # header = "data:image/png;base64" + mt = header.split(";")[0].split(":", 1)[1] + return (mt or None), (b64_content or None) + except (ValueError, IndexError): + return None, None + + +# -- encoding helpers ------------------------------------------------------ + + +def data_to_base64(data: str | bytes) -> str: + """Ensure *data* is a base-64 encoded string. + + * ``bytes`` -> base-64 encoded. + * ``str`` that is a ``data:`` URL -> base-64 content extracted. + * ``str`` that is an ``http(s)://`` URL -> returned as-is (caller + must handle). + * ``str`` that is not a URL -> assumed to already be base-64. + """ + if isinstance(data, bytes): + return base64.b64encode(data).decode("ascii") + if data.startswith("data:"): + _, b64 = split_data_url(data) + if b64 is not None: + return b64 + return data + + +def data_to_data_url(data: str | bytes, media_type: str) -> str: + """Convert *data* to a ``data:`` URL. Passes through existing URLs.""" + if isinstance(data, str) and is_url(data): + return data + b64 = data_to_base64(data) + return f"data:{media_type};base64,{b64}" + + +# -- media-type inference -------------------------------------------------- + + +def infer_media_type(url: str) -> str: + """Infer IANA media type from a URL. + + * ``data:image/png;base64,...`` -> ``"image/png"`` + * ``https://example.com/cat.jpg`` -> ``"image/jpeg"`` (via :mod:`mimetypes`) + * Unknown -> raises :class:`ValueError` + """ + if url.startswith("data:"): + # data:[][;base64], + rest = url[5:] # strip "data:" + sep = rest.find(",") + meta = rest[:sep] if sep != -1 else rest + mt = meta.split(";")[0] + if mt: + return mt + else: + guessed, _ = mimetypes.guess_type(url) + if guessed: + return guessed + raise ValueError( + f"Cannot infer media_type from URL: {url!r}. Provide media_type explicitly." + ) + # --------------------------------------------------------------------------- # Signature definitions @@ -186,3 +273,98 @@ def detect_image_media_type(data: bytes | str) -> str | None: def detect_audio_media_type(data: bytes | str) -> str | None: """Detect audio format from magic bytes.""" return detect_media_type(data, AUDIO_SIGNATURES) + + +DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +class DownloadError(Exception): + """Raised when a URL download fails.""" + + def __init__( + self, + url: str, + *, + status_code: int | None = None, + status_text: str | None = None, + cause: BaseException | None = None, + ) -> None: + parts = [f"Failed to download {url!r}"] + if status_code is not None: + parts.append(f"status={status_code}") + if status_text: + parts.append(status_text) + super().__init__(": ".join(parts)) + self.url = url + self.status_code = status_code + if cause is not None: + self.__cause__ = cause + + +def _validate_url(url: str) -> None: + """Reject non-HTTP(S) URLs (SSRF prevention).""" + from urllib.parse import urlparse + + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_SCHEMES: + raise DownloadError( + url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" + ) + + +async def download( + url: str, + *, + max_bytes: int = DEFAULT_MAX_BYTES, +) -> tuple[bytes, str | None]: + """Download *url* and return ``(data, content_type)``. + + Args: + url: The URL to fetch (must be ``http`` or ``https``). + max_bytes: Maximum response size. Defaults to 100 MiB. + + Returns: + A tuple of ``(raw_bytes, content_type_or_None)``. + + Raises: + DownloadError: On any failure (network, HTTP status, size, etc.). + """ + _validate_url(url) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(url) + + # Validate redirect target + if resp.url is not None and str(resp.url) != url: + _validate_url(str(resp.url)) + + if resp.status_code >= 400: + raise DownloadError( + url, + status_code=resp.status_code, + status_text=resp.reason_phrase or "", + ) + + data = resp.content + if len(data) > max_bytes: + raise DownloadError( + url, + status_text=( + f"Response exceeds maximum size " + f"({len(data)} > {max_bytes} bytes)" + ), + ) + + content_type = resp.headers.get("content-type") + # Strip charset/parameters: "image/png; charset=..." → "image/png" + if content_type: + content_type = content_type.split(";")[0].strip() + + return data, content_type or None + + except DownloadError: + raise + except Exception as exc: + raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/models/core/llm.py b/src/vercel_ai_sdk/models/core/helpers/streaming.py similarity index 84% rename from src/vercel_ai_sdk/models/core/llm.py rename to src/vercel_ai_sdk/models/core/helpers/streaming.py index 765abf36..11d27006 100644 --- a/src/vercel_ai_sdk/models/core/llm.py +++ b/src/vercel_ai_sdk/models/core/helpers/streaming.py @@ -1,13 +1,12 @@ from __future__ import annotations -import abc import dataclasses -from collections.abc import AsyncGenerator, Sequence +import json +from collections.abc import AsyncGenerator import pydantic -from ...types import messages as messages_ -from ...types import tools as tools_ +from ....types import messages as messages_ @dataclasses.dataclass @@ -236,27 +235,30 @@ def _build_message( ) -class LanguageModel(abc.ABC): - @abc.abstractmethod - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - raise NotImplementedError - yield - - async def buffer( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> messages_.Message: - """Drain the stream and return the final message.""" - final = None - async for msg in self.stream(messages, tools, output_type=output_type): - final = msg - if final is None: - raise ValueError("LLM produced no messages") - return final +async def events_to_messages( + events: AsyncGenerator[StreamEvent], + output_type: type[pydantic.BaseModel] | None = None, +) -> AsyncGenerator[messages_.Message]: + """Convert a stream of events into Message snapshots. + + This is the standalone version of the logic that ``LanguageModel.stream()`` + uses. Wire functions call this to turn their ``StreamEvent`` generators + into ``Message`` generators suitable for ``Stream``. + """ + handler = StreamHandler() + msg: messages_.Message | None = None + async for event in events: + msg = handler.handle_event(event) + yield msg + + # After stream completes, validate and attach structured output part + if output_type is not None and msg is not None and msg.text: + data = json.loads(msg.text) + output_type.model_validate(data) # fail fast on bad data + part = messages_.StructuredOutputPart( + data=data, + output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", + ) + msg = msg.model_copy() + msg.parts = [*msg.parts, part] + yield msg diff --git a/src/vercel_ai_sdk/models/core/image.py b/src/vercel_ai_sdk/models/core/image.py deleted file mode 100644 index eb7aa9c3..00000000 --- a/src/vercel_ai_sdk/models/core/image.py +++ /dev/null @@ -1,60 +0,0 @@ -"""ImageModel — abstract image generation model.""" - -from __future__ import annotations - -import abc -from typing import Any, override - -from ...types import messages as messages_ -from .media.base import MediaModel, MediaResult - - -class ImageModel(MediaModel): - """Abstract image generation model. - - Accepts :class:`Message`\\s as input and returns a :class:`Message` - containing generated images as :class:`FilePart`\\s. - - Adapter authors implement :meth:`make_request`; the framework handles - parsing messages and assembling the response :class:`Message`. - """ - - async def generate( - self, - messages: list[messages_.Message], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> messages_.Message: - """Generate images from the given messages.""" - prompt = self._extract_prompt(messages) - input_files = self._extract_input_files(messages) - result = await self.make_request( - prompt, - input_files, - n=n, - size=size, - aspect_ratio=aspect_ratio, - seed=seed, - provider_options=provider_options, - ) - return self._build_message(result) - - @override - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - size: str | None = None, - aspect_ratio: str | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific image generation.""" - ... diff --git a/src/vercel_ai_sdk/models/core/media/__init__.py b/src/vercel_ai_sdk/models/core/media/__init__.py deleted file mode 100644 index a4485760..00000000 --- a/src/vercel_ai_sdk/models/core/media/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Media utilities — data-format helpers, media type detection, and download.""" - -from . import data, detect, download -from .base import MediaModel, MediaResult - -__all__ = [ - "MediaModel", - "MediaResult", - "data", - "detect", - "download", -] diff --git a/src/vercel_ai_sdk/models/core/media/base.py b/src/vercel_ai_sdk/models/core/media/base.py deleted file mode 100644 index b6306a67..00000000 --- a/src/vercel_ai_sdk/models/core/media/base.py +++ /dev/null @@ -1,86 +0,0 @@ -"""MediaModel base class and MediaResult type. - -Shared pipeline steps that every media adapter would otherwise duplicate: - -* **Input** -- extract a text prompt and input files from messages. -* **Output** -- wrap the adapter's :class:`MediaResult` into a - :class:`Message` with ``role="assistant"``. -""" - -from __future__ import annotations - -import abc -import dataclasses -from typing import Any - -from ....types import messages as messages_ - - -@dataclasses.dataclass -class MediaResult: - """Raw result returned by an adapter's ``make_request()`` method. - - The framework wraps this into a :class:`Message` automatically. - """ - - files: list[messages_.FilePart] - usage: messages_.Usage | None = None - - -class MediaModel(abc.ABC): - """Abstract base for media generation models. - - Subclasses (:class:`ImageModel`, :class:`VideoModel`) define the - public ``generate()`` signature with media-type-specific parameters - and delegate to the adapter's ``make_request()`` method. - """ - - @staticmethod - def _extract_prompt(messages: list[messages_.Message]) -> str: - """Concatenate all :class:`TextPart` texts from user/system messages.""" - parts: list[str] = [] - for msg in messages: - if msg.role in ("user", "system"): - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append(p.text) - return " ".join(parts) - - @staticmethod - def _extract_input_files( - messages: list[messages_.Message], - ) -> list[messages_.FilePart]: - """Collect all :class:`FilePart` objects from user messages.""" - files: list[messages_.FilePart] = [] - for msg in messages: - if msg.role == "user": - for p in msg.parts: - if isinstance(p, messages_.FilePart): - files.append(p) - return files - - @staticmethod - def _build_message(result: MediaResult) -> messages_.Message: - """Wrap adapter output into a :class:`Message`.""" - return messages_.Message( - role="assistant", - parts=list(result.files), - usage=result.usage, - ) - - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific generation logic. - - Receives already-parsed inputs and returns a :class:`MediaResult`. - The framework calls this from ``generate()`` and wraps the result - into a :class:`Message`. - """ - ... diff --git a/src/vercel_ai_sdk/models/core/media/data.py b/src/vercel_ai_sdk/models/core/media/data.py deleted file mode 100644 index e92fb5e2..00000000 --- a/src/vercel_ai_sdk/models/core/media/data.py +++ /dev/null @@ -1,100 +0,0 @@ -"""Data-format helpers for multimodal content. - -URL detection, ``data:`` URL parsing, base-64 encoding/decoding, and -media-type inference utilities used by :class:`~vercel_ai_sdk.core.messages.FilePart` -and the provider converters. -""" - -from __future__ import annotations - -import base64 -import mimetypes - -# -- URL helpers ----------------------------------------------------------- - - -def is_url(data: str) -> bool: - """Return True if *data* looks like a URL rather than raw base-64.""" - return data.startswith(("http://", "https://", "data:")) - - -def is_downloadable_url(data: str) -> bool: - """Return True if *data* is an ``http(s)://`` URL that can be fetched.""" - return data.startswith(("http://", "https://")) - - -def split_data_url(url: str) -> tuple[str | None, str | None]: - """Parse a ``data:`` URL into ``(media_type, base64_content)``. - - Returns ``(None, None)`` if the input is not a valid ``data:`` URL. - - Example:: - - >>> split_data_url("data:image/png;base64,iVBOR...") - ("image/png", "iVBOR...") - """ - if not url.startswith("data:"): - return None, None - try: - header, b64_content = url.split(",", 1) - # header = "data:image/png;base64" - mt = header.split(";")[0].split(":", 1)[1] - return (mt or None), (b64_content or None) - except (ValueError, IndexError): - return None, None - - -# -- encoding helpers ------------------------------------------------------ - - -def data_to_base64(data: str | bytes) -> str: - """Ensure *data* is a base-64 encoded string. - - * ``bytes`` -> base-64 encoded. - * ``str`` that is a ``data:`` URL -> base-64 content extracted. - * ``str`` that is an ``http(s)://`` URL -> returned as-is (caller - must handle). - * ``str`` that is not a URL -> assumed to already be base-64. - """ - if isinstance(data, bytes): - return base64.b64encode(data).decode("ascii") - if data.startswith("data:"): - _, b64 = split_data_url(data) - if b64 is not None: - return b64 - return data - - -def data_to_data_url(data: str | bytes, media_type: str) -> str: - """Convert *data* to a ``data:`` URL. Passes through existing URLs.""" - if isinstance(data, str) and is_url(data): - return data - b64 = data_to_base64(data) - return f"data:{media_type};base64,{b64}" - - -# -- media-type inference -------------------------------------------------- - - -def infer_media_type(url: str) -> str: - """Infer IANA media type from a URL. - - * ``data:image/png;base64,...`` -> ``"image/png"`` - * ``https://example.com/cat.jpg`` -> ``"image/jpeg"`` (via :mod:`mimetypes`) - * Unknown -> raises :class:`ValueError` - """ - if url.startswith("data:"): - # data:[][;base64], - rest = url[5:] # strip "data:" - sep = rest.find(",") - meta = rest[:sep] if sep != -1 else rest - mt = meta.split(";")[0] - if mt: - return mt - else: - guessed, _ = mimetypes.guess_type(url) - if guessed: - return guessed - raise ValueError( - f"Cannot infer media_type from URL: {url!r}. Provide media_type explicitly." - ) diff --git a/src/vercel_ai_sdk/models/core/media/download.py b/src/vercel_ai_sdk/models/core/media/download.py deleted file mode 100644 index ef3757af..00000000 --- a/src/vercel_ai_sdk/models/core/media/download.py +++ /dev/null @@ -1,104 +0,0 @@ -"""Async download utility for URL-based file parts. - -Port of ``@ai-sdk/ai/src/util/download/download.ts``. Used by -provider adapters that need to fetch a URL the provider API cannot -accept natively (e.g. OpenAI does not support audio/PDF URLs). -""" - -from __future__ import annotations - -import httpx - -DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) -_ALLOWED_SCHEMES = frozenset({"http", "https"}) - - -class DownloadError(Exception): - """Raised when a URL download fails.""" - - def __init__( - self, - url: str, - *, - status_code: int | None = None, - status_text: str | None = None, - cause: BaseException | None = None, - ) -> None: - parts = [f"Failed to download {url!r}"] - if status_code is not None: - parts.append(f"status={status_code}") - if status_text: - parts.append(status_text) - super().__init__(": ".join(parts)) - self.url = url - self.status_code = status_code - if cause is not None: - self.__cause__ = cause - - -def _validate_url(url: str) -> None: - """Reject non-HTTP(S) URLs (SSRF prevention).""" - from urllib.parse import urlparse - - parsed = urlparse(url) - if parsed.scheme not in _ALLOWED_SCHEMES: - raise DownloadError( - url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" - ) - - -async def download( - url: str, - *, - max_bytes: int = DEFAULT_MAX_BYTES, -) -> tuple[bytes, str | None]: - """Download *url* and return ``(data, content_type)``. - - Args: - url: The URL to fetch (must be ``http`` or ``https``). - max_bytes: Maximum response size. Defaults to 100 MiB. - - Returns: - A tuple of ``(raw_bytes, content_type_or_None)``. - - Raises: - DownloadError: On any failure (network, HTTP status, size, etc.). - """ - _validate_url(url) - - try: - async with httpx.AsyncClient(follow_redirects=True) as client: - resp = await client.get(url) - - # Validate redirect target - if resp.url is not None and str(resp.url) != url: - _validate_url(str(resp.url)) - - if resp.status_code >= 400: - raise DownloadError( - url, - status_code=resp.status_code, - status_text=resp.reason_phrase or "", - ) - - data = resp.content - if len(data) > max_bytes: - raise DownloadError( - url, - status_text=( - f"Response exceeds maximum size " - f"({len(data)} > {max_bytes} bytes)" - ), - ) - - content_type = resp.headers.get("content-type") - # Strip charset/parameters: "image/png; charset=..." → "image/png" - if content_type: - content_type = content_type.split(";")[0].strip() - - return data, content_type or None - - except DownloadError: - raise - except Exception as exc: - raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/models/core/model.py b/src/vercel_ai_sdk/models/core/model.py new file mode 100644 index 00000000..cbf59f50 --- /dev/null +++ b/src/vercel_ai_sdk/models/core/model.py @@ -0,0 +1,34 @@ +"""Model metadata types.""" + +from __future__ import annotations + +import dataclasses + + +@dataclasses.dataclass(frozen=True) +class ModelCost: + """Per-million-token pricing.""" + + input: float = 0.0 + output: float = 0.0 + cache_read: float = 0.0 + cache_write: float = 0.0 + + +@dataclasses.dataclass(frozen=True) +class Model: + """Pure-data description of a model. + + * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). + * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). + * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). + """ + + id: str + adapter: str + provider: str + name: str = "" + capabilities: tuple[str, ...] = ("text",) + context_window: int = 0 + max_output_tokens: int = 0 + cost: ModelCost | None = None diff --git a/src/vercel_ai_sdk/models/core/proto.py b/src/vercel_ai_sdk/models/core/proto.py new file mode 100644 index 00000000..1ceb8ff2 --- /dev/null +++ b/src/vercel_ai_sdk/models/core/proto.py @@ -0,0 +1,59 @@ +"""Adapter function protocols. + +An *adapter function* translates between our ``Message`` types and a specific +provider API (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). + +Adapter functions are plain async generators / coroutines — no base class +required. The protocols below exist only for static type-checking. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Sequence +from typing import Any, Protocol, runtime_checkable + +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from .client import Client +from .model import Model + + +@runtime_checkable +class StreamFn(Protocol): + """Protocol for streaming adapter functions. + + Implementations yield ``Message`` snapshots as the response streams + in. Each snapshot is a complete, self-contained message reflecting + the accumulated state up to that point. + """ + + def __call__( + self, + client: Client, + model: Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, + ) -> AsyncGenerator[messages_.Message]: ... + + +@runtime_checkable +class GenerateFn(Protocol): + """Protocol for non-streaming adapter functions (images, video, etc.). + + ``params`` is typed as ``Any`` at the protocol level because each adapter + defines its own parameter types (e.g. ``ImageParams | VideoParams``). + Type safety is enforced at the top-level ``generate()`` function. + """ + + async def __call__( + self, + client: Client, + model: Model, + messages: list[messages_.Message], + params: Any = None, + ) -> messages_.Message: ... diff --git a/src/vercel_ai_sdk/models/core/video.py b/src/vercel_ai_sdk/models/core/video.py deleted file mode 100644 index 84e1d074..00000000 --- a/src/vercel_ai_sdk/models/core/video.py +++ /dev/null @@ -1,66 +0,0 @@ -"""VideoModel — abstract video generation model.""" - -from __future__ import annotations - -import abc -from typing import Any, override - -from ...types import messages as messages_ -from .media.base import MediaModel, MediaResult - - -class VideoModel(MediaModel): - """Abstract video generation model. - - Accepts :class:`Message`\\s as input and returns a :class:`Message` - containing generated videos as :class:`FilePart`\\s. - - Adapter authors implement :meth:`make_request`; the framework handles - parsing messages and assembling the response :class:`Message`. - """ - - async def generate( - self, - messages: list[messages_.Message], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> messages_.Message: - """Generate videos from the given messages.""" - prompt = self._extract_prompt(messages) - input_files = self._extract_input_files(messages) - result = await self.make_request( - prompt, - input_files, - n=n, - aspect_ratio=aspect_ratio, - resolution=resolution, - duration=duration, - fps=fps, - seed=seed, - provider_options=provider_options, - ) - return self._build_message(result) - - @override - @abc.abstractmethod - async def make_request( - self, - prompt: str, - input_files: list[messages_.FilePart], - *, - n: int = 1, - aspect_ratio: str | None = None, - resolution: str | None = None, - duration: float | None = None, - fps: int | None = None, - seed: int | None = None, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - """Adapter-specific video generation.""" - ... diff --git a/src/vercel_ai_sdk/models/openai/__init__.py b/src/vercel_ai_sdk/models/openai/__init__.py index 4b83b500..bd01bcd1 100644 --- a/src/vercel_ai_sdk/models/openai/__init__.py +++ b/src/vercel_ai_sdk/models/openai/__init__.py @@ -1,5 +1,7 @@ -"""OpenAI provider adapter.""" +"""OpenAI provider — adapter for the OpenAI chat completions API.""" -from .llm import OpenAIModel, _messages_to_openai +from .adapter import stream -__all__ = ["OpenAIModel", "_messages_to_openai"] +__all__ = [ + "stream", +] diff --git a/src/vercel_ai_sdk/models/openai/adapter.py b/src/vercel_ai_sdk/models/openai/adapter.py new file mode 100644 index 00000000..8f63c244 --- /dev/null +++ b/src/vercel_ai_sdk/models/openai/adapter.py @@ -0,0 +1,386 @@ +"""OpenAI adapter — chat completions API. + +Message/tool conversion and streaming via the official ``openai`` SDK. +The SDK client is constructed from :class:`Client` params on each call. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator, Sequence +from typing import Any + +import openai +import pydantic + +from ...types import messages as messages_ +from ...types import tools as tools_ +from ..core import client as client_ +from ..core import model as model_ +from ..core.helpers import media as media_ +from ..core.helpers import streaming as streaming_ + +# --------------------------------------------------------------------------- +# Message / tool conversion — internal types → OpenAI wire format +# --------------------------------------------------------------------------- + + +def _tools_to_openai( + tools: Sequence[tools_.ToolLike], +) -> list[dict[str, Any]]: + """Convert internal Tool objects to OpenAI tool schema format.""" + return [ + { + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.param_schema, + }, + } + for tool in tools + ] + + +async def _file_part_to_openai( + part: messages_.FilePart, +) -> dict[str, Any]: + """Convert a :class:`FilePart` to an OpenAI content-array element. + + * ``image/*`` -> ``image_url`` (URL or ``data:`` URL) + * ``audio/*`` -> ``input_audio`` (base-64 only; URLs auto-downloaded) + * ``application/pdf`` -> ``file`` (base-64 only; URLs auto-downloaded) + * ``text/*`` -> ``text`` (decoded to string) + * anything else -> ``ValueError`` + """ + mt = part.media_type + data = part.data + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + url = media_.data_to_data_url(data, media_type) + return {"type": "image_url", "image_url": {"url": url}} + + if mt.startswith("audio/"): + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + fmt = mt.split("/", 1)[1] if "/" in mt else mt + b64 = media_.data_to_base64(data) + return { + "type": "input_audio", + "input_audio": {"data": b64, "format": fmt}, + } + + if mt == "application/pdf": + if isinstance(data, str) and media_.is_downloadable_url(data): + downloaded, _ = await media_.download(data) + data = downloaded + data_url = media_.data_to_data_url(data, mt) + filename = part.filename or "document.pdf" + return { + "type": "file", + "file": {"filename": filename, "file_data": data_url}, + } + + if mt.startswith("text/"): + if isinstance(data, bytes): + text_content = data.decode("utf-8") + elif media_.is_url(data): + text_content = data + else: + import base64 as _b64 + + text_content = _b64.b64decode(data).decode("utf-8") + return {"type": "text", "text": text_content} + + raise ValueError(f"Unsupported media type for OpenAI: {mt}") + + +async def _messages_to_openai( + messages: list[messages_.Message], +) -> list[dict[str, Any]]: + """Convert internal messages to OpenAI API format. + + * ``tool_calls`` on assistant messages + * tool results as separate ``role: "tool"`` messages + """ + result: list[dict[str, Any]] = [] + for msg in messages: + match msg.role: + case "assistant": + content = "" + reasoning = "" + tool_calls: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case messages_.ReasoningPart(text=text): + reasoning += text + case messages_.TextPart(text=text): + content += text + case messages_.ToolPart(): + tool_calls.append( + { + "id": part.tool_call_id, + "type": "function", + "function": { + "name": part.tool_name, + "arguments": part.tool_args, + }, + } + ) + if part.status in ("result", "error"): + tool_results.append( + { + "role": "tool", + "tool_call_id": part.tool_call_id, + "content": str(part.result) + if part.result is not None + else "", + } + ) + + entry: dict[str, Any] = {"role": "assistant"} + if content: + entry["content"] = content + if reasoning: + entry["reasoning"] = reasoning + if tool_calls: + entry["tool_calls"] = tool_calls + result.append(entry) + result.extend(tool_results) + + case "system": + content_text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "system", "content": content_text}) + + case "user": + has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + if not has_files: + text = "".join( + p.text for p in msg.parts if isinstance(p, messages_.TextPart) + ) + result.append({"role": "user", "content": text}) + else: + parts: list[dict[str, Any]] = [] + for p in msg.parts: + match p: + case messages_.TextPart(text=text): + parts.append({"type": "text", "text": text}) + case messages_.FilePart(): + parts.append(await _file_part_to_openai(p)) + result.append({"role": "user", "content": parts}) + return result + + +# --------------------------------------------------------------------------- +# SDK client factory +# --------------------------------------------------------------------------- + + +def _make_client(client: client_.Client) -> openai.AsyncOpenAI: + """Construct an ``AsyncOpenAI`` from our generic ``Client``.""" + return openai.AsyncOpenAI( + base_url=client.base_url, + api_key=client.api_key or "", + ) + + +# --------------------------------------------------------------------------- +# Public adapter function +# --------------------------------------------------------------------------- + + +async def stream( + client: client_.Client, + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + thinking: bool = False, + budget_tokens: int | None = None, + reasoning_effort: str | None = None, + **kwargs: Any, +) -> AsyncGenerator[messages_.Message]: + """Stream an LLM response via the OpenAI chat completions API. + + Yields ``Message`` snapshots as the response streams in. + + Extra keyword arguments beyond the ``StreamFn`` protocol: + + * ``thinking`` — enable reasoning/thinking output. + * ``budget_tokens`` — max tokens for reasoning (mutually exclusive + with ``reasoning_effort``). + * ``reasoning_effort`` — effort level: ``"none"``, ``"minimal"``, + ``"low"``, ``"medium"``, ``"high"``, ``"xhigh"`` + (mutually exclusive with ``budget_tokens``). + """ + sdk_client = _make_client(client) + openai_messages = await _messages_to_openai(messages) + openai_tools = _tools_to_openai(tools) if tools else None + + api_kwargs: dict[str, Any] = { + "model": model.id, + "messages": openai_messages, + "stream": True, + "stream_options": {"include_usage": True}, + } + if openai_tools: + api_kwargs["tools"] = openai_tools + + if output_type is not None: + from openai.lib._pydantic import to_strict_json_schema + + api_kwargs["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": output_type.__name__, + "schema": to_strict_json_schema(output_type), + "strict": True, + }, + } + + # Enable reasoning/thinking via Vercel AI Gateway's unified format + if thinking: + reasoning_config: dict[str, Any] = {"enabled": True} + if budget_tokens is not None: + reasoning_config["max_tokens"] = budget_tokens + elif reasoning_effort is not None: + reasoning_config["effort"] = reasoning_effort + api_kwargs["extra_body"] = {"reasoning": reasoning_config} + + handler = streaming_.StreamHandler() + + try: + sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) + + text_started = False + reasoning_started = False + tc_state: dict[int, dict[str, Any]] = {} + finish_reason: str | None = None + usage: messages_.Usage | None = None + + async for chunk in sdk_stream: + if chunk.usage is not None: + raw = chunk.usage.model_dump(exclude_none=True) + reasoning_tokens: int | None = None + cache_read: int | None = None + cd = getattr( + chunk.usage, + "completion_tokens_details", + None, + ) + if cd: + reasoning_tokens = getattr(cd, "reasoning_tokens", None) + pd = getattr( + chunk.usage, + "prompt_tokens_details", + None, + ) + if pd: + cache_read = getattr(pd, "cached_tokens", None) + usage = messages_.Usage( + input_tokens=chunk.usage.prompt_tokens or 0, + output_tokens=chunk.usage.completion_tokens or 0, + reasoning_tokens=reasoning_tokens, + cache_read_tokens=cache_read, + raw=raw, + ) + + if not chunk.choices: + continue + + choice = chunk.choices[0] + delta = choice.delta + + # Reasoning / thinking content + reasoning_value = None + if hasattr(delta, "reasoning") and delta.reasoning: + reasoning_value = delta.reasoning + elif hasattr(delta, "model_extra") and delta.model_extra: + reasoning_value = delta.model_extra.get("reasoning") + + if reasoning_value: + if not reasoning_started: + reasoning_started = True + yield handler.handle_event( + streaming_.ReasoningStart(block_id="reasoning") + ) + yield handler.handle_event( + streaming_.ReasoningDelta( + block_id="reasoning", delta=reasoning_value + ) + ) + + if delta.content: + if reasoning_started: + yield handler.handle_event( + streaming_.ReasoningEnd(block_id="reasoning") + ) + reasoning_started = False + + if not text_started: + text_started = True + yield handler.handle_event(streaming_.TextStart(block_id="text")) + yield handler.handle_event( + streaming_.TextDelta(block_id="text", delta=delta.content) + ) + + if delta.tool_calls: + for tc in delta.tool_calls: + idx = tc.index + if idx not in tc_state: + tc_state[idx] = { + "id": tc.id, + "name": None, + "started": False, + } + if tc.id: + tc_state[idx]["id"] = tc.id + if tc.function: + if tc.function.name: + tc_state[idx]["name"] = tc.function.name + if tc.function.arguments: + tid = tc_state[idx]["id"] + tname = tc_state[idx]["name"] or "" + + if not tc_state[idx]["started"] and tid: + tc_state[idx]["started"] = True + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=tid, + tool_name=tname, + ) + ) + + if tid: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=tid, + delta=tc.function.arguments, + ) + ) + + if choice.finish_reason is not None: + finish_reason = choice.finish_reason + if reasoning_started: + yield handler.handle_event( + streaming_.ReasoningEnd(block_id="reasoning") + ) + if text_started: + yield handler.handle_event(streaming_.TextEnd(block_id="text")) + for tc in tc_state.values(): + if tc["started"] and tc["id"]: + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=tc["id"]) + ) + + yield handler.handle_event( + streaming_.MessageDone(finish_reason=finish_reason, usage=usage) + ) + finally: + await sdk_client.close() diff --git a/src/vercel_ai_sdk/models/openai/llm.py b/src/vercel_ai_sdk/models/openai/llm.py deleted file mode 100644 index 3404cf27..00000000 --- a/src/vercel_ai_sdk/models/openai/llm.py +++ /dev/null @@ -1,391 +0,0 @@ -from __future__ import annotations - -import json -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any, override - -import openai -import pydantic - -from ...types import messages as messages_ -from ...types import tools as tools_ -from ..core import llm as llm_ -from ..core import media - - -def _tools_to_openai(tools: Sequence[tools_.ToolLike]) -> list[dict[str, Any]]: - """Convert internal Tool objects to OpenAI tool schema format.""" - return [ - { - "type": "function", - "function": { - "name": tool.name, - "description": tool.description, - "parameters": tool.param_schema, - }, - } - for tool in tools - ] - - -async def _file_part_to_openai(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to an OpenAI content-array element. - - Follows the OpenAI chat-completions content part formats: - - * ``image/*`` → ``image_url`` (URL or ``data:`` URL) - * ``audio/*`` → ``input_audio`` (base-64 only; URLs auto-downloaded) - * ``application/pdf`` → ``file`` (base-64 only; URLs auto-downloaded) - * ``text/*`` → ``text`` (decoded to string) - * anything else → ``ValueError`` - - OpenAI does not accept URLs for audio ``input_audio`` or PDF ``file`` - parts. When URL data is provided for these types, it is downloaded - automatically (matching the TS SDK's ``downloadAssets`` behaviour). - """ - mt = part.media_type - data = part.data - - if mt.startswith("image/"): - media_type = "image/jpeg" if mt == "image/*" else mt - url = media.data.data_to_data_url(data, media_type) - return {"type": "image_url", "image_url": {"url": url}} - - if mt.startswith("audio/"): - # OpenAI input_audio requires raw base-64 — download http(s) URLs. - if isinstance(data, str) and media.data.is_downloadable_url(data): - downloaded, _ = await media.download.download(data) - data = downloaded - fmt = mt.split("/", 1)[1] if "/" in mt else mt - b64 = media.data.data_to_base64(data) - return {"type": "input_audio", "input_audio": {"data": b64, "format": fmt}} - - if mt == "application/pdf": - # OpenAI file parts require base-64 — download http(s) URLs. - if isinstance(data, str) and media.data.is_downloadable_url(data): - downloaded, _ = await media.download.download(data) - data = downloaded - data_url = media.data.data_to_data_url(data, mt) - filename = part.filename or "document.pdf" - return {"type": "file", "file": {"filename": filename, "file_data": data_url}} - - if mt.startswith("text/"): - # Decode text content — URLs are passed through as text, - # bytes/base-64 are decoded to UTF-8 string. - if isinstance(data, bytes): - text_content = data.decode("utf-8") - elif media.data.is_url(data): - text_content = data - else: - import base64 as _b64 - - text_content = _b64.b64decode(data).decode("utf-8") - return {"type": "text", "text": text_content} - - raise ValueError(f"Unsupported media type for OpenAI: {mt}") - - -async def _messages_to_openai( - messages: list[messages_.Message], -) -> list[dict[str, Any]]: - """Convert internal messages to OpenAI API format. - - Converts to the OpenAI wire format: - - - ``tool_calls`` on assistant messages - - tool results as separate ``role: "tool"`` messages - - The Vercel AI Gateway preserves reasoning details across interactions, - normalizing formats from different providers. - - See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - """ - result: list[dict[str, Any]] = [] - for msg in messages: - if msg.role == "assistant": - content = "" - reasoning = "" - tool_calls = [] - tool_results = [] - - for part in msg.parts: - if isinstance(part, messages_.ReasoningPart): - reasoning += part.text - elif isinstance(part, messages_.TextPart): - content += part.text - elif isinstance(part, messages_.ToolPart): - tool_calls.append( - { - "id": part.tool_call_id, - "type": "function", - "function": { - "name": part.tool_name, - "arguments": part.tool_args, - }, - } - ) - if part.status in ("result", "error"): - tool_results.append( - { - "role": "tool", - "tool_call_id": part.tool_call_id, - "content": str(part.result) - if part.result is not None - else "", - } - ) - - entry: dict[str, Any] = {"role": "assistant"} - if content: - entry["content"] = content - if reasoning: - entry["reasoning"] = reasoning - if tool_calls: - entry["tool_calls"] = tool_calls - result.append(entry) - - # Emit tool results as separate messages (OpenAI API format) - result.extend(tool_results) - elif msg.role == "system": - content = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "system", "content": content}) - else: - # User messages — may contain multimodal FileParts - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) - if not has_files: - # Text-only: keep simple string format (cheaper, no content array) - text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) - ) - result.append({"role": "user", "content": text}) - else: - parts: list[dict[str, Any]] = [] - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): - parts.append(await _file_part_to_openai(p)) - result.append({"role": "user", "content": parts}) - return result - - -class OpenAIModel(llm_.LanguageModel): - """OpenAI adapter with reasoning/thinking support via Vercel AI Gateway. - - Supports reasoning for models like GPT 5.x, o-series, and Claude via gateway. - Uses the Vercel AI Gateway's unified reasoning API format. - - See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - """ - - def __init__( - self, - model: str = "gpt-4o", - base_url: str | None = None, - api_key: str | None = None, - thinking: bool = False, - budget_tokens: int | None = None, - reasoning_effort: str | None = None, - ) -> None: - """Initialize OpenAI model adapter. - - Args: - model: Model identifier - (e.g., 'openai/gpt-5.2', 'anthropic/claude-sonnet-4.5') - base_url: API base URL - (e.g., 'https://ai-gateway.vercel.sh/v1') - api_key: API key for authentication - thinking: Enable reasoning/thinking output - budget_tokens: Max tokens for reasoning - (mutually exclusive with reasoning_effort) - reasoning_effort: Effort level — 'none', 'minimal', - 'low', 'medium', 'high', 'xhigh' - (mutually exclusive with budget_tokens) - """ - self._model = model - self._thinking = thinking - self._budget_tokens = budget_tokens - self._reasoning_effort = reasoning_effort - resolved_key = api_key or os.environ.get("OPENAI_API_KEY") or "" - self._client = openai.AsyncOpenAI(base_url=base_url, api_key=resolved_key) - - async def stream_events( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[llm_.StreamEvent]: - """Yield raw stream events from OpenAI API.""" - openai_messages = await _messages_to_openai(messages) - openai_tools = _tools_to_openai(tools) if tools else None - - kwargs: dict[str, Any] = { - "model": self._model, - "messages": openai_messages, - "stream": True, - } - if openai_tools: - kwargs["tools"] = openai_tools - kwargs["stream_options"] = {"include_usage": True} - - if output_type is not None: - from openai.lib._pydantic import to_strict_json_schema - - kwargs["response_format"] = { - "type": "json_schema", - "json_schema": { - "name": output_type.__name__, - "schema": to_strict_json_schema(output_type), - "strict": True, - }, - } - - # Enable reasoning/thinking via Vercel AI Gateway's unified format - # See: https://vercel.com/docs/ai-gateway/openai-compat/advanced - if self._thinking: - reasoning_config: dict[str, Any] = {"enabled": True} - # Use budget_tokens OR reasoning_effort (mutually exclusive per docs) - if self._budget_tokens is not None: - reasoning_config["max_tokens"] = self._budget_tokens - elif self._reasoning_effort is not None: - reasoning_config["effort"] = self._reasoning_effort - kwargs["extra_body"] = {"reasoning": reasoning_config} - - stream = await self._client.chat.completions.create(**kwargs) - - # Track active blocks for Start/End events - text_started = False - reasoning_started = False - tool_calls: dict[int, dict[str, Any]] = {} # index -> {id, name, started} - finish_reason: str | None = None - usage: messages_.Usage | None = None - - async for chunk in stream: - # Extract usage from any chunk that carries it (typically the final - # chunk when stream_options.include_usage is True). - if chunk.usage is not None: - raw = chunk.usage.model_dump(exclude_none=True) - # Extract optional breakdowns - reasoning_tokens: int | None = None - cache_read: int | None = None - completion_details = getattr( - chunk.usage, "completion_tokens_details", None - ) - if completion_details: - reasoning_tokens = getattr( - completion_details, "reasoning_tokens", None - ) - prompt_details = getattr(chunk.usage, "prompt_tokens_details", None) - if prompt_details: - cache_read = getattr(prompt_details, "cached_tokens", None) - usage = messages_.Usage( - input_tokens=chunk.usage.prompt_tokens or 0, - output_tokens=chunk.usage.completion_tokens or 0, - reasoning_tokens=reasoning_tokens, - cache_read_tokens=cache_read, - raw=raw, - ) - - if not chunk.choices: - continue - - choice = chunk.choices[0] - delta = choice.delta - - # Handle reasoning/thinking content via Vercel AI Gateway - # The gateway may return reasoning in different ways: - # 1. As a direct attribute (if SDK supports it) - # 2. In model_extra (Pydantic v2 extra fields) - reasoning_value = None - if hasattr(delta, "reasoning") and delta.reasoning: - reasoning_value = delta.reasoning - elif hasattr(delta, "model_extra") and delta.model_extra: - reasoning_value = delta.model_extra.get("reasoning") - - if reasoning_value: - if not reasoning_started: - reasoning_started = True - yield llm_.ReasoningStart(block_id="reasoning") - yield llm_.ReasoningDelta(block_id="reasoning", delta=reasoning_value) - - if delta.content: - # Close reasoning block when text starts (reasoning precedes text) - if reasoning_started: - yield llm_.ReasoningEnd(block_id="reasoning") - reasoning_started = False - - if not text_started: - text_started = True - yield llm_.TextStart(block_id="text") - yield llm_.TextDelta(block_id="text", delta=delta.content) - - if delta.tool_calls: - for tc in delta.tool_calls: - idx = tc.index - if idx not in tool_calls: - tool_calls[idx] = {"id": tc.id, "name": None, "started": False} - if tc.id: - tool_calls[idx]["id"] = tc.id - if tc.function: - if tc.function.name: - tool_calls[idx]["name"] = tc.function.name - if tc.function.arguments: - tool_id = tool_calls[idx]["id"] - tool_name = tool_calls[idx]["name"] or "" - - # Emit start if not started - if not tool_calls[idx]["started"] and tool_id: - tool_calls[idx]["started"] = True - yield llm_.ToolStart( - tool_call_id=tool_id, tool_name=tool_name - ) - - if tool_id: - yield llm_.ToolArgsDelta( - tool_call_id=tool_id, delta=tc.function.arguments - ) - - if choice.finish_reason is not None: - finish_reason = choice.finish_reason - # Close any open blocks - if reasoning_started: - yield llm_.ReasoningEnd(block_id="reasoning") - if text_started: - yield llm_.TextEnd(block_id="text") - for tc in tool_calls.values(): - if tc["started"] and tc["id"]: - yield llm_.ToolEnd(tool_call_id=tc["id"]) - - # Don't return yet — the usage chunk may arrive after - # finish_reason. We'll emit MessageDone after the loop. - - yield llm_.MessageDone(finish_reason=finish_reason, usage=usage) - - @override - async def stream( - self, - messages: list[messages_.Message], - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages_.Message]: - """Stream Messages (uses StreamHandler internally).""" - handler = llm_.StreamHandler() - msg: messages_.Message | None = None - async for event in self.stream_events(messages, tools, output_type=output_type): - msg = handler.handle_event(event) - yield msg - - # After stream completes, validate and attach structured output part - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = messages_.StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg diff --git a/src/vercel_ai_sdk/types/messages.py b/src/vercel_ai_sdk/types/messages.py index 41a088b7..683c7469 100644 --- a/src/vercel_ai_sdk/types/messages.py +++ b/src/vercel_ai_sdk/types/messages.py @@ -138,9 +138,9 @@ def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: ``media_type`` is provided. """ if media_type is None: - from ..models.core.media import data as media_data + from ..models.core.helpers import media as media_helpers - media_type = media_data.infer_media_type(url) + media_type = media_helpers.infer_media_type(url) return cls(data=url, media_type=media_type) @classmethod @@ -158,11 +158,11 @@ def from_bytes( detection fails. """ if media_type is None: - from ..models.core.media import detect as media_detect + from ..models.core.helpers import media as media_helpers - media_type = media_detect.detect_image_media_type( + media_type = media_helpers.detect_image_media_type( data - ) or media_detect.detect_audio_media_type(data) + ) or media_helpers.detect_audio_media_type(data) if media_type is None: raise ValueError( "Cannot detect media_type from bytes. Provide media_type explicitly." diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index db717241..1bbe2756 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -12,7 +12,7 @@ from vercel_ai_sdk.agents import hooks from vercel_ai_sdk.types import messages -from ...conftest import MockLLM, tool_msg +from ...conftest import MOCK_MODEL, mock_llm, tool_msg async def get_event_types(msgs: list[messages.Message]) -> list[str]: @@ -240,22 +240,10 @@ async def get_weather(city: str) -> str: return f"Sunny in {city}" -async def mock_agent( - llm: ai.LanguageModel, - user_query: str, -) -> ai.StreamResult: - """Agent using stream_loop directly.""" - return await ai.stream_loop( - llm, - messages=ai.make_messages(system="You are helpful.", user=user_query), - tools=[get_weather], - ) - - @pytest.mark.asyncio async def test_runtime_tool_roundtrip() -> None: """ - Integration test: run a mock agent loop through ai.run() and verify + Integration test: run an Agent through agent.run() and verify that tool-input-available and tool-output-available events are emitted. This test demonstrates the bug: the runtime yields the message with @@ -263,11 +251,17 @@ async def test_runtime_tool_roundtrip() -> None: executed and the ToolPart has been mutated to status="result". The UI adapter never sees the intermediate status="pending" state. - Root cause: stream_loop appends the message, then executes tools which - mutate the message in-place. The message was already yielded with + Root cause: the default loop appends the message, then executes tools + which mutate the message in-place. The message was already yielded with status="pending", but pydantic models are mutable so when we collect them at the end, we see the mutated state. """ + weather_agent = ai.agent( + model=MOCK_MODEL, + system="You are helpful.", + tools=[get_weather], + ) + # First LLM call: returns a tool call tool_call_response = [ messages.Message( @@ -294,11 +288,13 @@ async def test_runtime_tool_roundtrip() -> None: ), ] - mock_llm = MockLLM([tool_call_response, final_text_response]) + mock_llm([tool_call_response, final_text_response]) # Collect all messages from the runtime runtime_messages: list[messages.Message] = [] - async for msg in ai.run(mock_agent, mock_llm, "What's the weather in London?"): + async for msg in weather_agent.run( + ai.make_messages(user="What's the weather in London?") + ): runtime_messages.append(msg) # Stream through UI adapter @@ -308,22 +304,27 @@ async def test_runtime_tool_roundtrip() -> None: ] # This is what SHOULD happen: - # 1. First step yields tool call with status="pending" - # -> tool-input-start, tool-input-available + # 1. First step streams tool call args then completes + # -> tool-input-start, tool-input-delta, tool-input-available # 2. After tool execution, we yield the same message with # status="result" -> tool-output-available # (same step because same message ID) - # 3. Second LLM step yields final text -> text-start, text-end + # 3. Second LLM step streams text then completes + # -> text-start, text-delta, text-end, (final done msg) text-start, text-end expected = [ "start", "start-step", "tool-input-start", + "tool-input-delta", "tool-input-available", "tool-output-available", # Same step as input (same message ID) "finish-step", # Second LLM call (new message ID = new step) "start-step", "text-start", + "text-delta", + "text-end", + "text-start", # Final done message re-emits completed text "text-end", "finish-step", "finish", @@ -638,12 +639,15 @@ async def dangerous_action(path: str) -> str: """Do something dangerous.""" return f"deleted {path}" - async def graph(llm: ai.LanguageModel) -> None: - result = await ai.stream_step( - llm, - ai.make_messages(system="You are helpful.", user="delete /tmp"), - [dangerous_action], - ) + approval_agent = ai.agent( + model=MOCK_MODEL, + system="You are helpful.", + tools=[dangerous_action], + ) + + @approval_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if not result.tool_calls: return @@ -662,7 +666,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: await asyncio.gather(*(approve_and_execute(tc) for tc in result.tool_calls)) - mock_llm = MockLLM( + mock_llm( [ [ tool_msg( @@ -675,7 +679,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: ) runtime_messages: list[messages.Message] = [] - result = ai.run(graph, mock_llm) + result = approval_agent.run(ai.make_messages(user="delete /tmp")) async for msg in result: runtime_messages.append(msg) @@ -697,6 +701,7 @@ async def approve_and_execute(tc: ai.ToolPart) -> None: "start", "start-step", "tool-input-start", + "tool-input-delta", "tool-input-available", "tool-approval-request", "finish-step", diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index 8b5c5e9d..a1220ac8 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -10,7 +10,7 @@ from vercel_ai_sdk.agents.mcp.client import _mcp_tool_to_native from vercel_ai_sdk.agents.tools import _tool_registry, get_tool -from ...conftest import MockLLM, text_msg, tool_msg +from ...conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg def _fake_mcp_tool( @@ -64,12 +64,12 @@ def test_mcp_tool_to_native_schema_preserved() -> None: assert native.description == "Echo input" -# -- End-to-end: MCP tool executes through stream_loop -------------------- +# -- End-to-end: MCP tool executes through Agent default loop --------------- @pytest.mark.asyncio -async def test_mcp_tool_executes_through_stream_loop() -> None: - """MCP-style tool via _mcp_tool_to_native can be called by the agent loop.""" +async def test_mcp_tool_executes_through_agent() -> None: + """MCP-style tool via _mcp_tool_to_native works with Agent.""" call_log: list[dict[str, str]] = [] async def fake_fn(**kwargs: str) -> str: @@ -84,18 +84,13 @@ async def fake_fn(**kwargs: str) -> str: native._fn = fake_fn _tool_registry[native.name] = native - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="echo hello"), - tools=[native], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[native]) call1 = [tool_msg(tc_id="tc-mcp-1", name="mcp_e2e_echo", args='{"text": "hello"}')] call2 = [text_msg("Done.", id="msg-2")] - llm = MockLLM([call1, call2]) + llm = mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = my_agent.run(ai.make_messages(user="echo hello")) msgs = [m async for m in result] # Tool was called with the right args diff --git a/tests/agents/test_checkpoint.py b/tests/agents/test_checkpoint.py index 4f84b243..22fbd583 100644 --- a/tests/agents/test_checkpoint.py +++ b/tests/agents/test_checkpoint.py @@ -9,7 +9,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.agents.checkpoint import Checkpoint, HookEvent, StepEvent, ToolEvent -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg @ai.hook @@ -23,19 +23,20 @@ class Approval(pydantic.BaseModel): @pytest.mark.asyncio async def test_step_replay_skips_llm() -> None: - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_step( - llm, messages=ai.make_messages(system="test", user="hello") - ) + my_agent = ai.agent(model=MOCK_MODEL) - llm1 = MockLLM([[text_msg("Hi there!")]]) - result1 = ai.run(graph, llm1) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) + + llm1 = mock_llm([[text_msg("Hi there!")]]) + result1 = my_agent.run(ai.make_messages(system="test", user="hello")) [msg async for msg in result1] assert llm1.call_count == 1 cp = result1.checkpoint - llm2 = MockLLM([]) - result2 = ai.run(graph, llm2, checkpoint=cp) + llm2 = mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="test", user="hello"), checkpoint=cp) [msg async for msg in result2] assert llm2.call_count == 0 @@ -51,8 +52,11 @@ async def counting_tool(x: int) -> int: execution_count += 1 return x + 1 - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - result = await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL, tools=[counting_tool]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: await asyncio.gather( *( @@ -62,14 +66,17 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: ) return result - llm1 = MockLLM([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) - result1 = ai.run(graph, llm1) + mock_llm([[tool_msg(tc_id="tc-1", name="counting_tool", args='{"x": 5}')]]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] assert execution_count == 1 assert result1.checkpoint.tools[0].result == 6 execution_count = 0 - result2 = ai.run(graph, MockLLM([]), checkpoint=result1.checkpoint) + mock_llm([]) + result2 = my_agent.run( + ai.make_messages(system="t", user="go"), checkpoint=result1.checkpoint + ) [msg async for msg in result2] assert execution_count == 0 @@ -79,11 +86,15 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_hook_cancellation_pending() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) return await Approval.create("my_approval", metadata={"tool": "test"}) # type: ignore[attr-defined] - result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + result = my_agent.run(ai.make_messages(system="t", user="go")) msgs = [msg async for msg in result] assert "my_approval" in result.pending_hooks hook_msgs = [m for m in msgs if any(isinstance(p, ai.HookPart) for p in m.parts)] @@ -92,17 +103,22 @@ async def graph(llm: ai.LanguageModel) -> Any: @pytest.mark.asyncio async def test_hook_resolution_on_reentry() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) return await Approval.create("my_approval") # type: ignore[attr-defined] resp = [text_msg("OK")] - result1 = ai.run(graph, MockLLM([resp])) + mock_llm([resp]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("my_approval", {"granted": True}) # type: ignore[attr-defined] - result2 = ai.run(graph, MockLLM([]), checkpoint=cp) + mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 assert result2.checkpoint.hooks[-1].label == "my_approval" @@ -110,8 +126,11 @@ async def graph(llm: ai.LanguageModel) -> Any: @pytest.mark.asyncio async def test_parallel_hooks_all_collected() -> None: - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -123,15 +142,19 @@ async def b() -> Any: tg.create_task(a()) tg.create_task(b()) - result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + result = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result] assert {"hook_a", "hook_b"} <= set(result.pending_hooks) @pytest.mark.asyncio async def test_parallel_hooks_resolve_on_reentry() -> None: - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(system="t", user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) async def a() -> Any: return await Approval.create("hook_a") # type: ignore[attr-defined] @@ -145,13 +168,15 @@ async def b() -> Any: return ta.result(), tb.result() resp = [text_msg("OK")] - result1 = ai.run(graph, MockLLM([resp])) + mock_llm([resp]) + result1 = my_agent.run(ai.make_messages(system="t", user="go")) [msg async for msg in result1] cp = result1.checkpoint Approval.resolve("hook_a", {"granted": True}) # type: ignore[attr-defined] Approval.resolve("hook_b", {"granted": False}) # type: ignore[attr-defined] - result2 = ai.run(graph, MockLLM([]), checkpoint=cp) + mock_llm([]) + result2 = my_agent.run(ai.make_messages(system="t", user="go"), checkpoint=cp) [msg async for msg in result2] assert len(result2.pending_hooks) == 0 diff --git a/tests/agents/test_hooks.py b/tests/agents/test_hooks.py index a51c2b55..65bfd15b 100644 --- a/tests/agents/test_hooks.py +++ b/tests/agents/test_hooks.py @@ -8,7 +8,7 @@ import vercel_ai_sdk as ai -from ..conftest import MockLLM, text_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg @ai.hook @@ -31,16 +31,18 @@ class CancellingConfirmation(pydantic.BaseModel): async def test_resolve_live_future() -> None: """In long-running mode, Hook.resolve() unblocks the awaiting coroutine.""" resolved_value = None + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(llm: ai.LanguageModel) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: nonlocal resolved_value - await ai.stream_step(llm, ai.make_messages(user="go")) + await ai.stream_step(agent.model, msgs) result = await Confirmation.create("confirm_1") # type: ignore[attr-defined] resolved_value = result - llm = MockLLM([[text_msg("OK")]]) + mock_llm([[text_msg("OK")]]) # Confirmation.cancels_future=False -> long-running mode - run_result = ai.run(graph, llm) + run_result = my_agent.run(ai.make_messages(user="go")) collected = [] async for msg in run_result: @@ -54,10 +56,6 @@ async def graph(llm: ai.LanguageModel) -> None: assert resolved_value is not None assert resolved_value.approved is True assert resolved_value.reason == "looks good" - # The graph completed successfully (resolved_value proves it). - # Note: pending_hooks is not cleaned up after live resolution -- - # that's a known runtime limitation. The important thing is the - # graph continued past the hook. # -- Hook.cancel() -------------------------------------------------------- @@ -67,17 +65,19 @@ async def graph(llm: ai.LanguageModel) -> None: async def test_cancel_live_hook() -> None: """Hook.cancel() cancels the future, causing CancelledError in graph.""" was_cancelled = False + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(llm: ai.LanguageModel) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: nonlocal was_cancelled - await ai.stream_step(llm, ai.make_messages(user="go")) + await ai.stream_step(agent.model, msgs) try: await Confirmation.create("cancel_me") # type: ignore[attr-defined] except asyncio.CancelledError: was_cancelled = True - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) async for msg in run_result: if any(isinstance(p, ai.HookPart) and p.status == "pending" for p in msg.parts): @@ -101,17 +101,19 @@ async def test_cancel_nonexistent_raises() -> None: @pytest.mark.asyncio async def test_pre_registered_resolution_consumed() -> None: """Pre-registered resolution is consumed by Hook.create() without suspending.""" + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(llm: ai.LanguageModel) -> Any: - await ai.stream_step(llm, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> Any: + await ai.stream_step(agent.model, msgs) result = await Confirmation.create("pre_reg_1") # type: ignore[attr-defined] return result # Pre-register BEFORE run Confirmation.resolve("pre_reg_1", {"approved": True}) # type: ignore[attr-defined] - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) [m async for m in run_result] # Should have completed with no pending hooks @@ -136,13 +138,15 @@ def test_resolve_validates_schema() -> None: @pytest.mark.asyncio async def test_resolved_hook_emits_message() -> None: """After resolution, a 'resolved' HookPart message is emitted.""" + my_agent = ai.agent(model=MOCK_MODEL) - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) await Confirmation.create("emit_test") # type: ignore[attr-defined] - llm = MockLLM([[text_msg("OK")]]) - run_result = ai.run(graph, llm) + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) msgs = [] async for msg in run_result: @@ -164,13 +168,17 @@ async def graph(llm: ai.LanguageModel) -> None: @pytest.mark.asyncio async def test_hook_metadata_in_pending() -> None: - async def graph(llm: ai.LanguageModel) -> None: - await ai.stream_step(llm, ai.make_messages(user="go")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + await ai.stream_step(agent.model, msgs) await CancellingConfirmation.create( # type: ignore[attr-defined] "meta_test", metadata={"tool": "rm -rf", "path": "/"} ) - run_result = ai.run(graph, MockLLM([[text_msg("OK")]])) + mock_llm([[text_msg("OK")]]) + run_result = my_agent.run(ai.make_messages(user="go")) [m async for m in run_result] info = run_result.pending_hooks["meta_test"] diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 2445d7d3..6eb5b0be 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -1,4 +1,4 @@ -"""Runtime: stream_loop end-to-end, execute_tool, multi-turn, Runtime injection.""" +"""Agent default loop, execute_tool, multi-turn, Runtime injection.""" import asyncio @@ -8,7 +8,7 @@ from vercel_ai_sdk.agents.runtime import Runtime from vercel_ai_sdk.types import messages -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg # -- Tool definitions for tests -------------------------------------------- @@ -25,46 +25,34 @@ async def concat(a: str, b: str) -> str: return a + b -# -- stream_loop: single turn (no tools) ---------------------------------- +# -- Agent default loop: single turn (no tools) ---------------------------- @pytest.mark.asyncio -async def test_stream_loop_text_only() -> None: - """stream_loop with no tool calls returns after one LLM call.""" - - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="Hi"), - tools=[double], - ) - - llm = MockLLM([[text_msg("Hello!")]]) - result = ai.run(graph, llm) +async def test_agent_text_only() -> None: + """Agent default loop with no tool calls returns after one LLM call.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) + + llm = mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) msgs = [m async for m in result] assert llm.call_count == 1 assert any(m.text == "Hello!" for m in msgs) -# -- stream_loop: tool call + follow-up ----------------------------------- +# -- Agent default loop: tool call + follow-up ----------------------------- @pytest.mark.asyncio -async def test_stream_loop_tool_then_text() -> None: - """stream_loop calls tool, feeds result back, gets final text.""" - - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="Double 5"), - tools=[double], - ) +async def test_agent_tool_then_text() -> None: + """Agent default loop calls tool, feeds result back, gets final text.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] call2 = [text_msg("The answer is 10.")] - llm = MockLLM([call1, call2]) + llm = mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = my_agent.run(ai.make_messages(user="Double 5")) msgs = [m async for m in result] assert llm.call_count == 2 # Tool should have been executed: 5 * 2 = 10 @@ -75,19 +63,13 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: assert tool_results[0].tool_calls[0].result == 10 -# -- stream_loop: multiple tool calls in one message ---------------------- +# -- Agent default loop: multiple tool calls in one message ---------------- @pytest.mark.asyncio -async def test_stream_loop_parallel_tools() -> None: +async def test_agent_parallel_tools() -> None: """LLM returns two tool calls in one message; both execute.""" - - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="Double 3 and 7"), - tools=[double], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) two_tools = messages.Message( id="msg-1", @@ -110,9 +92,9 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: ], ) call2 = [text_msg("6 and 14", id="msg-2")] - llm = MockLLM([[two_tools], call2]) + llm = mock_llm([[two_tools], call2]) - result = ai.run(graph, llm) + result = my_agent.run(ai.make_messages(user="Double 3 and 7")) msgs = [m async for m in result] assert llm.call_count == 2 # Both tools should have results @@ -124,28 +106,22 @@ async def graph(llm: ai.LanguageModel) -> ai.StreamResult: assert len(tool_result_msgs) >= 1 -# -- stream_loop: multi-turn (tool -> tool -> text) ----------------------- +# -- Agent default loop: multi-turn (tool -> tool -> text) ----------------- @pytest.mark.asyncio -async def test_stream_loop_multi_turn() -> None: +async def test_agent_multi_turn() -> None: """LLM calls a tool, then calls another tool, then returns text.""" - - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="Concat then double"), - tools=[double, concat], - ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[double, concat]) turn1 = [ tool_msg(tc_id="tc-1", name="concat", args='{"a": "hello", "b": " world"}') ] turn2 = [tool_msg(tc_id="tc-2", name="double", args='{"x": 3}', id="msg-2")] turn3 = [text_msg("Done: hello world, 6", id="msg-3")] - llm = MockLLM([turn1, turn2, turn3]) + llm = mock_llm([turn1, turn2, turn3]) - result = ai.run(graph, llm) + result = my_agent.run(ai.make_messages(user="Concat then double")) [m async for m in result] assert llm.call_count == 3 @@ -162,11 +138,14 @@ async def test_execute_tool_missing_raises() -> None: tc = messages.ToolPart( tool_call_id="tc-1", tool_name="nonexistent_tool_zzz", tool_args="{}" ) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def graph(llm: ai.LanguageModel) -> None: + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: await ai.execute_tool(tc) - result = ai.run(graph, MockLLM([])) + mock_llm([]) + result = my_agent.run(ai.make_messages(user="go")) with pytest.raises(ExceptionGroup) as exc_info: [m async for m in result] assert any(isinstance(e, ValueError) for e in exc_info.value.exceptions) @@ -187,8 +166,11 @@ async def introspect(query: str, rt: Runtime) -> str: received_rt = rt return "ok" - async def graph(llm: ai.LanguageModel) -> None: - result = await ai.stream_step(llm, ai.make_messages(user="go")) + my_agent = ai.agent(model=MOCK_MODEL, tools=[introspect]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: await asyncio.gather( *( @@ -198,7 +180,8 @@ async def graph(llm: ai.LanguageModel) -> None: ) call = [tool_msg(tc_id="tc-1", name="introspect", args='{"query": "test"}')] - result = ai.run(graph, MockLLM([call])) + mock_llm([call]) + result = my_agent.run(ai.make_messages(user="go")) [m async for m in result] assert received_rt is not None assert isinstance(received_rt, Runtime) @@ -210,9 +193,11 @@ async def graph(llm: ai.LanguageModel) -> None: @pytest.mark.asyncio async def test_execute_tool_updates_message() -> None: """After execute_tool, the ToolPart in the message has status=result.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - async def graph(llm: ai.LanguageModel) -> None: - result = await ai.stream_step(llm, ai.make_messages(user="go")) + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: + result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: msg = result.last_message for tc in result.tool_calls: @@ -223,29 +208,24 @@ async def graph(llm: ai.LanguageModel) -> None: assert msg.tool_calls[0].result == 10 call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] - result = ai.run(graph, MockLLM([call])) + mock_llm([call]) + result = my_agent.run(ai.make_messages(user="go")) [m async for m in result] -# -- Checkpoint records tools from stream_loop ----------------------------- +# -- Checkpoint records tools from Agent default loop ---------------------- @pytest.mark.asyncio -async def test_stream_loop_checkpoint_records_tools() -> None: - """stream_loop's tool executions are recorded in the checkpoint.""" - - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, - messages=ai.make_messages(user="Double 4"), - tools=[double], - ) +async def test_agent_checkpoint_records_tools() -> None: + """Agent default loop's tool executions are recorded in the checkpoint.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) call1 = [tool_msg(tc_id="tc-1", name="double", args='{"x": 4}')] call2 = [text_msg("8", id="msg-2")] - llm = MockLLM([call1, call2]) + mock_llm([call1, call2]) - result = ai.run(graph, llm) + result = my_agent.run(ai.make_messages(user="Double 4")) [m async for m in result] cp = result.checkpoint diff --git a/tests/agents/test_streams.py b/tests/agents/test_streams.py index f67e5454..db7770ee 100644 --- a/tests/agents/test_streams.py +++ b/tests/agents/test_streams.py @@ -7,7 +7,7 @@ from vercel_ai_sdk.agents.streams import StreamResult from vercel_ai_sdk.types import messages -from ..conftest import MockLLM, text_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg class _Weather(pydantic.BaseModel): @@ -58,9 +58,10 @@ def test_stream_result_tool_calls() -> None: @pytest.mark.asyncio async def test_stream_outside_run_raises() -> None: """@stream-decorated fn called without ai.run() should raise.""" + mock_llm([[text_msg("hi")]]) with pytest.raises(ValueError, match="No Runtime context"): await ai.stream_step( - MockLLM([[text_msg("hi")]]), + MOCK_MODEL, ai.make_messages(user="test"), ) @@ -70,20 +71,23 @@ async def test_stream_outside_run_raises() -> None: @pytest.mark.asyncio async def test_stream_step_replays_from_checkpoint() -> None: - """stream_step inside ai.run with a checkpoint replays without calling LLM.""" + """stream_step inside Agent.run with a checkpoint replays without calling LLM.""" - async def graph(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_step(llm, ai.make_messages(user="hello")) + my_agent = ai.agent(model=MOCK_MODEL) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: + return await ai.stream_step(agent.model, msgs) # First run - llm1 = MockLLM([[text_msg("Hi")]]) - r1 = ai.run(graph, llm1) + mock_llm([[text_msg("Hi")]]) + r1 = my_agent.run(ai.make_messages(user="hello")) [msg async for msg in r1] cp = r1.checkpoint # Replay - llm2 = MockLLM([]) - r2 = ai.run(graph, llm2, checkpoint=cp) + llm2 = mock_llm([]) + r2 = my_agent.run(ai.make_messages(user="hello"), checkpoint=cp) [msg async for msg in r2] assert llm2.call_count == 0 diff --git a/tests/conftest.py b/tests/conftest.py index 981f9db3..e949d1b9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,59 +1,119 @@ from __future__ import annotations -import json from collections.abc import AsyncGenerator, Sequence +from typing import Any, Literal import pydantic import vercel_ai_sdk as ai -from vercel_ai_sdk.types import messages -from vercel_ai_sdk.types.messages import StructuredOutputPart +from vercel_ai_sdk import models +from vercel_ai_sdk.types import messages as messages_ +# A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. +MOCK_MODEL = models.Model(id="mock-model", adapter="mock", provider="mock") -class MockLLM(ai.LanguageModel): - """LLM that yields pre-configured response sequences, one per call.""" +# Register a dummy provider so _auto_client() doesn't error for provider="mock". +models._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") - def __init__(self, responses: list[list[messages.Message]]) -> None: + +class MockAdapter: + """Mock stream adapter that yields pre-configured response sequences. + + Each call to the adapter pops the next response list and yields the + messages through a StreamHandler (matching real adapter behavior). + Tracks ``call_count`` for assertions. + """ + + def __init__(self, responses: list[list[messages_.Message]]) -> None: self._responses = list(responses) self._call_index = 0 self.call_count = 0 async def stream( self, - messages: list[messages.Message], + client: models.Client, + model: models.Model, + messages: list[messages_.Message], + *, tools: Sequence[ai.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - ) -> AsyncGenerator[messages.Message]: + **kwargs: Any, + ) -> AsyncGenerator[messages_.Message]: if self._call_index >= len(self._responses): - raise RuntimeError("MockLLM: no more responses configured") + raise RuntimeError("MockAdapter: no more responses configured") self.call_count += 1 seq = self._responses[self._call_index] self._call_index += 1 - msg = None + + from vercel_ai_sdk.models.core.helpers import streaming as streaming_ + + handler = streaming_.StreamHandler() + for msg in seq: - yield msg + for i, part in enumerate(msg.parts): + if isinstance(part, messages_.TextPart): + bid = f"text-{i}" + yield handler.handle_event(streaming_.TextStart(block_id=bid)) + if part.text: + yield handler.handle_event( + streaming_.TextDelta(block_id=bid, delta=part.text) + ) + yield handler.handle_event(streaming_.TextEnd(block_id=bid)) + + elif isinstance(part, messages_.ReasoningPart): + bid = f"reasoning-{i}" + yield handler.handle_event(streaming_.ReasoningStart(block_id=bid)) + if part.text: + yield handler.handle_event( + streaming_.ReasoningDelta(block_id=bid, delta=part.text) + ) + yield handler.handle_event( + streaming_.ReasoningEnd(block_id=bid, signature=part.signature) + ) + + elif isinstance(part, messages_.ToolPart): + yield handler.handle_event( + streaming_.ToolStart( + tool_call_id=part.tool_call_id, + tool_name=part.tool_name, + ) + ) + if part.tool_args: + yield handler.handle_event( + streaming_.ToolArgsDelta( + tool_call_id=part.tool_call_id, + delta=part.tool_args, + ) + ) + yield handler.handle_event( + streaming_.ToolEnd(tool_call_id=part.tool_call_id) + ) - # Simulate structured output validation (matching real provider behavior) - if output_type is not None and msg is not None and msg.text: - data = json.loads(msg.text) - output_type.model_validate(data) # fail fast on bad data - part = StructuredOutputPart( - data=data, - output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", - ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] - yield msg + yield handler.handle_event(streaming_.MessageDone()) + + +def mock_llm(responses: list[list[messages_.Message]]) -> MockAdapter: + """Create a MockAdapter and register it in the models adapter registry. + + Returns the adapter so tests can inspect ``call_count``. + """ + adapter = MockAdapter(responses) + models.register_stream("mock", adapter.stream) + return adapter + + +# ── Helpers ────────────────────────────────────────────────────── def text_msg( - text: str, *, id: str = "msg-1", state: str = "done", delta: str | None = None -) -> messages.Message: - return messages.Message( - id=id, - role="assistant", - parts=[messages.TextPart(text=text, state=state, delta=delta)], - ) + text: str, + *, + id: str = "msg-1", + state: messages_.PartState | None = "done", + delta: str | None = None, +) -> messages_.Message: + part: messages_.Part = messages_.TextPart(text=text, state=state, delta=delta) + return messages_.Message(id=id, role="assistant", parts=[part]) def tool_msg( @@ -62,20 +122,15 @@ def tool_msg( tc_id: str = "tc-1", name: str = "test_tool", args: str = "{}", - status: str = "pending", + status: Literal["pending", "result", "error"] = "pending", result: dict[str, object] | None = None, -) -> messages.Message: - return messages.Message( - id=id, - role="assistant", - parts=[ - messages.ToolPart( - tool_call_id=tc_id, - tool_name=name, - tool_args=args, - status=status, - result=result, - state="done", - ) - ], +) -> messages_.Message: + part: messages_.Part = messages_.ToolPart( + tool_call_id=tc_id, + tool_name=name, + tool_args=args, + status=status, + result=result, + state="done", ) + return messages_.Message(id=id, role="assistant", parts=[part]) diff --git a/tests/models/ai_gateway/test_gateway_image.py b/tests/models/ai_gateway/test_generate_image.py similarity index 69% rename from tests/models/ai_gateway/test_gateway_image.py rename to tests/models/ai_gateway/test_generate_image.py index 660457db..ca91365a 100644 --- a/tests/models/ai_gateway/test_gateway_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -1,9 +1,10 @@ -"""Integration tests for ``GatewayImageModel``. +"""Integration tests for the AI Gateway v3 image generation adapter. -Every test exercises the real ``model.generate()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: +Every test exercises the real ``generate()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: - model.generate() + generate(client, model, messages, ImageParams(...)) → extract prompt/images from messages → httpx POST (mock) to /image-model → JSON response parsing @@ -20,7 +21,13 @@ import httpx import pytest -from vercel_ai_sdk.models.ai_gateway import GatewayImageModel, errors +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.ai_gateway.generate import ( + ImageParams, + generate, +) +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages # 1x1 transparent PNG (minimal valid PNG for magic-byte detection) @@ -31,24 +38,25 @@ _JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) _JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() +_IMAGE_MODEL = model_.Model( + id="google/imagen-4.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("image",), +) + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- -def _image_model( - handler: httpx.MockTransport, - *, - model: str = "google/imagen-4.0-generate-001", - api_key: str = "test-key", -) -> GatewayImageModel: - return GatewayImageModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - _transport=handler, - ) +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c def _user(text: str) -> messages.Message: @@ -66,7 +74,7 @@ def _user(text: str) -> messages.Message: class TestGenerate: @pytest.mark.asyncio async def test_basic_image_generation(self) -> None: - """Simple prompt → one PNG image back.""" + """Simple prompt -> one PNG image back.""" def handler(req: httpx.Request) -> httpx.Response: return httpx.Response( @@ -74,8 +82,8 @@ def handler(req: httpx.Request) -> httpx.Response: json={"images": [_PNG_B64]}, ) - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("A sunset over Tokyo")]) + client = _client(httpx.MockTransport(handler)) + msg = await generate(client, _IMAGE_MODEL, [_user("A sunset over Tokyo")]) assert msg.role == "assistant" assert len(msg.images) == 1 @@ -94,8 +102,13 @@ def handler(req: httpx.Request) -> httpx.Response: json={"images": [_PNG_B64, _JPEG_B64, _PNG_B64]}, ) - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("Three cats")], n=3) + client = _client(httpx.MockTransport(handler)) + msg = await generate( + client, + _IMAGE_MODEL, + [_user("Three cats")], + params=ImageParams(n=3), + ) assert len(msg.images) == 3 assert msg.images[0].media_type == "image/png" @@ -115,8 +128,8 @@ def handler(req: httpx.Request) -> httpx.Response: }, ) - model = _image_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("a dog")]) + client = _client(httpx.MockTransport(handler)) + msg = await generate(client, _IMAGE_MODEL, [_user("a dog")]) assert msg.usage is not None assert msg.usage.input_tokens == 50 @@ -137,12 +150,14 @@ def handler(req: httpx.Request) -> httpx.Response: captured.update(dict(req.headers)) return httpx.Response(200, json={"images": [_PNG_B64]}) - model = _image_model( - httpx.MockTransport(handler), - model="openai/gpt-image-1", - api_key="sk-test", + model = model_.Model( + id="openai/gpt-image-1", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("image",), ) - await model.generate([_user("Hi")]) + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await generate(client, model, [_user("Hi")]) assert captured["authorization"] == "Bearer sk-test" assert captured["ai-image-model-specification-version"] == "3" @@ -157,14 +172,18 @@ def handler(req: httpx.Request) -> httpx.Response: captured_body.update(json.loads(req.content)) return httpx.Response(200, json={"images": [_PNG_B64]}) - model = _image_model(httpx.MockTransport(handler)) - await model.generate( + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _IMAGE_MODEL, [_user("landscape")], - n=2, - size="1024x1024", - aspect_ratio="16:9", - seed=42, - provider_options={"google": {"style": "vivid"}}, + params=ImageParams( + n=2, + size="1024x1024", + aspect_ratio="16:9", + seed=42, + provider_options={"google": {"style": "vivid"}}, + ), ) assert captured_body["prompt"] == "landscape" @@ -176,7 +195,7 @@ def handler(req: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_input_images_forwarded(self) -> None: - """Input images from user messages → files in request body.""" + """Input images from user messages -> files in request body.""" captured_body: dict[str, Any] = {} def handler(req: httpx.Request) -> httpx.Response: @@ -190,8 +209,8 @@ def handler(req: httpx.Request) -> httpx.Response: messages.FilePart(data=_PNG_B64, media_type="image/png"), ], ) - model = _image_model(httpx.MockTransport(handler)) - await model.generate([user_msg]) + client = _client(httpx.MockTransport(handler)) + await generate(client, _IMAGE_MODEL, [user_msg]) assert captured_body["prompt"] == "Edit this" assert "files" in captured_body @@ -207,8 +226,8 @@ def handler(req: httpx.Request) -> httpx.Response: captured_url.append(str(req.url)) return httpx.Response(200, json={"images": [_PNG_B64]}) - model = _image_model(httpx.MockTransport(handler)) - await model.generate([_user("test")]) + client = _client(httpx.MockTransport(handler)) + await generate(client, _IMAGE_MODEL, [_user("test")]) assert captured_url[0] == "https://gw.test/v3/ai/image-model" @@ -233,7 +252,11 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayAuthenticationError): - await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) @pytest.mark.asyncio async def test_429_rate_limit_error(self) -> None: @@ -249,14 +272,22 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayRateLimitError): - await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) @pytest.mark.asyncio async def test_empty_images_returns_empty_message(self) -> None: - """Gateway returns empty images array → message with no parts.""" + """Gateway returns empty images array -> message with no parts.""" def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, json={"images": []}) - msg = await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + msg = await generate( + _client(httpx.MockTransport(handler)), + _IMAGE_MODEL, + [_user("test")], + ) assert len(msg.images) == 0 diff --git a/tests/models/ai_gateway/test_gateway_video.py b/tests/models/ai_gateway/test_generate_video.py similarity index 72% rename from tests/models/ai_gateway/test_gateway_video.py rename to tests/models/ai_gateway/test_generate_video.py index 07de0ecf..06dc6b91 100644 --- a/tests/models/ai_gateway/test_gateway_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -1,9 +1,10 @@ -"""Integration tests for ``GatewayVideoModel``. +"""Integration tests for the AI Gateway v3 video generation adapter. -Every test exercises the real ``model.generate()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: +Every test exercises the real ``generate()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: - model.generate() + generate(client, model, messages, VideoParams(...)) → extract prompt/image from messages → httpx POST (mock) to /video-model with SSE accept → SSE event parsing @@ -21,7 +22,13 @@ import httpx import pytest -from vercel_ai_sdk.models.ai_gateway import GatewayVideoModel, errors +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.ai_gateway.generate import ( + VideoParams, + generate, +) +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages # MP4 magic bytes (ftyp box) @@ -34,6 +41,13 @@ _WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) _WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() +_VIDEO_MODEL = model_.Model( + id="google/veo-3.0-generate-001", + adapter="ai-gateway-v3", + provider="ai-gateway", + capabilities=("video",), +) + # --------------------------------------------------------------------------- # Helpers @@ -45,18 +59,12 @@ def _sse(*events: dict[str, Any]) -> str: return "".join(f"data: {json.dumps(e)}\n\n" for e in events) -def _video_model( - handler: httpx.MockTransport, - *, - model: str = "google/veo-3.0-generate-001", - api_key: str = "test-key", -) -> GatewayVideoModel: - return GatewayVideoModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - _transport=handler, - ) +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c def _user(text: str) -> messages.Message: @@ -74,7 +82,7 @@ def _user(text: str) -> messages.Message: class TestGenerate: @pytest.mark.asyncio async def test_basic_video_generation_base64(self) -> None: - """Simple prompt → one MP4 video back via base64.""" + """Simple prompt -> one MP4 video back via base64.""" body = _sse( { "type": "result", @@ -87,8 +95,13 @@ async def test_basic_video_generation_base64(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - model = _video_model(httpx.MockTransport(handler)) - msg = await model.generate([_user("A cat walking on a beach")]) + client = _client(httpx.MockTransport(handler)) + msg = await generate( + client, + _VIDEO_MODEL, + [_user("A cat walking on a beach")], + params=VideoParams(), + ) assert msg.role == "assistant" assert len(msg.videos) == 1 @@ -97,7 +110,7 @@ def handler(req: httpx.Request) -> httpx.Response: @pytest.mark.asyncio async def test_video_generation_url(self) -> None: - """Video returned as URL → downloaded automatically.""" + """Video returned as URL -> downloaded automatically.""" body = _sse( { "type": "result", @@ -114,14 +127,19 @@ async def test_video_generation_url(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - model = _video_model(httpx.MockTransport(handler)) + client = _client(httpx.MockTransport(handler)) with patch( - "vercel_ai_sdk.models.core.media.download.download", + "vercel_ai_sdk.models.core.helpers.media.download", new_callable=AsyncMock, return_value=(_MP4_HEADER, "video/mp4"), ) as mock_dl: - msg = await model.generate([_user("A sunset timelapse")]) + msg = await generate( + client, + _VIDEO_MODEL, + [_user("A sunset timelapse")], + params=VideoParams(), + ) mock_dl.assert_called_once_with("https://storage.example.com/video.mp4") assert len(msg.videos) == 1 @@ -143,8 +161,11 @@ async def test_multiple_videos(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - msg = await _video_model(httpx.MockTransport(handler)).generate( - [_user("Two versions")], n=2 + msg = await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("Two versions")], + params=VideoParams(n=2), ) assert len(msg.videos) == 2 assert msg.videos[0].media_type == "video/mp4" @@ -179,12 +200,13 @@ def handler(req: httpx.Request) -> httpx.Response: ), ) - model = _video_model( - httpx.MockTransport(handler), - model="google/veo-3.0-generate-001", - api_key="sk-test", + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await generate( + client, + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), ) - await model.generate([_user("test")]) assert captured["authorization"] == "Bearer sk-test" assert captured["ai-video-model-specification-version"] == "3" @@ -214,23 +236,27 @@ def handler(req: httpx.Request) -> httpx.Response: ), ) - model = _video_model(httpx.MockTransport(handler)) - await model.generate( + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _VIDEO_MODEL, [_user("sunset")], - n=2, - aspect_ratio="16:9", - resolution="1920x1080", - duration=5.0, - fps=30, - seed=42, - provider_options={"google": {"enhancePrompt": True}}, + params=VideoParams( + n=2, + aspect_ratio="16:9", + resolution="1920x1080", + duration=5, + fps=30, + seed=42, + provider_options={"google": {"enhancePrompt": True}}, + ), ) assert captured_body["prompt"] == "sunset" assert captured_body["n"] == 2 assert captured_body["aspectRatio"] == "16:9" assert captured_body["resolution"] == "1920x1080" - assert captured_body["duration"] == 5.0 + assert captured_body["duration"] == 5 assert captured_body["fps"] == 30 assert captured_body["seed"] == 42 assert captured_body["providerOptions"] == {"google": {"enhancePrompt": True}} @@ -257,14 +283,19 @@ def handler(req: httpx.Request) -> httpx.Response: ), ) - model = _video_model(httpx.MockTransport(handler)) - await model.generate([_user("test")]) + client = _client(httpx.MockTransport(handler)) + await generate( + client, + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) assert captured_url[0] == "https://gw.test/v3/ai/video-model" @pytest.mark.asyncio async def test_image_to_video_input(self) -> None: - """Image in user message → image field in request body.""" + """Image in user message -> image field in request body.""" captured_body: dict[str, Any] = {} def handler(req: httpx.Request) -> httpx.Response: @@ -293,8 +324,8 @@ def handler(req: httpx.Request) -> httpx.Response: messages.FilePart(data=png_b64, media_type="image/png"), ], ) - model = _video_model(httpx.MockTransport(handler)) - await model.generate([user_msg]) + client = _client(httpx.MockTransport(handler)) + await generate(client, _VIDEO_MODEL, [user_msg], params=VideoParams()) assert captured_body["prompt"] == "Animate this" assert "image" in captured_body @@ -310,7 +341,7 @@ def handler(req: httpx.Request) -> httpx.Response: class TestErrors: @pytest.mark.asyncio async def test_sse_error_event(self) -> None: - """Gateway returns an SSE error event → raises.""" + """Gateway returns an SSE error event -> raises.""" body = _sse( { "type": "error", @@ -325,7 +356,12 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) with pytest.raises(errors.GatewayInvalidRequestError, match="Content policy"): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) @pytest.mark.asyncio async def test_401_authentication_error(self) -> None: @@ -341,14 +377,24 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayAuthenticationError): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) @pytest.mark.asyncio async def test_empty_sse_stream(self) -> None: - """SSE stream with no data events → raises.""" + """SSE stream with no data events -> raises.""" def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text="") with pytest.raises(errors.GatewayResponseError, match="SSE stream ended"): - await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) + await generate( + _client(httpx.MockTransport(handler)), + _VIDEO_MODEL, + [_user("test")], + params=VideoParams(), + ) diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index 38c27e73..512c83e3 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -1,15 +1,15 @@ """Tests for the v3 protocol serialization and deserialization. Focus areas: -- ``messages_to_v3_prompt``: the critical outgoing translation layer -- ``tools_to_v3`` / ``build_request_body``: using real ``@tool`` -- ``parse_stream_part``: the critical incoming translation layer -- ``parse_generate_result``: non-streaming response handling +- ``_messages_to_prompt``: the critical outgoing translation layer +- ``_build_request_body``: using real ``@tool`` +- ``_parse_stream_part``: the critical incoming translation layer - ``_parse_usage``: the two distinct wire formats """ from __future__ import annotations +import importlib import json from unittest.mock import AsyncMock, patch @@ -17,17 +17,20 @@ import pytest import vercel_ai_sdk as ai -from vercel_ai_sdk.models.ai_gateway import protocol -from vercel_ai_sdk.models.core import llm +from vercel_ai_sdk.models.core.helpers import streaming from vercel_ai_sdk.types import messages +# The ai_gateway __init__.py re-exports `stream` as a function, which +# shadows the module. Use importlib to get the actual module. +stream_mod = importlib.import_module("vercel_ai_sdk.models.ai_gateway.stream") + # --------------------------------------------------------------------------- -# messages_to_v3_prompt +# _messages_to_prompt # --------------------------------------------------------------------------- @pytest.mark.asyncio -class TestMessagesToV3Prompt: +class TestMessagesToPrompt: async def test_system_message(self) -> None: msgs = [ messages.Message( @@ -35,7 +38,7 @@ async def test_system_message(self) -> None: parts=[messages.TextPart(text="You are helpful.")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [{"role": "system", "content": "You are helpful."}] async def test_user_message(self) -> None: @@ -45,7 +48,7 @@ async def test_user_message(self) -> None: parts=[messages.TextPart(text="Hello")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [ { "role": "user", @@ -63,7 +66,7 @@ async def test_assistant_with_reasoning_and_text(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "reasoning", "text": "Let me think..."} assert content[1] == {"type": "text", "text": "42"} @@ -85,7 +88,7 @@ async def test_tool_call_with_result_produces_two_messages(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert len(result) == 2 # Assistant message has the tool-call @@ -114,13 +117,13 @@ async def test_tool_error_result(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) tr = result[1]["content"][0] assert tr["output"]["type"] == "error-text" assert tr["output"]["value"] == "Connection timeout" async def test_user_message_with_image_url(self) -> None: - """FilePart with image URL → downloaded and converted to data: URL.""" + """FilePart with image URL -> downloaded and converted to data: URL.""" fake_jpeg = b"\xff\xd8\xff\xe0" msgs = [ messages.Message( @@ -134,11 +137,11 @@ async def test_user_message_with_image_url(self) -> None: ) ] with patch( - "vercel_ai_sdk.models.core.media.download.download", + "vercel_ai_sdk.models.core.helpers.media.download", new_callable=AsyncMock, return_value=(fake_jpeg, "image/jpeg"), ): - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "text", "text": "Look at this"} assert content[1]["type"] == "file" @@ -146,7 +149,7 @@ async def test_user_message_with_image_url(self) -> None: assert content[1]["data"].startswith("data:image/jpeg;base64,") async def test_user_message_with_file_bytes(self) -> None: - """FilePart with bytes → v3 file content part with data URL.""" + """FilePart with bytes -> v3 file content part with data URL.""" msgs = [ messages.Message( role="user", @@ -157,7 +160,7 @@ async def test_user_message_with_file_bytes(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) part = result[0]["content"][0] assert part["type"] == "file" assert part["mediaType"] == "image/png" @@ -172,7 +175,7 @@ async def test_user_message_text_only_unchanged(self) -> None: parts=[messages.TextPart(text="Hello")], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert result == [ {"role": "user", "content": [{"type": "text", "text": "Hello"}]} ] @@ -192,13 +195,13 @@ async def test_pending_tool_call_no_tool_message(self) -> None: ], ) ] - result = await protocol.messages_to_v3_prompt(msgs) + result = await stream_mod._messages_to_prompt(msgs) assert len(result) == 1 assert result[0]["role"] == "assistant" # --------------------------------------------------------------------------- -# tools_to_v3 / build_request_body — using real @tool +# _build_request_body — using real @tool # --------------------------------------------------------------------------- @@ -212,14 +215,14 @@ async def get_weather(city: str, units: str = "celsius") -> str: class TestBuildRequestBody: async def test_with_real_tool(self) -> None: """Verify @tool-produced schema round-trips through - build_request_body → JSON → gateway wire format.""" + _build_request_body -> JSON -> gateway wire format.""" msgs = [ messages.Message( role="user", parts=[messages.TextPart(text="What's the weather?")], ) ] - body = await protocol.build_request_body(msgs, tools=[get_weather]) + body = await stream_mod._build_request_body(msgs, tools=[get_weather]) assert "tools" in body tool_def = body["tools"][0] @@ -245,7 +248,7 @@ class WeatherResult(pydantic.BaseModel): parts=[messages.TextPart(text="Weather?")], ) ] - body = await protocol.build_request_body(msgs, output_type=WeatherResult) + body = await stream_mod._build_request_body(msgs, output_type=WeatherResult) assert "responseFormat" in body rf = body["responseFormat"] @@ -262,46 +265,46 @@ async def test_provider_options_passthrough(self) -> None: ) ] opts = {"gateway": {"order": ["bedrock", "openai"]}} - body = await protocol.build_request_body(msgs, provider_options=opts) + body = await stream_mod._build_request_body(msgs, provider_options=opts) assert body["providerOptions"] == opts # --------------------------------------------------------------------------- -# parse_stream_part — parametrized simple 1:1 mappings +# _parse_stream_part — parametrized simple 1:1 mappings # --------------------------------------------------------------------------- _SIMPLE_STREAM_PARTS = [ ( {"type": "text-start", "id": "t1"}, - llm.TextStart(block_id="t1"), + streaming.TextStart(block_id="t1"), ), ( {"type": "text-end", "id": "t1"}, - llm.TextEnd(block_id="t1"), + streaming.TextEnd(block_id="t1"), ), ( {"type": "reasoning-start", "id": "r1"}, - llm.ReasoningStart(block_id="r1"), + streaming.ReasoningStart(block_id="r1"), ), ( {"type": "reasoning-delta", "id": "r1", "delta": "hmm"}, - llm.ReasoningDelta(block_id="r1", delta="hmm"), + streaming.ReasoningDelta(block_id="r1", delta="hmm"), ), ( {"type": "reasoning-end", "id": "r1"}, - llm.ReasoningEnd(block_id="r1"), + streaming.ReasoningEnd(block_id="r1"), ), ( {"type": "tool-input-start", "id": "tc-1", "toolName": "search"}, - llm.ToolStart(tool_call_id="tc-1", tool_name="search"), + streaming.ToolStart(tool_call_id="tc-1", tool_name="search"), ), ( {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q"'}, - llm.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), + streaming.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), ), ( {"type": "tool-input-end", "id": "tc-1"}, - llm.ToolEnd(tool_call_id="tc-1"), + streaming.ToolEnd(tool_call_id="tc-1"), ), ] @@ -312,9 +315,9 @@ async def test_provider_options_passthrough(self) -> None: ids=[w["type"] for w, _ in _SIMPLE_STREAM_PARTS], ) def test_parse_stream_part_simple( - wire: dict[str, object], expected: llm.StreamEvent + wire: dict[str, object], expected: streaming.StreamEvent ) -> None: - events = protocol.parse_stream_part(wire) + events = stream_mod._parse_stream_part(wire) assert len(events) == 1 assert events[0] == expected @@ -323,16 +326,16 @@ def test_parse_stream_part_simple( class TestParseStreamPartComplex: async def test_text_delta_uses_textDelta_key(self) -> None: """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( {"type": "text-delta", "id": "t1", "textDelta": "Hello"} ) - assert isinstance(events[0], llm.TextDelta) + assert isinstance(events[0], streaming.TextDelta) assert events[0].delta == "Hello" async def test_tool_call_expands_to_three_events(self) -> None: """A complete ``tool-call`` part must expand into - ToolStart → ToolArgsDelta → ToolEnd.""" - events = protocol.parse_stream_part( + ToolStart -> ToolArgsDelta -> ToolEnd.""" + events = stream_mod._parse_stream_part( { "type": "tool-call", "toolCallId": "tc-1", @@ -341,14 +344,14 @@ async def test_tool_call_expands_to_three_events(self) -> None: } ) assert len(events) == 3 - assert isinstance(events[0], llm.ToolStart) + assert isinstance(events[0], streaming.ToolStart) assert events[0].tool_name == "get_weather" - assert isinstance(events[1], llm.ToolArgsDelta) + assert isinstance(events[1], streaming.ToolArgsDelta) assert json.loads(events[1].delta) == {"city": "SF"} - assert isinstance(events[2], llm.ToolEnd) + assert isinstance(events[2], streaming.ToolEnd) async def test_finish_flat_usage(self) -> None: - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "finish", "finishReason": "stop", @@ -359,14 +362,14 @@ async def test_finish_flat_usage(self) -> None: } ) done = events[0] - assert isinstance(done, llm.MessageDone) + assert isinstance(done, streaming.MessageDone) assert done.finish_reason == "stop" assert done.usage is not None assert done.usage.input_tokens == 10 assert done.usage.output_tokens == 20 async def test_finish_v3_nested_usage(self) -> None: - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "finish", "finishReason": { @@ -386,7 +389,7 @@ async def test_finish_v3_nested_usage(self) -> None: } ) done = events[0] - assert isinstance(done, llm.MessageDone) + assert isinstance(done, streaming.MessageDone) assert done.finish_reason == "tool-calls" assert done.usage is not None assert done.usage.input_tokens == 100 @@ -396,7 +399,7 @@ async def test_finish_v3_nested_usage(self) -> None: async def test_file_part(self) -> None: """A ``file`` stream part (inline image from Gemini/GPT-5) must produce a FileEvent.""" - events = protocol.parse_stream_part( + events = stream_mod._parse_stream_part( { "type": "file", "id": "f1", @@ -405,82 +408,21 @@ async def test_file_part(self) -> None: } ) assert len(events) == 1 - assert isinstance(events[0], llm.FileEvent) + assert isinstance(events[0], streaming.FileEvent) assert events[0].block_id == "f1" assert events[0].media_type == "image/png" assert events[0].data == "iVBORw0KGgo=" async def test_file_part_defaults(self) -> None: """A minimal ``file`` part uses sensible defaults.""" - events = protocol.parse_stream_part({"type": "file", "data": "somedata"}) + events = stream_mod._parse_stream_part({"type": "file", "data": "somedata"}) assert len(events) == 1 - assert isinstance(events[0], llm.FileEvent) + assert isinstance(events[0], streaming.FileEvent) assert events[0].media_type == "application/octet-stream" async def test_unknown_types_produce_no_events(self) -> None: for t in ("stream-start", "raw", "response-metadata", "banana"): - assert protocol.parse_stream_part({"type": t}) == [] - - -# --------------------------------------------------------------------------- -# parse_generate_result -# --------------------------------------------------------------------------- - - -@pytest.mark.asyncio -class TestParseGenerateResult: - async def test_text_content(self) -> None: - events = protocol.parse_generate_result( - { - "content": [{"type": "text", "text": "Hello!"}], - "finishReason": "stop", - "usage": {"prompt_tokens": 4, "completion_tokens": 10}, - } - ) - # TextStart + TextDelta + TextEnd + MessageDone - assert len(events) == 4 - assert isinstance(events[1], llm.TextDelta) - assert events[1].delta == "Hello!" - assert isinstance(events[3], llm.MessageDone) - - async def test_tool_call_content(self) -> None: - events = protocol.parse_generate_result( - { - "content": [ - { - "type": "tool-call", - "toolCallId": "tc-1", - "toolName": "search", - "input": {"query": "weather"}, - } - ], - "finishReason": "tool-calls", - } - ) - assert isinstance(events[0], llm.ToolStart) - assert isinstance(events[3], llm.MessageDone) - assert events[3].finish_reason == "tool-calls" - - async def test_file_content(self) -> None: - """A ``file`` part in non-streaming result produces a FileEvent.""" - events = protocol.parse_generate_result( - { - "content": [ - { - "type": "file", - "id": "f1", - "mediaType": "image/png", - "data": "iVBORw0KGgo=", - } - ], - "finishReason": "stop", - "usage": {"prompt_tokens": 10, "completion_tokens": 5}, - } - ) - file_events = [e for e in events if isinstance(e, llm.FileEvent)] - assert len(file_events) == 1 - assert file_events[0].media_type == "image/png" - assert isinstance(events[-1], llm.MessageDone) + assert stream_mod._parse_stream_part({"type": t}) == [] # --------------------------------------------------------------------------- @@ -491,12 +433,12 @@ async def test_file_content(self) -> None: @pytest.mark.asyncio class TestParseUsage: async def test_flat_format(self) -> None: - usage = protocol._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) + usage = stream_mod._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) assert usage.input_tokens == 10 assert usage.output_tokens == 20 async def test_v3_nested_format(self) -> None: - usage = protocol._parse_usage( + usage = stream_mod._parse_usage( { "inputTokens": { "total": 100, @@ -513,6 +455,6 @@ async def test_v3_nested_format(self) -> None: assert usage.reasoning_tokens == 10 async def test_non_dict_returns_empty(self) -> None: - usage = protocol._parse_usage("not a dict") + usage = stream_mod._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 diff --git a/tests/models/ai_gateway/test_gateway.py b/tests/models/ai_gateway/test_stream.py similarity index 81% rename from tests/models/ai_gateway/test_gateway.py rename to tests/models/ai_gateway/test_stream.py index 2ac3e64b..784dfac1 100644 --- a/tests/models/ai_gateway/test_gateway.py +++ b/tests/models/ai_gateway/test_stream.py @@ -1,19 +1,21 @@ -"""Integration tests for ``GatewayModel``. +"""Integration tests for the AI Gateway v3 streaming adapter. -Every test exercises the real ``model.stream()`` method with an injected -``httpx.MockTransport``, so the full production code path is covered: +Every test exercises the real ``stream()`` function with a ``Client`` +wired to an ``httpx.MockTransport``, so the full production code path +is covered: - model.stream() - → build_request_body() + stream(client, model, messages) + → _build_request_body() → httpx POST (mock) → SSE line parsing - → parse_stream_part() + → _parse_stream_part() → StreamHandler → yield Message """ from __future__ import annotations +import importlib import json from typing import Any @@ -21,44 +23,49 @@ import pytest import vercel_ai_sdk as ai -from vercel_ai_sdk.models.ai_gateway import GatewayModel, errors +from vercel_ai_sdk.models.ai_gateway import errors +from vercel_ai_sdk.models.core import client as client_ +from vercel_ai_sdk.models.core import model as model_ from vercel_ai_sdk.types import messages +# The ai_gateway __init__.py re-exports `stream` as a function, which +# shadows the module. Use importlib to get the actual module. +stream_mod = importlib.import_module("vercel_ai_sdk.models.ai_gateway.stream") + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- +_TEST_MODEL = model_.Model( + id="test-provider/test-model", + adapter="ai-gateway-v3", + provider="ai-gateway", +) + def _sse(*events: dict[str, Any]) -> str: """Build SSE response text from event dicts.""" return "".join(f"data: {json.dumps(e)}\n\n" for e in events) -def _gateway( - handler: httpx.MockTransport, - *, - model: str = "test-provider/test-model", - api_key: str = "test-key", - provider_options: dict[str, Any] | None = None, -) -> GatewayModel: - """Create a ``GatewayModel`` wired to a mock transport.""" - return GatewayModel( - model=model, - api_key=api_key, - base_url="https://gw.test/v3/ai", - provider_options=provider_options, - _transport=handler, - ) +def _client( + handler: httpx.MockTransport, *, api_key: str = "test-key" +) -> client_.Client: + """Create a Client wired to a mock transport.""" + c = client_.Client(base_url="https://gw.test/v3/ai", api_key=api_key) + c._http = httpx.AsyncClient(transport=handler) + return c async def _collect( - model: GatewayModel, + client: client_.Client, msgs: list[messages.Message], + model: model_.Model = _TEST_MODEL, **kwargs: Any, ) -> list[messages.Message]: - """Drain ``model.stream()`` and return all yielded messages.""" + """Drain ``stream()`` and return all yielded messages.""" result: list[messages.Message] = [] - async for msg in model.stream(msgs, **kwargs): + async for msg in stream_mod.stream(client, model, msgs, **kwargs): result.append(msg) return result @@ -96,8 +103,8 @@ async def test_text_stream(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - model = _gateway(httpx.MockTransport(handler)) - msgs = await _collect(model, [_user("Hi")]) + client = _client(httpx.MockTransport(handler)) + msgs = await _collect(client, [_user("Hi")]) final = msgs[-1] assert final.text == "Hello World" @@ -121,7 +128,7 @@ async def test_reasoning_then_text(self) -> None: def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) - final = (await _collect(_gateway(httpx.MockTransport(handler)), [_user("?")]))[ + final = (await _collect(_client(httpx.MockTransport(handler)), [_user("?")]))[ -1 ] assert final.reasoning == "think" @@ -149,7 +156,7 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("search")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("search")]) )[-1] tc = final.tool_calls assert len(tc) == 1 @@ -181,7 +188,7 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("draw me")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("draw me")]) )[-1] assert final.text == "Here is an image:" assert len(final.images) == 1 @@ -210,7 +217,7 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, text=body) final = ( - await _collect(_gateway(httpx.MockTransport(handler)), [_user("weather")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("weather")]) )[-1] assert len(final.tool_calls) == 1 assert json.loads(final.tool_calls[0].tool_args) == {"city": "SF"} @@ -233,12 +240,13 @@ def handler(req: httpx.Request) -> httpx.Response: text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), ) - model = _gateway( - httpx.MockTransport(handler), - model="anthropic/claude-sonnet-4", - api_key="sk-test", + model = model_.Model( + id="anthropic/claude-sonnet-4", + adapter="ai-gateway-v3", + provider="ai-gateway", ) - await _collect(model, [_user("Hi")]) + client = _client(httpx.MockTransport(handler), api_key="sk-test") + await _collect(client, [_user("Hi")], model=model) assert captured["authorization"] == "Bearer sk-test" assert captured["ai-gateway-protocol-version"] == "0.0.1" @@ -258,7 +266,7 @@ def handler(req: httpx.Request) -> httpx.Response: text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), ) - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hello")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("Hello")]) assert captured_body["prompt"] == [ { @@ -280,8 +288,9 @@ def handler(req: httpx.Request) -> httpx.Response: opts = {"gateway": {"order": ["bedrock", "openai"]}} await _collect( - _gateway(httpx.MockTransport(handler), provider_options=opts), + _client(httpx.MockTransport(handler)), [_user("Hi")], + provider_options=opts, ) assert captured_body["providerOptions"] == opts @@ -306,7 +315,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) await _collect( - _gateway(httpx.MockTransport(handler)), + _client(httpx.MockTransport(handler)), [_user("find something")], tools=[lookup], ) @@ -343,7 +352,7 @@ def handler(req: httpx.Request) -> httpx.Response: _user("Thanks, and tomorrow?"), ] - await _collect(_gateway(httpx.MockTransport(handler)), conversation) + await _collect(_client(httpx.MockTransport(handler)), conversation) prompt = captured_body["prompt"] # user → assistant (tool-call) → tool (tool-result) → user @@ -376,7 +385,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayAuthenticationError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) @pytest.mark.asyncio async def test_429_rate_limit_error(self) -> None: @@ -392,7 +401,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayRateLimitError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) @pytest.mark.asyncio async def test_404_model_not_found(self) -> None: @@ -409,7 +418,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) with pytest.raises(errors.GatewayModelNotFoundError) as exc_info: - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) assert exc_info.value.model_id == "xyz" @pytest.mark.asyncio @@ -418,4 +427,4 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(500, text="Not JSON") with pytest.raises(errors.GatewayResponseError): - await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + await _collect(_client(httpx.MockTransport(handler)), [_user("Hi")]) diff --git a/tests/models/anthropic/__init__.py b/tests/models/anthropic/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/anthropic/test_anthropic.py b/tests/models/anthropic/test_anthropic.py deleted file mode 100644 index 8c9633c7..00000000 --- a/tests/models/anthropic/test_anthropic.py +++ /dev/null @@ -1,390 +0,0 @@ -"""Anthropic provider: _messages_to_anthropic conversion tests.""" - -import base64 - -import pytest - -from vercel_ai_sdk.models.anthropic import _messages_to_anthropic -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart, ToolPart - -pytestmark = pytest.mark.asyncio - - -async def test_tool_result_none_still_emits_tool_result() -> None: - """A tool that returns None must still produce a tool_result block. - - Regression: when part.result is None the converter skipped the tool_result, - leaving a tool_use without a matching tool_result. Anthropic rejects this - with: "tool_use ids were found without tool_result blocks immediately after". - """ - tool_part = ToolPart( - tool_call_id="toolu_01abc", - tool_name="send_notification", - tool_args="{}", - ) - tool_part.set_result(None) # tool returned None (fire-and-forget style) - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should have: assistant message with tool_use, then user message with tool_result - assert len(anthropic_msgs) == 2, ( - f"Expected 2 messages (assistant + user/tool_result), " - f"got {len(anthropic_msgs)}: {anthropic_msgs}" - ) - - assistant_msg = anthropic_msgs[0] - assert assistant_msg["role"] == "assistant" - assert any(block["type"] == "tool_use" for block in assistant_msg["content"]) - - user_msg = anthropic_msgs[1] - assert user_msg["role"] == "user" - tool_results = [b for b in user_msg["content"] if b["type"] == "tool_result"] - assert len(tool_results) == 1 - assert tool_results[0]["tool_use_id"] == "toolu_01abc" - - -async def test_tool_with_normal_result() -> None: - """Baseline: a tool with a normal result produces the correct pair.""" - tool_part = ToolPart( - tool_call_id="toolu_02xyz", - tool_name="get_weather", - tool_args='{"city": "SF"}', - ) - tool_part.set_result({"temp": 62}) - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - assert anthropic_msgs[1]["content"][0]["content"] == "{'temp': 62}" - - -async def test_tool_error_produces_tool_result() -> None: - """Tool errors must also produce a tool_result block (with is_error=True).""" - tool_part = ToolPart( - tool_call_id="toolu_03err", - tool_name="failing_tool", - tool_args="{}", - ) - tool_part.set_error("Connection timeout") - - messages = [ - Message(role="assistant", parts=[tool_part]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - tool_result = anthropic_msgs[1]["content"][0] - assert tool_result["type"] == "tool_result" - assert tool_result["is_error"] is True - assert tool_result["content"] == "Connection timeout" - - -async def test_multiple_tools_one_returns_none() -> None: - """When one of several tools returns None, all must have tool_results.""" - tool_a = ToolPart( - tool_call_id="toolu_a", - tool_name="tool_a", - tool_args="{}", - ) - tool_a.set_result("some result") - - tool_b = ToolPart( - tool_call_id="toolu_b", - tool_name="tool_b", - tool_args="{}", - ) - tool_b.set_result(None) # returns None - - messages = [ - Message(role="assistant", parts=[tool_a, tool_b]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - assert len(anthropic_msgs) == 2 - - # Both tool_use blocks in assistant message - tool_uses = [b for b in anthropic_msgs[0]["content"] if b["type"] == "tool_use"] - assert len(tool_uses) == 2 - - # Both tool_result blocks in user message - tool_results = [ - b for b in anthropic_msgs[1]["content"] if b["type"] == "tool_result" - ] - assert len(tool_results) == 2 - - result_ids = {r["tool_use_id"] for r in tool_results} - assert result_ids == {"toolu_a", "toolu_b"} - - -# -- Multi-turn: consecutive user messages (tool_result + next user) ------- - - -async def test_multi_turn_no_consecutive_same_role_messages() -> None: - """Multi-turn with tools must not produce consecutive same-role messages. - - Regression: when a previous assistant turn includes a tool call (with - result), _messages_to_anthropic emits: - [assistant(tool_use)] [user(tool_result)] [user(next question)] - The two consecutive user messages violate Anthropic's alternating-role - requirement, causing: "tool_use ids were found without tool_result - blocks immediately after". - - The tool_result user message must be merged with the following user - message (or otherwise avoid consecutive same-role messages). - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="talk_to_mothership", - tool_args='{"question": "when?"}', - ) - tool.set_result({"value": "Soon."}) - - messages = [ - Message(role="user", parts=[TextPart(text="when will the robots take over?")]), - Message( - role="assistant", - parts=[ - TextPart(text="I'll check with the mothership."), - tool, - TextPart(text="The mothership has spoken: Soon."), - ], - ), - Message( - role="user", - parts=[TextPart(text="can you remember the first turn?")], - ), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Verify no consecutive same-role messages - for i in range(1, len(anthropic_msgs)): - assert anthropic_msgs[i]["role"] != anthropic_msgs[i - 1]["role"], ( - f"Consecutive same-role messages at indices {i - 1} and {i}: " - f"both are '{anthropic_msgs[i]['role']}'. " - f"Full messages: {anthropic_msgs}" - ) - - -async def test_multi_turn_tool_result_before_user_merged() -> None: - """When tool_result (user) is followed by a user message, they merge. - - The merged user message should contain both the tool_result blocks - and the text content from the following user message. - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="get_weather", - tool_args='{"city": "SF"}', - ) - tool.set_result("Sunny, 62F") - - messages = [ - Message(role="user", parts=[TextPart(text="what's the weather?")]), - Message(role="assistant", parts=[tool]), - Message(role="user", parts=[TextPart(text="thanks, what about tomorrow?")]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should be: user, assistant, user (tool_result + text) - assert len(anthropic_msgs) == 3 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert anthropic_msgs[2]["role"] == "user" - - # The merged user message should contain the tool_result - user_content = anthropic_msgs[2]["content"] - assert isinstance(user_content, list) - tool_results = [b for b in user_content if b.get("type") == "tool_result"] - assert len(tool_results) == 1 - assert tool_results[0]["tool_use_id"] == "toolu_01abc" - - -async def test_stream_loop_second_iteration_messages() -> None: - """Simulates what stream_loop sends on the 2nd LLM call in a multi-turn. - - After the first stream_step returns a tool call, stream_loop appends - the assistant message (now with status=result after execute_tool) and - calls stream_step again. The messages must not have consecutive - same-role entries. - """ - tool = ToolPart( - tool_call_id="toolu_01abc", - tool_name="talk_to_mothership", - tool_args='{"question": "test"}', - ) - tool.set_result("answer") - - # These are the messages that stream_loop would pass to the 2nd stream_step: - # original user messages + assistant message from 1st step (with tool result) - messages = [ - Message(role="user", parts=[TextPart(text="ask the mothership")]), - Message(role="assistant", parts=[tool]), - # No user message follows — this is the loop, not a new user turn - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # Should be: user, assistant(tool_use), user(tool_result) - assert len(anthropic_msgs) == 3 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert anthropic_msgs[2]["role"] == "user" - - # Verify the tool_result is present - tool_results = [ - b for b in anthropic_msgs[2]["content"] if b.get("type") == "tool_result" - ] - assert len(tool_results) == 1 - - -async def test_pending_tool_does_not_emit_tool_result() -> None: - """A tool with status='pending' must not produce a tool_result block. - - When stream_step returns a message mid-stream (before tool execution), - the ToolPart has status='pending'. The converter must emit only - the tool_use block — no tool_result. - """ - tool = ToolPart( - tool_call_id="toolu_pending", - tool_name="slow_tool", - tool_args='{"x": 1}', - ) - # Don't call set_result — status stays "pending" - - messages = [ - Message(role="user", parts=[TextPart(text="do something")]), - Message(role="assistant", parts=[tool]), - ] - - _system, anthropic_msgs = await _messages_to_anthropic(messages) - - # assistant message with tool_use, but NO user message with tool_result - assert len(anthropic_msgs) == 2 - assert anthropic_msgs[0]["role"] == "user" - assert anthropic_msgs[1]["role"] == "assistant" - assert any(b["type"] == "tool_use" for b in anthropic_msgs[1]["content"]) - - # No tool_result anywhere - for msg in anthropic_msgs: - if isinstance(msg["content"], list): - assert not any(b.get("type") == "tool_result" for b in msg["content"]) - - -# -- Multimodal user messages ------------------------------------------------ - - -async def test_user_text_only_is_plain_string() -> None: - """Text-only user messages should produce a plain content string.""" - msgs = [Message(role="user", parts=[TextPart(text="Hello")])] - _sys, result = await _messages_to_anthropic(msgs) - assert result[0]["content"] == "Hello" - - -async def test_user_image_url() -> None: - """Image URL → Anthropic image block with url source.""" - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="Describe this"), - FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), - ], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - content = result[0]["content"] - assert content[0] == {"type": "text", "text": "Describe this"} - assert content[1] == { - "type": "image", - "source": {"type": "url", "url": "https://example.com/cat.jpg"}, - } - - -async def test_user_image_base64() -> None: - """Base64 image → Anthropic image block with base64 source.""" - b64 = base64.b64encode(b"\x89PNG").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="image/png")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - img = result[0]["content"][0] - assert img["type"] == "image" - assert img["source"]["type"] == "base64" - assert img["source"]["media_type"] == "image/png" - assert img["source"]["data"] == b64 - - -async def test_user_pdf_url() -> None: - """PDF URL → Anthropic document block with url source.""" - msgs = [ - Message( - role="user", - parts=[ - FilePart( - data="https://example.com/doc.pdf", media_type="application/pdf" - ) - ], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"] == {"type": "url", "url": "https://example.com/doc.pdf"} - - -async def test_user_pdf_base64() -> None: - """PDF base64 → Anthropic document block with base64 source.""" - b64 = base64.b64encode(b"%PDF-1.4").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="application/pdf")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"]["type"] == "base64" - assert doc["source"]["media_type"] == "application/pdf" - - -async def test_user_text_plain_bytes() -> None: - """text/plain with bytes → Anthropic document with text source.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], - ) - ] - _sys, result = await _messages_to_anthropic(msgs) - doc = result[0]["content"][0] - assert doc["type"] == "document" - assert doc["source"]["type"] == "text" - assert doc["source"]["data"] == "Hello, world!" - - -async def test_unsupported_media_type_raises() -> None: - """Unsupported media type → ValueError.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"\x00", media_type="video/mp4")], - ) - ] - with pytest.raises(ValueError, match="Unsupported media type"): - await _messages_to_anthropic(msgs) diff --git a/tests/models/core/media/__init__.py b/tests/models/core/media/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/core/media/test_data.py b/tests/models/core/media/test_data.py deleted file mode 100644 index 55783ad7..00000000 --- a/tests/models/core/media/test_data.py +++ /dev/null @@ -1,80 +0,0 @@ -"""Tests for media data-format helpers (URL detection, base-64, data URLs).""" - -from vercel_ai_sdk.models.core.media.data import ( - data_to_base64, - data_to_data_url, - is_url, - split_data_url, -) - -# -- is_url ---------------------------------------------------------------- - - -def test_is_url_http() -> None: - assert is_url("https://example.com/img.png") is True - assert is_url("http://example.com/img.png") is True - - -def test_is_url_data() -> None: - assert is_url("data:image/png;base64,abc") is True - - -def test_is_url_base64() -> None: - assert is_url("iVBORw0KGgo=") is False - - -# -- data_to_base64 ------------------------------------------------------- - - -def test_data_to_base64_bytes() -> None: - assert data_to_base64(b"\x01\x02\x03") == "AQID" - - -def test_data_to_base64_passthrough() -> None: - assert data_to_base64("AQID") == "AQID" - - -def test_data_to_base64_extracts_from_data_url() -> None: - """data: URLs must have the prefix stripped -- providers need raw base64.""" - result = data_to_base64("data:image/png;base64,AQID") - assert result == "AQID" - - -def test_data_to_base64_passthrough_http_url() -> None: - """HTTP URLs are passed through -- caller must handle.""" - url = "https://example.com/img.png" - assert data_to_base64(url) == url - - -# -- data_to_data_url ------------------------------------------------------ - - -def test_data_to_data_url_from_bytes() -> None: - result = data_to_data_url(b"\x01\x02\x03", "image/png") - assert result == "data:image/png;base64,AQID" - - -def test_data_to_data_url_passthrough_url() -> None: - url = "https://example.com/img.png" - assert data_to_data_url(url, "image/png") == url - - -# -- split_data_url -------------------------------------------------------- - - -def test_split_data_url_valid() -> None: - mt, b64 = split_data_url("data:image/png;base64,iVBOR") - assert mt == "image/png" - assert b64 == "iVBOR" - - -def test_split_data_url_non_data_url() -> None: - mt, b64 = split_data_url("https://example.com/img.png") - assert mt is None - assert b64 is None - - -def test_split_data_url_malformed() -> None: - mt, b64 = split_data_url("data:") - assert mt is None - assert b64 is None diff --git a/tests/models/core/media/test_detect_media_type.py b/tests/models/core/media/test_detect_media_type.py deleted file mode 100644 index 6199a493..00000000 --- a/tests/models/core/media/test_detect_media_type.py +++ /dev/null @@ -1,460 +0,0 @@ -"""Tests for magic-byte media type detection. - -Ported from: .reference/ai/packages/ai/src/util/detect-media-type.test.ts -""" - -from __future__ import annotations - -import base64 - -from vercel_ai_sdk.models.core.media.detect import ( - AUDIO_SIGNATURES, - IMAGE_SIGNATURES, - detect_media_type, -) - -# --------------------------------------------------------------------------- -# Image detection -# --------------------------------------------------------------------------- - - -class TestGif: - def test_detect_gif_from_bytes(self) -> None: - data = bytes([0x47, 0x49, 0x46, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/gif" - - def test_detect_gif_from_base64(self) -> None: - assert detect_media_type("R0lGabc123", IMAGE_SIGNATURES) == "image/gif" - - -class TestPng: - def test_detect_png_from_bytes(self) -> None: - data = bytes([0x89, 0x50, 0x4E, 0x47, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/png" - - def test_detect_png_from_base64(self) -> None: - assert detect_media_type("iVBORwabc123", IMAGE_SIGNATURES) == "image/png" - - -class TestJpeg: - def test_detect_jpeg_from_bytes(self) -> None: - data = bytes([0xFF, 0xD8, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/jpeg" - - def test_detect_jpeg_from_base64(self) -> None: - assert detect_media_type("/9j/abc123", IMAGE_SIGNATURES) == "image/jpeg" - - -class TestWebp: - def test_detect_webp_from_bytes(self) -> None: - # RIFF + 4 bytes (file size) + WEBP + VP8 data - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, # "RIFF" - 0x24, - 0x00, - 0x00, - 0x00, # file size (wildcard in sig) - 0x57, - 0x45, - 0x42, - 0x50, # "WEBP" - 0x56, - 0x50, - 0x38, - 0x20, # "VP8 " (trailing data) - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/webp" - - def test_detect_webp_from_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, - 0x56, - 0x50, - 0x38, - 0x20, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/webp" - - def test_riff_audio_not_detected_as_webp_bytes(self) -> None: - """RIFF + WAVE should NOT match WebP.""" - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, # "WAVE", not "WEBP" - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_riff_audio_not_detected_as_webp_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) is None - - -class TestBmp: - def test_detect_bmp_from_bytes(self) -> None: - data = bytes([0x42, 0x4D, 0xFF, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/bmp" - - def test_detect_bmp_from_base64(self) -> None: - data = bytes([0x42, 0x4D, 0xFF, 0xFF]) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/bmp" - - -class TestTiff: - def test_detect_tiff_le_from_bytes(self) -> None: - data = bytes([0x49, 0x49, 0x2A, 0x00, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_le_from_base64(self) -> None: - assert detect_media_type("SUkqAAabc123", IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_be_from_bytes(self) -> None: - data = bytes([0x4D, 0x4D, 0x00, 0x2A, 0xFF]) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" - - def test_detect_tiff_be_from_base64(self) -> None: - assert detect_media_type("TU0AKgabc123", IMAGE_SIGNATURES) == "image/tiff" - - -class TestAvif: - def test_detect_avif_from_bytes(self) -> None: - data = bytes( - [ - 0x00, - 0x00, - 0x00, - 0x20, - 0x66, - 0x74, - 0x79, - 0x70, - 0x61, - 0x76, - 0x69, - 0x66, - 0xFF, - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/avif" - - def test_detect_avif_from_base64(self) -> None: - assert ( - detect_media_type("AAAAIGZ0eXBhdmlmabc123", IMAGE_SIGNATURES) - == "image/avif" - ) - - -class TestHeic: - def test_detect_heic_from_bytes(self) -> None: - data = bytes( - [ - 0x00, - 0x00, - 0x00, - 0x20, - 0x66, - 0x74, - 0x79, - 0x70, - 0x68, - 0x65, - 0x69, - 0x63, - 0xFF, - ] - ) - assert detect_media_type(data, IMAGE_SIGNATURES) == "image/heic" - - def test_detect_heic_from_base64(self) -> None: - assert ( - detect_media_type("AAAAIGZ0eXBoZWljabc123", IMAGE_SIGNATURES) - == "image/heic" - ) - - -# --------------------------------------------------------------------------- -# Audio detection -# --------------------------------------------------------------------------- - - -class TestMp3: - def test_detect_mp3_from_bytes(self) -> None: - data = bytes([0xFF, 0xFB]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_from_base64(self) -> None: - assert detect_media_type("//s=", AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_with_id3v2_tags_from_bytes(self) -> None: - """ID3v2 header (10 bytes tag, size=4) followed by MP3 frame.""" - data = bytes( - [ - 0x49, - 0x44, - 0x33, # "ID3" - 0x04, - 0x00, # version - 0x00, # flags - 0x00, - 0x00, - 0x00, - 0x04, # size = 4 (syncsafe) - 0x00, - 0x00, - 0x00, - 0x00, # 4 bytes of tag data - 0xFF, - 0xFB, # MP3 frame sync - 0x90, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" - - def test_detect_mp3_with_id3v2_tags_from_base64(self) -> None: - data = bytes( - [ - 0x49, - 0x44, - 0x33, - 0x04, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x04, - 0x00, - 0x00, - 0x00, - 0x00, - 0xFF, - 0xFB, - 0x90, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - 0x00, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/mpeg" - - -class TestWav: - def test_detect_wav_from_bytes(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, # "RIFF" - 0x24, - 0x00, - 0x00, - 0x00, # file size - 0x57, - 0x41, - 0x56, - 0x45, # "WAVE" - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/wav" - - def test_detect_wav_from_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x41, - 0x56, - 0x45, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/wav" - - def test_webp_not_detected_as_wav_bytes(self) -> None: - """RIFF + WEBP should NOT match WAV.""" - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, # "WEBP", not "WAVE" - ] - ) - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_webp_not_detected_as_wav_base64(self) -> None: - data = bytes( - [ - 0x52, - 0x49, - 0x46, - 0x46, - 0x24, - 0x00, - 0x00, - 0x00, - 0x57, - 0x45, - 0x42, - 0x50, - ] - ) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) is None - - -class TestOgg: - def test_detect_ogg_from_bytes(self) -> None: - data = bytes([0x4F, 0x67, 0x67, 0x53]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/ogg" - - def test_detect_ogg_from_base64(self) -> None: - assert detect_media_type("T2dnUw", AUDIO_SIGNATURES) == "audio/ogg" - - -class TestFlac: - def test_detect_flac_from_bytes(self) -> None: - data = bytes([0x66, 0x4C, 0x61, 0x43]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/flac" - - def test_detect_flac_from_base64(self) -> None: - assert detect_media_type("ZkxhQw", AUDIO_SIGNATURES) == "audio/flac" - - -class TestAac: - def test_detect_aac_from_bytes(self) -> None: - data = bytes([0x40, 0x15, 0x00, 0x00]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/aac" - - def test_detect_aac_from_base64(self) -> None: - data = bytes([0x40, 0x15, 0x00, 0x00]) - b64 = base64.b64encode(data).decode() - assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/aac" - - -class TestMp4Audio: - def test_detect_mp4_from_bytes(self) -> None: - data = bytes([0x66, 0x74, 0x79, 0x70]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mp4" - - def test_detect_mp4_from_base64(self) -> None: - assert detect_media_type("ZnR5cA", AUDIO_SIGNATURES) == "audio/mp4" - - -class TestWebmAudio: - def test_detect_webm_from_bytes(self) -> None: - data = bytes([0x1A, 0x45, 0xDF, 0xA3]) - assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/webm" - - def test_detect_webm_from_base64(self) -> None: - assert detect_media_type("GkXfow==", AUDIO_SIGNATURES) == "audio/webm" - - -# --------------------------------------------------------------------------- -# Error / edge cases -# --------------------------------------------------------------------------- - - -class TestEdgeCases: - def test_unknown_image_format(self) -> None: - data = bytes([0x00, 0x01, 0x02, 0x03]) - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_unknown_audio_format(self) -> None: - data = bytes([0x00, 0x01, 0x02, 0x03]) - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_empty_bytes_image(self) -> None: - assert detect_media_type(b"", IMAGE_SIGNATURES) is None - - def test_empty_bytes_audio(self) -> None: - assert detect_media_type(b"", AUDIO_SIGNATURES) is None - - def test_short_bytes_image(self) -> None: - """Bytes shorter than longest signature should not crash.""" - data = bytes([0x89, 0x50]) # incomplete PNG - assert detect_media_type(data, IMAGE_SIGNATURES) is None - - def test_short_bytes_audio(self) -> None: - data = bytes([0x4F, 0x67]) # incomplete OGG - assert detect_media_type(data, AUDIO_SIGNATURES) is None - - def test_invalid_base64_image(self) -> None: - assert detect_media_type("invalid123", IMAGE_SIGNATURES) is None - - def test_invalid_base64_audio(self) -> None: - assert detect_media_type("invalid123", AUDIO_SIGNATURES) is None diff --git a/tests/models/core/media/test_models.py b/tests/models/core/media/test_models.py deleted file mode 100644 index ed5a81ea..00000000 --- a/tests/models/core/media/test_models.py +++ /dev/null @@ -1,198 +0,0 @@ -"""Tests for MediaModel: extraction and message assembly.""" - -from __future__ import annotations - -from typing import Any - -import pytest - -from vercel_ai_sdk.models.core.media import MediaModel, MediaResult -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart, Usage - -# --------------------------------------------------------------------------- -# Concrete stub for testing the base class -# --------------------------------------------------------------------------- - - -class _StubMediaModel(MediaModel): - """Minimal concrete implementation that just returns what we tell it to.""" - - def __init__(self, result: MediaResult) -> None: - self._result = result - - async def make_request( - self, - prompt: str, - input_files: list[FilePart], - *, - n: int = 1, - provider_options: dict[str, Any] | None = None, - ) -> MediaResult: - return self._result - - -# --------------------------------------------------------------------------- -# _extract_prompt -# --------------------------------------------------------------------------- - - -class TestExtractPrompt: - def test_user_text(self) -> None: - msgs = [Message(role="user", parts=[TextPart(text="hello world")])] - assert MediaModel._extract_prompt(msgs) == "hello world" - - def test_system_and_user(self) -> None: - msgs = [ - Message(role="system", parts=[TextPart(text="be helpful")]), - Message(role="user", parts=[TextPart(text="draw a cat")]), - ] - assert MediaModel._extract_prompt(msgs) == "be helpful draw a cat" - - def test_ignores_assistant(self) -> None: - msgs = [ - Message(role="user", parts=[TextPart(text="hello")]), - Message(role="assistant", parts=[TextPart(text="ignored")]), - ] - assert MediaModel._extract_prompt(msgs) == "hello" - - def test_multiple_text_parts(self) -> None: - msgs = [ - Message( - role="user", - parts=[TextPart(text="first"), TextPart(text="second")], - ) - ] - assert MediaModel._extract_prompt(msgs) == "first second" - - def test_skips_non_text_parts(self) -> None: - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="prompt"), - FilePart(data=b"\x89PNG", media_type="image/png"), - ], - ) - ] - assert MediaModel._extract_prompt(msgs) == "prompt" - - def test_empty_messages(self) -> None: - assert MediaModel._extract_prompt([]) == "" - - -# --------------------------------------------------------------------------- -# _extract_input_files -# --------------------------------------------------------------------------- - - -class TestExtractInputFiles: - def test_user_file_parts(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - pdf = FilePart(data=b"%PDF", media_type="application/pdf") - msgs = [Message(role="user", parts=[TextPart(text="hi"), img, pdf])] - result = MediaModel._extract_input_files(msgs) - assert result == [img, pdf] - - def test_ignores_assistant_files(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - msgs = [Message(role="assistant", parts=[img])] - assert MediaModel._extract_input_files(msgs) == [] - - def test_ignores_system_files(self) -> None: - img = FilePart(data=b"\x89PNG", media_type="image/png") - msgs = [Message(role="system", parts=[img])] - assert MediaModel._extract_input_files(msgs) == [] - - def test_returns_all_media_types(self) -> None: - """Unlike the old extract_input_images, this returns ALL file parts.""" - img = FilePart(data=b"\x89PNG", media_type="image/png") - audio = FilePart(data=b"\xff\xfb", media_type="audio/mpeg") - video = FilePart(data=b"\x00\x00", media_type="video/mp4") - msgs = [Message(role="user", parts=[img, audio, video])] - result = MediaModel._extract_input_files(msgs) - assert len(result) == 3 - - def test_empty_messages(self) -> None: - assert MediaModel._extract_input_files([]) == [] - - def test_multiple_user_messages(self) -> None: - img1 = FilePart(data=b"\x89PNG", media_type="image/png") - img2 = FilePart(data=b"\xff\xd8", media_type="image/jpeg") - msgs = [ - Message(role="user", parts=[img1]), - Message(role="user", parts=[img2]), - ] - result = MediaModel._extract_input_files(msgs) - assert result == [img1, img2] - - -# --------------------------------------------------------------------------- -# _build_message -# --------------------------------------------------------------------------- - - -class TestBuildMessage: - def test_wraps_files_in_message(self) -> None: - fp = FilePart(data=b"\x89PNG", media_type="image/png") - result = MediaResult(files=[fp]) - msg = MediaModel._build_message(result) - assert msg.role == "assistant" - assert len(msg.parts) == 1 - assert msg.images[0] is fp - - def test_includes_usage(self) -> None: - fp = FilePart(data=b"\x89PNG", media_type="image/png") - usage = Usage(input_tokens=10, output_tokens=20) - result = MediaResult(files=[fp], usage=usage) - msg = MediaModel._build_message(result) - assert msg.usage is not None - assert msg.usage.input_tokens == 10 - assert msg.usage.output_tokens == 20 - - def test_no_usage(self) -> None: - result = MediaResult(files=[]) - msg = MediaModel._build_message(result) - assert msg.usage is None - - def test_empty_files(self) -> None: - result = MediaResult(files=[]) - msg = MediaModel._build_message(result) - assert msg.parts == [] - - -# --------------------------------------------------------------------------- -# Integration: generate() calls make_request() and wraps result -# --------------------------------------------------------------------------- - - -class TestGenerateIntegration: - @pytest.mark.asyncio - async def test_generate_round_trip(self) -> None: - """The base class extracts prompt/files and wraps the result.""" - fp_out = FilePart(data="b64data", media_type="image/png") - usage = Usage(input_tokens=5, output_tokens=15) - stub = _StubMediaModel(MediaResult(files=[fp_out], usage=usage)) - - # We can't call generate() directly on MediaModel since it doesn't - # define one — subclasses do. But we can verify the pipeline by - # calling the helpers manually. - prompt = stub._extract_prompt( - [Message(role="user", parts=[TextPart(text="a sunset")])] - ) - assert prompt == "a sunset" - - input_files = stub._extract_input_files( - [ - Message( - role="user", - parts=[FilePart(data=b"\x89PNG", media_type="image/png")], - ) - ] - ) - assert len(input_files) == 1 - - result = await stub.make_request(prompt, input_files) - msg = stub._build_message(result) - assert msg.role == "assistant" - assert msg.images == [fp_out] - assert msg.usage == usage diff --git a/tests/models/core/test_media.py b/tests/models/core/test_media.py new file mode 100644 index 00000000..eb77c96a --- /dev/null +++ b/tests/models/core/test_media.py @@ -0,0 +1,372 @@ +"""Tests for media data helpers and magic-byte media type detection. + +Covers ``is_url``, ``data_to_base64``, ``data_to_data_url``, +``split_data_url``, ``detect_image_media_type``, ``detect_audio_media_type``, +and edge cases. +""" + +from __future__ import annotations + +import base64 + +from vercel_ai_sdk.models.core.helpers.media import ( + data_to_base64, + data_to_data_url, + detect_audio_media_type, + detect_image_media_type, + is_url, + split_data_url, +) + +# --------------------------------------------------------------------------- +# is_url +# --------------------------------------------------------------------------- + + +class TestIsUrl: + def test_http(self) -> None: + assert is_url("https://example.com/img.png") is True + assert is_url("http://example.com/img.png") is True + + def test_data(self) -> None: + assert is_url("data:image/png;base64,iVBOR") is True + + def test_base64(self) -> None: + assert is_url("iVBORw0KGgo=") is False + + +# --------------------------------------------------------------------------- +# data_to_base64 +# --------------------------------------------------------------------------- + + +class TestDataToBase64: + def test_bytes(self) -> None: + raw = b"\x89PNG" + result = data_to_base64(raw) + assert base64.b64decode(result) == raw + + def test_passthrough(self) -> None: + b64 = base64.b64encode(b"hello").decode() + assert data_to_base64(b64) == b64 + + def test_extracts_from_data_url(self) -> None: + payload = base64.b64encode(b"hello").decode() + data_url = f"data:image/png;base64,{payload}" + assert data_to_base64(data_url) == payload + + def test_passthrough_http_url(self) -> None: + url = "https://example.com/image.png" + assert data_to_base64(url) == url + + +# --------------------------------------------------------------------------- +# data_to_data_url +# --------------------------------------------------------------------------- + + +class TestDataToDataUrl: + def test_from_bytes(self) -> None: + raw = b"\x89PNG" + result = data_to_data_url(raw, "image/png") + assert result.startswith("data:image/png;base64,") + assert base64.b64decode(result.split(",", 1)[1]) == raw + + def test_passthrough_url(self) -> None: + url = "https://example.com/image.png" + assert data_to_data_url(url, "image/png") == url + + +# --------------------------------------------------------------------------- +# split_data_url +# --------------------------------------------------------------------------- + + +class TestSplitDataUrl: + def test_valid(self) -> None: + media_type, content = split_data_url("data:image/png;base64,iVBOR") + assert media_type == "image/png" + assert content == "iVBOR" + + def test_non_data_url(self) -> None: + assert split_data_url("https://example.com") == (None, None) + + def test_malformed(self) -> None: + assert split_data_url("data:nope") == (None, None) + + +# --------------------------------------------------------------------------- +# Image detection +# --------------------------------------------------------------------------- + + +class TestGif: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x47, 0x49, 0x46])) == "image/gif" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x47, 0x49, 0x46])).decode() + ) + == "image/gif" + ) + + +class TestPng: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x89, 0x50, 0x4E, 0x47])) == "image/png" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x89, 0x50, 0x4E, 0x47])).decode() + ) + == "image/png" + ) + + +class TestJpeg: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0xFF, 0xD8, 0xFF])) == "image/jpeg" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0xFF, 0xD8, 0xFF])).decode() + ) + == "image/jpeg" + ) + + +class TestWebp: + _RIFF_WEBP = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._RIFF_WEBP) == "image/webp" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._RIFF_WEBP).decode()) + == "image/webp" + ) + + def test_riff_wave_not_webp_bytes(self) -> None: + riff_wave = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + assert detect_image_media_type(riff_wave) is None + + def test_riff_wave_not_webp_base64(self) -> None: + riff_wave = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + assert detect_image_media_type(base64.b64encode(riff_wave).decode()) is None + + +class TestBmp: + def test_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x42, 0x4D])) == "image/bmp" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(bytes([0x42, 0x4D])).decode()) + == "image/bmp" + ) + + +class TestTiff: + def test_little_endian_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x49, 0x49, 0x2A, 0x00])) == "image/tiff" + + def test_big_endian_from_bytes(self) -> None: + assert detect_image_media_type(bytes([0x4D, 0x4D, 0x00, 0x2A])) == "image/tiff" + + def test_little_endian_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x49, 0x49, 0x2A, 0x00])).decode() + ) + == "image/tiff" + ) + + def test_big_endian_from_base64(self) -> None: + assert ( + detect_image_media_type( + base64.b64encode(bytes([0x4D, 0x4D, 0x00, 0x2A])).decode() + ) + == "image/tiff" + ) + + +class TestAvif: + _AVIF = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._AVIF) == "image/avif" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._AVIF).decode()) + == "image/avif" + ) + + +class TestHeic: + _HEIC = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63] + ) + + def test_from_bytes(self) -> None: + assert detect_image_media_type(self._HEIC) == "image/heic" + + def test_from_base64(self) -> None: + assert ( + detect_image_media_type(base64.b64encode(self._HEIC).decode()) + == "image/heic" + ) + + +# --------------------------------------------------------------------------- +# Audio detection +# --------------------------------------------------------------------------- + + +class TestMp3: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(bytes([0xFF, 0xFB])) == "audio/mpeg" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(bytes([0xFF, 0xFB])).decode()) + == "audio/mpeg" + ) + + def test_with_id3_tags_bytes(self) -> None: + # ID3v2 header (10 bytes) + MP3 sync bytes + id3_header = bytes([0x49, 0x44, 0x33, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + mp3_data = id3_header + bytes([0xFF, 0xFB]) + assert detect_audio_media_type(mp3_data) == "audio/mpeg" + + def test_with_id3_tags_base64(self) -> None: + id3_header = bytes([0x49, 0x44, 0x33, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]) + mp3_data = id3_header + bytes([0xFF, 0xFB]) + assert ( + detect_audio_media_type(base64.b64encode(mp3_data).decode()) == "audio/mpeg" + ) + + +class TestWav: + _RIFF_WAVE = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._RIFF_WAVE) == "audio/wav" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._RIFF_WAVE).decode()) + == "audio/wav" + ) + + def test_riff_webp_not_wav_bytes(self) -> None: + riff_webp = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + assert detect_audio_media_type(riff_webp) is None + + def test_riff_webp_not_wav_base64(self) -> None: + riff_webp = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x45, 0x42, 0x50] + ) + assert detect_audio_media_type(base64.b64encode(riff_webp).decode()) is None + + +class TestOgg: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(b"OggS") == "audio/ogg" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(b"OggS").decode()) == "audio/ogg" + ) + + +class TestFlac: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(b"fLaC") == "audio/flac" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(b"fLaC").decode()) == "audio/flac" + ) + + +class TestAac: + def test_from_bytes(self) -> None: + assert detect_audio_media_type(bytes([0x40, 0x15, 0x00, 0x00])) == "audio/aac" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type( + base64.b64encode(bytes([0x40, 0x15, 0x00, 0x00])).decode() + ) + == "audio/aac" + ) + + +class TestMp4Audio: + # The audio/mp4 signature starts at the `ftyp` atom directly (no box size prefix). + _FTYP = bytes([0x66, 0x74, 0x79, 0x70]) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._FTYP) == "audio/mp4" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._FTYP).decode()) + == "audio/mp4" + ) + + +class TestWebmAudio: + _WEBM = bytes([0x1A, 0x45, 0xDF, 0xA3]) + + def test_from_bytes(self) -> None: + assert detect_audio_media_type(self._WEBM) == "audio/webm" + + def test_from_base64(self) -> None: + assert ( + detect_audio_media_type(base64.b64encode(self._WEBM).decode()) + == "audio/webm" + ) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_unknown_image_format(self) -> None: + assert detect_image_media_type(bytes([0x00, 0x01, 0x02, 0x03])) is None + + def test_unknown_audio_format(self) -> None: + assert detect_audio_media_type(bytes([0x00, 0x01, 0x02, 0x03])) is None + + def test_empty_bytes_image(self) -> None: + assert detect_image_media_type(b"") is None + + def test_empty_bytes_audio(self) -> None: + assert detect_audio_media_type(b"") is None + + def test_short_bytes_image(self) -> None: + assert detect_image_media_type(bytes([0x89])) is None + + def test_short_bytes_audio(self) -> None: + assert detect_audio_media_type(bytes([0xFF])) is None diff --git a/tests/models/core/test_llm.py b/tests/models/core/test_streaming.py similarity index 88% rename from tests/models/core/test_llm.py rename to tests/models/core/test_streaming.py index ba1546e3..538d3a50 100644 --- a/tests/models/core/test_llm.py +++ b/tests/models/core/test_streaming.py @@ -1,13 +1,6 @@ -"""StreamHandler: event accumulation, state transitions, message building. -LanguageModel.buffer() with structured output.""" +"""StreamHandler: event accumulation, state transitions, message building.""" -import json - -import pydantic -import pytest - -import vercel_ai_sdk as ai -from vercel_ai_sdk.models.core.llm import ( +from vercel_ai_sdk.models.core.helpers.streaming import ( FileEvent, MessageDone, ReasoningDelta, @@ -29,14 +22,6 @@ Usage, ) -from ...conftest import MockLLM, text_msg - - -class _Weather(pydantic.BaseModel): - city: str - temperature: float - - # -- Text streaming -------------------------------------------------------- @@ -224,30 +209,6 @@ def test_deltas_only_on_active_blocks() -> None: assert text_parts[1].delta == "second" # t2 is active -# -- LanguageModel.buffer() with structured output ------------------------- - - -@pytest.mark.asyncio -async def test_buffer_structured_output() -> None: - """buffer() returns a message with a validated StructuredOutputPart.""" - json_text = '{"city":"Tokyo","temperature":28.5}' - llm = MockLLM([[text_msg(json_text)]]) - - msg = await llm.buffer(ai.make_messages(user="weather?"), output_type=_Weather) - - assert isinstance(msg.output, _Weather) - assert msg.output.city == "Tokyo" - - -@pytest.mark.asyncio -async def test_buffer_structured_output_invalid_json_raises() -> None: - """Bad LLM output with output_type should raise, not silently pass.""" - llm = MockLLM([[text_msg("not json")]]) - - with pytest.raises((json.JSONDecodeError, pydantic.ValidationError)): - await llm.buffer(ai.make_messages(user="weather?"), output_type=_Weather) - - # -- File event (inline images from LLMs like Gemini/GPT-5) --------------- diff --git a/tests/models/openai/__init__.py b/tests/models/openai/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/models/openai/test_openai.py b/tests/models/openai/test_openai.py deleted file mode 100644 index 964886e7..00000000 --- a/tests/models/openai/test_openai.py +++ /dev/null @@ -1,245 +0,0 @@ -"""OpenAI provider: _messages_to_openai multimodal conversion tests.""" - -import base64 -from unittest.mock import AsyncMock, patch - -import pytest - -from vercel_ai_sdk.models.openai import _messages_to_openai -from vercel_ai_sdk.types.messages import FilePart, Message, TextPart - -# -- text-only (regression) ------------------------------------------------ - - -@pytest.mark.asyncio -async def test_user_text_only_is_plain_string() -> None: - """Text-only user messages should produce a plain content string, not array.""" - msgs = [Message(role="user", parts=[TextPart(text="Hello")])] - result = await _messages_to_openai(msgs) - assert result == [{"role": "user", "content": "Hello"}] - - -# -- images ---------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_image_url() -> None: - """Image URL → OpenAI image_url content part.""" - msgs = [ - Message( - role="user", - parts=[ - TextPart(text="What's this?"), - FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), - ], - ) - ] - result = await _messages_to_openai(msgs) - content = result[0]["content"] - assert content[0] == {"type": "text", "text": "What's this?"} - assert content[1] == { - "type": "image_url", - "image_url": {"url": "https://example.com/cat.jpg"}, - } - - -@pytest.mark.asyncio -async def test_user_image_base64() -> None: - """Base64 image data → OpenAI image_url with data URL.""" - b64 = base64.b64encode(b"\x89PNG").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - content = result[0]["content"] - assert content[0]["type"] == "image_url" - assert content[0]["image_url"]["url"] == f"data:image/png;base64,{b64}" - - -@pytest.mark.asyncio -async def test_user_image_bytes() -> None: - """Raw bytes image → OpenAI image_url with data URL.""" - raw = b"\x89PNG" - msgs = [ - Message( - role="user", - parts=[FilePart(data=raw, media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - url = result[0]["content"][0]["image_url"]["url"] - assert url.startswith("data:image/png;base64,") - - -@pytest.mark.asyncio -async def test_user_image_wildcard_becomes_jpeg() -> None: - """image/* media type is normalized to image/jpeg for the data URL.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="https://example.com/img", media_type="image/*")], - ) - ] - result = await _messages_to_openai(msgs) - # URL passthrough: no data URL conversion needed - assert result[0]["content"][0]["image_url"]["url"] == "https://example.com/img" - - -@pytest.mark.asyncio -async def test_user_image_data_url() -> None: - """data: URL image → base64 extracted correctly for image_url.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="data:image/png;base64,AQID", media_type="image/png")], - ) - ] - result = await _messages_to_openai(msgs) - # data: URLs pass through directly for images - assert result[0]["content"][0]["image_url"]["url"] == "data:image/png;base64,AQID" - - -# -- audio ----------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_audio_base64() -> None: - """Audio base64 → OpenAI input_audio part.""" - b64 = base64.b64encode(b"\xff\xfb").decode() - msgs = [ - Message( - role="user", - parts=[FilePart(data=b64, media_type="audio/wav")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["data"] == b64 - assert part["input_audio"]["format"] == "wav" - - -@pytest.mark.asyncio -async def test_user_audio_data_url_extracts_base64() -> None: - """Audio data: URL → base64 prefix stripped for input_audio.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data="data:audio/wav;base64,AAAA", media_type="audio/wav")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["data"] == "AAAA" - - -@pytest.mark.asyncio -async def test_user_audio_url_downloads() -> None: - """Audio URLs are auto-downloaded since OpenAI requires base64.""" - fake_audio = b"\xff\xfb\x90\x00" - msgs = [ - Message( - role="user", - parts=[ - FilePart(data="https://example.com/clip.wav", media_type="audio/wav") - ], - ) - ] - with patch( - "vercel_ai_sdk.models.core.media.download.download", - new_callable=AsyncMock, - return_value=(fake_audio, "audio/wav"), - ): - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "input_audio" - assert part["input_audio"]["format"] == "wav" - # Should be base64 of the downloaded bytes - assert part["input_audio"]["data"] == base64.b64encode(fake_audio).decode() - - -# -- PDF ------------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_pdf_base64() -> None: - """PDF base64 → OpenAI file part.""" - b64 = base64.b64encode(b"%PDF-1.4").decode() - msgs = [ - Message( - role="user", - parts=[ - FilePart(data=b64, media_type="application/pdf", filename="report.pdf") - ], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "file" - assert part["file"]["filename"] == "report.pdf" - assert part["file"]["file_data"].startswith("data:application/pdf;base64,") - - -@pytest.mark.asyncio -async def test_user_pdf_url_downloads() -> None: - """PDF URLs are auto-downloaded since OpenAI requires base64.""" - fake_pdf = b"%PDF-1.4 fake content" - msgs = [ - Message( - role="user", - parts=[ - FilePart( - data="https://example.com/doc.pdf", - media_type="application/pdf", - filename="doc.pdf", - ) - ], - ) - ] - with patch( - "vercel_ai_sdk.models.core.media.download.download", - new_callable=AsyncMock, - return_value=(fake_pdf, "application/pdf"), - ): - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part["type"] == "file" - assert part["file"]["filename"] == "doc.pdf" - assert part["file"]["file_data"].startswith("data:application/pdf;base64,") - - -# -- text/* ---------------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_user_text_file_bytes() -> None: - """text/* file with bytes data → decoded to text content part.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], - ) - ] - result = await _messages_to_openai(msgs) - part = result[0]["content"][0] - assert part == {"type": "text", "text": "Hello, world!"} - - -# -- unsupported ----------------------------------------------------------- - - -@pytest.mark.asyncio -async def test_unsupported_media_type_raises() -> None: - """Unknown media type → ValueError.""" - msgs = [ - Message( - role="user", - parts=[FilePart(data=b"\x00", media_type="application/octet-stream")], - ) - ] - with pytest.raises(ValueError, match="Unsupported media type"): - await _messages_to_openai(msgs) diff --git a/tests/telemetry/test_otel_handler.py b/tests/telemetry/test_otel_handler.py index 5f2ff2b6..a30dcf7c 100644 --- a/tests/telemetry/test_otel_handler.py +++ b/tests/telemetry/test_otel_handler.py @@ -12,7 +12,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.telemetry.otel import OtelHandler -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg @pytest.fixture @@ -35,11 +35,10 @@ async def double(x: int) -> int: @pytest.mark.asyncio async def test_text_only_spans(spans: InMemorySpanExporter) -> None: """Text-only run produces ai.run > ai.stream span hierarchy.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) - - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] finished = spans.get_finished_spans() @@ -51,8 +50,11 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: stream_span = next(s for s in finished if s.name == "ai.stream") # ai.stream is a child of ai.run - assert stream_span.parent is not None - assert stream_span.parent.span_id == run_span.context.span_id + stream_parent = stream_span.parent + assert stream_parent is not None + run_ctx = run_span.context + assert run_ctx is not None + assert stream_parent.span_id == run_ctx.span_id # run_id attribute is set assert run_span.attributes is not None @@ -62,19 +64,15 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_tool_call_spans(spans: InMemorySpanExporter) -> None: """Tool-calling run produces ai.tool spans with correct attributes.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, messages=ai.make_messages(user="Double 5"), tools=[double] - ) - - llm = MockLLM( + mock_llm( [ [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')], [text_msg("10")], ] ) - result = ai.run(root, llm) + result = my_agent.run(ai.make_messages(user="Double 5")) [m async for m in result] finished = spans.get_finished_spans() @@ -89,5 +87,8 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: # ai.tool is a child of ai.run (tools execute between steps) run_span = next(s for s in finished if s.name == "ai.run") - assert tool_span.parent is not None - assert tool_span.parent.span_id == run_span.context.span_id + tool_parent = tool_span.parent + assert tool_parent is not None + run_ctx = run_span.context + assert run_ctx is not None + assert tool_parent.span_id == run_ctx.span_id diff --git a/tests/telemetry/test_telemetry.py b/tests/telemetry/test_telemetry.py index cf470334..b1b06eb3 100644 --- a/tests/telemetry/test_telemetry.py +++ b/tests/telemetry/test_telemetry.py @@ -17,7 +17,7 @@ ToolCallStartEvent, ) -from ..conftest import MockLLM, text_msg, tool_msg +from ..conftest import MOCK_MODEL, mock_llm, text_msg, tool_msg # ── Recording handler ──────────────────────────────────────────── @@ -55,11 +55,10 @@ async def double(x: int) -> int: @pytest.mark.asyncio async def test_text_only_run_events(handler: RecordingHandler) -> None: """Simplest run emits RunStart, StepStart, StepFinish, RunFinish.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) - - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -77,20 +76,16 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_tool_call_events(handler: RecordingHandler) -> None: - """Tool-calling run emits tool events between steps with correct payloads.""" - - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, messages=ai.make_messages(user="Double 5"), tools=[double] - ) + """Tool-calling run emits tool events between steps.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) - llm = MockLLM( + mock_llm( [ [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')], [text_msg("10")], ] ) - result = ai.run(root, llm) + result = my_agent.run(ai.make_messages(user="Double 5")) [m async for m in result] types = [type(e).__name__ for e in handler.events] @@ -119,7 +114,7 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_run_id_available_during_run() -> None: - """get_run_id() returns a non-empty ID inside a handler during a run.""" + """get_run_id() returns a non-empty ID inside a handler during run.""" captured: str = "" class Capture: @@ -130,13 +125,10 @@ def handle(self, event: TelemetryEvent) -> None: ai.telemetry.enable(Capture()) try: + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop( - llm, messages=ai.make_messages(user="Hi"), tools=[] - ) - - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(captured) == 16 finally: @@ -152,17 +144,18 @@ async def test_disable_reverts_to_noop() -> None: handler = RecordingHandler() ai.telemetry.enable(handler) - async def root(llm: ai.LanguageModel) -> ai.StreamResult: - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(handler.of_type(RunStartEvent)) == 1 ai.telemetry.disable() handler.events.clear() - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] assert len(handler.events) == 0 @@ -178,16 +171,20 @@ async def test_user_emitted_custom_event(handler: RecordingHandler) -> None: class CustomEvent(TelemetryEvent): message: str - async def root(llm: ai.LanguageModel) -> ai.StreamResult: + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) + + @my_agent.loop + async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> ai.StreamResult: ai.telemetry.handle(CustomEvent(message="hello")) - return await ai.stream_loop(llm, messages=ai.make_messages(user="Hi"), tools=[]) + return await ai.stream_step(agent.model, msgs) - result = ai.run(root, MockLLM([[text_msg("Hello!")]])) + mock_llm([[text_msg("Hello!")]]) + result = my_agent.run(ai.make_messages(user="Hi")) [m async for m in result] - custom = [e for e in handler.events if isinstance(e, CustomEvent)] - assert len(custom) == 1 - assert custom[0].message == "hello" + custom_events = [e for e in handler.events if isinstance(e, CustomEvent)] + assert len(custom_events) == 1 + assert custom_events[0].message == "hello" # ── Error capture ──────────────────────────────────────────────── @@ -195,12 +192,15 @@ async def root(llm: ai.LanguageModel) -> ai.StreamResult: @pytest.mark.asyncio async def test_run_error_in_finish_event(handler: RecordingHandler) -> None: - """RunFinishEvent captures the error when the root function raises.""" + """RunFinishEvent captures the error when the loop function raises.""" + my_agent = ai.agent(model=MOCK_MODEL, tools=[]) - async def root(llm: ai.LanguageModel) -> None: + @my_agent.loop + async def failing(agent: ai.Agent, msgs: list[ai.Message]) -> None: raise ValueError("boom") - result = ai.run(root, MockLLM([])) + mock_llm([]) + result = my_agent.run(ai.make_messages(user="Hi")) with pytest.raises(ExceptionGroup): [m async for m in result] diff --git a/uv.lock b/uv.lock index e811f862..a79e345c 100644 --- a/uv.lock +++ b/uv.lock @@ -536,6 +536,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/7b/2c79738432f5c924bef5071f933bcc9efd0473bac3b4aa584a6f7c1c8df8/mypy_extensions-1.1.0-py3-none-any.whl", hash = "sha256:1be4cccdb0f2482337c4743e60421de3a356cd97508abadd57d47403e94f5505", size = 4963, upload-time = "2025-04-22T14:54:22.983Z" }, ] +[[package]] +name = "nodeenv" +version = "1.10.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, +] + [[package]] name = "openai" version = "2.14.0" @@ -754,6 +763,19 @@ crypto = [ { name = "cryptography" }, ] +[[package]] +name = "pyright" +version = "1.1.408" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "nodeenv" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/74/b2/5db700e52554b8f025faa9c3c624c59f1f6c8841ba81ab97641b54322f16/pyright-1.1.408.tar.gz", hash = "sha256:f28f2321f96852fa50b5829ea492f6adb0e6954568d1caa3f3af3a5f555eb684", size = 4400578, upload-time = "2026-01-08T08:07:38.795Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/0c/82/a2c93e32800940d9573fb28c346772a14778b84ba7524e691b324620ab89/pyright-1.1.408-py3-none-any.whl", hash = "sha256:090b32865f4fdb1e0e6cd82bf5618480d48eecd2eb2e70f960982a3d9a4c17c1", size = 6399144, upload-time = "2026-01-08T08:07:37.082Z" }, +] + [[package]] name = "pytest" version = "9.0.2" @@ -1049,7 +1071,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev8" +version = "0.0.1.dev9" source = { editable = "." } dependencies = [ { name = "anthropic" }, @@ -1065,6 +1087,7 @@ dependencies = [ dev = [ { name = "mypy" }, { name = "opentelemetry-sdk" }, + { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "python-dotenv" }, @@ -1087,6 +1110,7 @@ requires-dist = [ dev = [ { name = "mypy", specifier = ">=1.11" }, { name = "opentelemetry-sdk", specifier = ">=1.0" }, + { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.0" }, { name = "pytest-asyncio", specifier = ">=0.24" }, { name = "python-dotenv", specifier = ">=1.2.1" },