diff --git a/CLAUDE.md b/CLAUDE.md index 0605ce5b..88971d7d 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -4,7 +4,10 @@ 1. use `uv` to manage the project; `uv add` and `uv remove` to manage dependencies, `uv run` to run 2. after making changes run lint and typecheck: `uv run ruff check --fix src tests` and `uv run mypy src tests` -3. import by module (except `typing`) to improve readability via namespacing +3. imports: + - import by module, using the shortest unambiguous relative path. `from ..core import helpers`, `from . import streaming` + - UNLESS it's `typing` — then `from typing import Foo` (there are too many of them). + - if the module name shadows a local variable in the same file, add a trailing underscore to the import: `from ..types import messages as messages_`. do not add trailing underscores preemptively — only when there is an actual collision. 4. treat `stream_step` and `stream_loop` as user code. they are convenience functions that could be reimplemented by the user, they *must* stay clean. ## design principles diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index 7250fe19..c5255588 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -27,11 +27,12 @@ async def talk_to_mothership(question: str) -> str: async def _execute_with_approval( tc: ai.ToolPart, message: ai.Message | None = None -) -> None: +) -> ai.ToolPart: """Execute a tool call only after the user grants approval. Creates a ToolApproval hook that suspends execution until the frontend responds with an approve/reject decision. + Returns the updated (immutable) ToolPart with the result. """ approval = await ai.ToolApproval.create( # type: ignore[attr-defined] f"approve_{tc.tool_call_id}", @@ -39,9 +40,8 @@ async def _execute_with_approval( ) if approval.granted: - await ai.execute_tool(tc, message=message) - else: - tc.set_error("Tool call was denied by the user.") + return await ai.execute_tool(tc, message=message) + return tc.with_error("Tool call was denied by the user.") chat_agent = ai.agent( @@ -73,8 +73,11 @@ async def graph( last_msg = result.last_message assert last_msg is not None - local_messages.append(last_msg) - await asyncio.gather( + updated_parts = await asyncio.gather( *(_execute_with_approval(tc, message=last_msg) for tc in result.tool_calls) ) + updated_msg = last_msg + for updated_tc in updated_parts: + updated_msg = updated_msg.replace(updated_tc) + local_messages.append(updated_msg) diff --git a/examples/models/buffer.py b/examples/models/buffer.py index 4020affd..d60a4d2e 100644 --- a/examples/models/buffer.py +++ b/examples/models/buffer.py @@ -2,8 +2,8 @@ import asyncio +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -11,12 +11,7 @@ provider="ai-gateway", ) -messages = [ - messages_.Message( - role="user", - parts=[messages_.TextPart(text="What is 2 + 2?")], - ), -] +messages = [ai.user_message("What is 2 + 2?")] async def main() -> None: diff --git a/examples/models/direct_adapter.py b/examples/models/direct_adapter.py index df386a0b..99c093e9 100644 --- a/examples/models/direct_adapter.py +++ b/examples/models/direct_adapter.py @@ -3,9 +3,9 @@ import asyncio import os +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m from vercel_ai_sdk.models import ai_gateway as ai_gateway_v3 -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -18,12 +18,7 @@ api_key=os.environ["AI_GATEWAY_API_KEY"], ) -messages = [ - messages_.Message( - role="user", - parts=[messages_.TextPart(text="Say hello in three languages.")], - ), -] +messages = [ai.user_message("Say hello in three languages.")] async def main() -> None: diff --git a/examples/models/explicit_client.py b/examples/models/explicit_client.py index 6c3d7c6e..747cc4e0 100644 --- a/examples/models/explicit_client.py +++ b/examples/models/explicit_client.py @@ -3,8 +3,8 @@ import asyncio import os +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -19,12 +19,7 @@ headers={"X-Custom-Header": "example"}, ) -messages = [ - messages_.Message( - role="user", - parts=[messages_.TextPart(text="Hello!")], - ), -] +messages = [ai.user_message("Hello!")] async def main() -> None: diff --git a/examples/models/image_generation.py b/examples/models/image_generation.py index 63b70d7d..9fa5bdde 100644 --- a/examples/models/image_generation.py +++ b/examples/models/image_generation.py @@ -4,8 +4,8 @@ import base64 import pathlib +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="google/imagen-4.0-generate-001", @@ -15,18 +15,11 @@ ) messages = [ - messages_.Message( - role="user", - parts=[ - messages_.TextPart( - text=( - "Anime girl with twin tails and cat ears, wearing a " - "sailor school uniform, striking a victory pose in front " - "of a futuristic Tokyo skyline at night, neon lights " - "reflecting in her eyes, digital art style" - ) - ), - ], + ai.user_message( + "Anime girl with twin tails and cat ears, wearing a " + "sailor school uniform, striking a victory pose in front " + "of a futuristic Tokyo skyline at night, neon lights " + "reflecting in her eyes, digital art style" ), ] diff --git a/examples/models/inline_image.py b/examples/models/inline_image.py index 91777e87..cec37009 100644 --- a/examples/models/inline_image.py +++ b/examples/models/inline_image.py @@ -9,8 +9,8 @@ import base64 import pathlib +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ # This is a language model that can also output images inline. model = m.Model( @@ -21,33 +21,19 @@ ) messages = [ - messages_.Message( - role="system", - parts=[ - messages_.TextPart( - text=( - "You are an anime art assistant. When asked to draw or create " - "an image, generate it in a soft pastel anime style." - ) - ), - ], + ai.system_message( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style." ), - messages_.Message( - role="user", - parts=[ - messages_.TextPart( - text=( - "Draw an anime girl with long silver hair and violet eyes, " - "sitting in a field of cherry blossoms at sunset." - ) - ), - ], + ai.user_message( + "Draw an anime girl with long silver hair and violet eyes, " + "sitting in a field of cherry blossoms at sunset." ), ] async def main() -> None: - last_msg: messages_.Message | None = None + last_msg: ai.Message | None = None # Stream — text deltas arrive as usual, images arrive as FileParts async for msg in m.stream(model, messages): diff --git a/examples/models/multimodal_input.py b/examples/models/multimodal_input.py index f5a11a14..4103f50e 100644 --- a/examples/models/multimodal_input.py +++ b/examples/models/multimodal_input.py @@ -3,8 +3,8 @@ import asyncio import pathlib +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -17,12 +17,9 @@ image_data = image_path.read_bytes() messages = [ - messages_.Message( - role="user", - parts=[ - messages_.TextPart(text="Describe this image in detail."), - messages_.FilePart(data=image_data, media_type="image/jpeg"), - ], + ai.user_message( + "Describe this image in detail.", + ai.file_part(image_data, media_type="image/jpeg"), ), ] diff --git a/examples/models/stream.py b/examples/models/stream.py index 1183fb05..0e9bb4d5 100644 --- a/examples/models/stream.py +++ b/examples/models/stream.py @@ -2,8 +2,8 @@ import asyncio +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -12,13 +12,8 @@ ) messages = [ - messages_.Message(role="system", parts=[messages_.TextPart(text="Be concise.")]), - messages_.Message( - role="user", - parts=[ - messages_.TextPart(text="Explain why the sky is blue in two sentences.") - ], - ), + ai.system_message("Be concise."), + ai.user_message("Explain why the sky is blue in two sentences."), ] diff --git a/examples/models/structured_output.py b/examples/models/structured_output.py index 172d7201..ebe5f757 100644 --- a/examples/models/structured_output.py +++ b/examples/models/structured_output.py @@ -4,8 +4,8 @@ import pydantic +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="anthropic/claude-sonnet-4", @@ -21,12 +21,7 @@ class Recipe(pydantic.BaseModel): prep_time_minutes: int -messages = [ - messages_.Message( - role="user", - parts=[messages_.TextPart(text="Give me a simple pancake recipe.")], - ), -] +messages = [ai.user_message("Give me a simple pancake recipe.")] async def main() -> None: diff --git a/examples/models/tools.py b/examples/models/tools.py index 3e3c5d81..d7615575 100644 --- a/examples/models/tools.py +++ b/examples/models/tools.py @@ -2,8 +2,8 @@ import asyncio +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ from vercel_ai_sdk.types import tools as tools_ model = m.Model( @@ -26,12 +26,7 @@ return_type=str, ) -messages = [ - messages_.Message( - role="user", - parts=[messages_.TextPart(text="What's the weather in Tokyo?")], - ), -] +messages = [ai.user_message("What's the weather in Tokyo?")] async def main() -> None: diff --git a/examples/models/video_generation.py b/examples/models/video_generation.py index b5f5c8d3..77b46c7a 100644 --- a/examples/models/video_generation.py +++ b/examples/models/video_generation.py @@ -4,8 +4,8 @@ import base64 import pathlib +import vercel_ai_sdk as ai from vercel_ai_sdk import models as m -from vercel_ai_sdk.types import messages as messages_ model = m.Model( id="google/veo-3.0-generate-001", @@ -15,18 +15,11 @@ ) messages = [ - messages_.Message( - role="user", - parts=[ - messages_.TextPart( - text=( - "An anime girl with long pink hair and a flowing white " - "dress stands on a hilltop at golden hour. A warm breeze " - "lifts her hair as she releases a paper lantern into the " - "sunset sky. Soft cel-shaded anime art style, warm palette." - ) - ), - ], + ai.user_message( + "An anime girl with long pink hair and a flowing white " + "dress stands on a hilltop at golden hour. A warm breeze " + "lifts her hair as she releases a paper lantern into the " + "sunset sky. Soft cel-shaded anime art style, warm palette." ), ] diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index a49c8602..3ecf57ba 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -91,6 +91,9 @@ async def mothership_loop( if not result.tool_calls: break + last_msg = result.last_message + assert last_msg is not None + for tc in result.tool_calls: if tc.tool_name == "contact_mothership": # TODO: mypy doesn't support class decorators that change the @@ -101,14 +104,14 @@ async def mothership_loop( metadata={"branch": "mothership", "tool": tc.tool_name}, ) if approval.granted: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) else: - tc.set_error(f"Denied: {approval.reason}") + updated_tc = tc.with_error(f"Denied: {approval.reason}") else: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) + last_msg = last_msg.replace(updated_tc) - if result.last_message is not None: - local_messages.append(result.last_message) + local_messages.append(last_msg) return result @@ -135,6 +138,9 @@ async def data_center_loop( if not result.tool_calls: break + last_msg = result.last_message + assert last_msg is not None + for tc in result.tool_calls: if tc.tool_name == "contact_data_centers": # TODO: mypy doesn't support class decorators that change the @@ -145,14 +151,14 @@ async def data_center_loop( metadata={"branch": "data_centers", "tool": tc.tool_name}, ) if approval.granted: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) else: - tc.set_error(f"Access denied: {approval.reason}") + updated_tc = tc.with_error(f"Access denied: {approval.reason}") else: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) + last_msg = last_msg.replace(updated_tc) - if result.last_message is not None: - local_messages.append(result.last_message) + local_messages.append(last_msg) return result @@ -174,8 +180,8 @@ async def multiagent_loop( # Fan out: run both sub-agent loops within this runtime r1, r2 = await asyncio.gather( - mothership_loop(mothership_agent, ai.make_messages(user=query)), - data_center_loop(data_center_agent, ai.make_messages(user=query)), + mothership_loop(mothership_agent, [ai.user_message(query)]), + data_center_loop(data_center_agent, [ai.user_message(query)]), ) combined = ( @@ -184,10 +190,12 @@ async def multiagent_loop( return await ai.stream_step( agent.model, - ai.make_messages( - system="You are assistant 3. Summarise the results from the other assistants.", - user=combined, - ), + [ + ai.system_message( + "You are assistant 3. Summarise the results from the other assistants." + ), + ai.user_message(combined), + ], label="summary", ) @@ -215,7 +223,7 @@ async def ws_endpoint(websocket: fastapi.WebSocket) -> None: await websocket.accept() print("Client connected") - result = orchestrator.run(ai.make_messages(user="When will the robots take over?")) + result = orchestrator.run([ai.user_message("When will the robots take over?")]) # Background task: read hook resolutions from the client. async def read_resolutions() -> None: diff --git a/examples/samples/custom_loop.py b/examples/samples/custom_loop.py index 923768fd..0e87ec37 100644 --- a/examples/samples/custom_loop.py +++ b/examples/samples/custom_loop.py @@ -74,9 +74,11 @@ async def custom(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult ) async for msg in my_agent.run( - ai.make_messages( - user="What's the weather and population of New York and Los Angeles?" - ) + [ + ai.user_message( + "What's the weather and population of New York and Los Angeles?" + ) + ] ): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/hooks.py b/examples/samples/hooks.py index 33bbe282..8e11a5c3 100644 --- a/examples/samples/hooks.py +++ b/examples/samples/hooks.py @@ -44,6 +44,9 @@ async def with_approval( if not result.tool_calls: break + last_msg = result.last_message + assert last_msg is not None + for tc in result.tool_calls: if tc.tool_name == "contact_mothership": # Blocks until resolved (long-running) or cancelled (serverless) @@ -55,20 +58,18 @@ async def with_approval( metadata={"tool": tc.tool_name}, ) if approval.granted: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) else: - tc.set_error(f"Rejected: {approval.reason}") + updated_tc = tc.with_error(f"Rejected: {approval.reason}") else: - await ai.execute_tool(tc, message=result.last_message) + updated_tc = await ai.execute_tool(tc, message=last_msg) + last_msg = last_msg.replace(updated_tc) - if result.last_message is not None: - local_messages.append(result.last_message) + local_messages.append(last_msg) return result - async for msg in my_agent.run( - ai.make_messages(user="When will the robots take over?") - ): + async for msg in my_agent.run([ai.user_message("When will the robots take over?")]): # Hook parts arrive as pending, waiting for resolution if (hook := msg.get_hook_part()) and hook.status == "pending": answer = input(f"Approve {hook.hook_id}? [y/n] ") diff --git a/examples/samples/mcp_tools.py b/examples/samples/mcp_tools.py index aa020560..1b7968e0 100644 --- a/examples/samples/mcp_tools.py +++ b/examples/samples/mcp_tools.py @@ -29,7 +29,7 @@ async def main() -> None: ) async for msg in my_agent.run( - ai.make_messages(user="How do I create middleware in Next.js?") + [ai.user_message("How do I create middleware in Next.js?")] ): rich.print(msg) diff --git a/examples/samples/media/image_gen_dedicated.py b/examples/samples/media/image_gen_dedicated.py index b394c670..aef16fcf 100644 --- a/examples/samples/media/image_gen_dedicated.py +++ b/examples/samples/media/image_gen_dedicated.py @@ -23,14 +23,14 @@ async def main() -> None: # Generate two images of an anime girl character msg = await model.generate( - ai.make_messages( - user=( + [ + ai.user_message( "Anime girl with twin tails and cat ears, wearing a " "sailor school uniform, striking a victory pose in front " "of a futuristic Tokyo skyline at night, neon lights " "reflecting in her eyes, digital art style" - ), - ), + ) + ], n=2, aspect_ratio="16:9", ) diff --git a/examples/samples/media/image_gen_inline.py b/examples/samples/media/image_gen_inline.py index 190ef936..4273a642 100644 --- a/examples/samples/media/image_gen_inline.py +++ b/examples/samples/media/image_gen_inline.py @@ -39,7 +39,7 @@ async def main() -> None: tools=[], ) - async for msg in my_agent.run(ai.make_messages(user=prompt)): + async for msg in my_agent.run([ai.user_message(prompt)]): if msg.text_delta: print(msg.text_delta, end="", flush=True) diff --git a/examples/samples/media/video_gen.py b/examples/samples/media/video_gen.py index 17b94875..550ad7de 100644 --- a/examples/samples/media/video_gen.py +++ b/examples/samples/media/video_gen.py @@ -23,16 +23,16 @@ async def main() -> None: # Generate a short anime-style video clip print("Generating video (this may take a minute or two)...") msg = await model.generate( - ai.make_messages( - user=( + [ + ai.user_message( "An anime girl with long pink hair and a flowing white " "dress stands on a hilltop at golden hour. A warm breeze " "lifts her hair as she releases a paper lantern into the " "sunset sky. The camera slowly pulls back to reveal dozens " "of lanterns rising over a countryside village below. " "Soft cel-shaded anime art style, warm palette." - ), - ), + ) + ], aspect_ratio="16:9", duration=8, ) diff --git a/examples/samples/multiagent.py b/examples/samples/multiagent.py index aaf8bc8e..f51cf20b 100644 --- a/examples/samples/multiagent.py +++ b/examples/samples/multiagent.py @@ -45,13 +45,13 @@ async def multi(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: result1, result2 = await asyncio.gather( ai.stream_step( agent1.model, - ai.make_messages(system=agent1.system, user=user_query), + [ai.system_message(agent1.system), ai.user_message(user_query)], agent1.tools, label="a1", ), ai.stream_step( agent2.model, - ai.make_messages(system=agent2.system, user=user_query), + [ai.system_message(agent2.system), ai.user_message(user_query)], agent2.tools, label="a2", ), @@ -61,14 +61,14 @@ async def multi(agent: ai.Agent, messages: list[ai.Message]) -> ai.StreamResult: return await ai.stream_step( agent.model, - ai.make_messages( - system="Summarize the results from the other assistants.", - user=combined, - ), + [ + ai.system_message("Summarize the results from the other assistants."), + ai.user_message(combined), + ], label="summary", ) - async for msg in orchestrator.run(ai.make_messages(user="Process the number 5")): + async for msg in orchestrator.run([ai.user_message("Process the number 5")]): if msg.text_delta: prefix = f"[{msg.label}] " if msg.label else "" print(f"{prefix}{msg.text_delta}", end="", flush=True) diff --git a/examples/samples/simple.py b/examples/samples/simple.py index afd0f676..dc93e13b 100644 --- a/examples/samples/simple.py +++ b/examples/samples/simple.py @@ -21,9 +21,7 @@ async def main() -> None: tools=[talk_to_mothership], ) - async for msg in my_agent.run( - ai.make_messages(user="When will the robots take over?") - ): + async for msg in my_agent.run([ai.user_message("When will the robots take over?")]): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/streaming_tool.py b/examples/samples/streaming_tool.py index 109d5368..a29fd23f 100644 --- a/examples/samples/streaming_tool.py +++ b/examples/samples/streaming_tool.py @@ -34,9 +34,7 @@ async def main() -> None: tools=[talk_to_mothership], ) - async for msg in my_agent.run( - ai.make_messages(user="When will the robots take over?") - ): + async for msg in my_agent.run([ai.user_message("When will the robots take over?")]): if msg.label == "tool_progress": print(f" [{msg.text}]") elif msg.text_delta: diff --git a/examples/samples/structured_output.py b/examples/samples/structured_output.py index b9875b2b..73e0267f 100644 --- a/examples/samples/structured_output.py +++ b/examples/samples/structured_output.py @@ -16,10 +16,12 @@ class WeatherForecast(pydantic.BaseModel): async def main() -> None: llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") - messages = ai.make_messages( - system="You are a weather assistant. Respond with realistic weather data.", - user="What's the weather like in San Francisco right now?", - ) + messages = [ + ai.system_message( + "You are a weather assistant. Respond with realistic weather data." + ), + ai.user_message("What's the weather like in San Francisco right now?"), + ] # Streaming: watch the JSON arrive incrementally, get validated output at the end print("--- Streaming ---") diff --git a/examples/temporal-durable/workflow.py b/examples/temporal-durable/workflow.py index 571792d4..4bbd2b10 100644 --- a/examples/temporal-durable/workflow.py +++ b/examples/temporal-durable/workflow.py @@ -96,10 +96,10 @@ async def agent(llm: Any, user_query: str) -> ai.StreamResult: are no longer part of the public API. This example needs a custom models adapter to work with the new Agent API. """ - messages = ai.make_messages( - system="Answer questions using the weather and population tools.", - user=user_query, - ) + messages = [ + ai.system_message("Answer questions using the weather and population tools."), + ai.user_message(user_query), + ] # Manually implement the loop since we can't use Agent with LanguageModel tools = [get_weather, get_population] diff --git a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py index c68ea0d6..ba0ec8e4 100644 --- a/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/adapters/ai_sdk_ui/adapter.py @@ -7,7 +7,6 @@ import dataclasses import json import logging -import uuid from collections.abc import AsyncGenerator, AsyncIterable from typing import Any, Literal @@ -28,11 +27,6 @@ def _to_camel_case(snake_str: str) -> str: return components[0] + "".join(x.title() for x in components[1:]) -def _generate_id(prefix: str = "id") -> str: - """Generate a unique ID with prefix.""" - return f"{prefix}_{uuid.uuid4().hex[:12]}" - - def serialize_part(part: protocol.UIMessageStreamPart) -> str: """Serialize a stream part to JSON with camelCase keys.""" d = dataclasses.asdict(part) @@ -188,7 +182,7 @@ async def to_ui_message_stream( # Handle reasoning streaming (deltas) - reasoning comes before text if delta := msg.reasoning_delta: if not state.reasoning_id: - state.reasoning_id = _generate_id("reasoning") + state.reasoning_id = messages_.generate_id("reasoning") yield protocol.ReasoningStartPart(id=state.reasoning_id) yield protocol.ReasoningDeltaPart(id=state.reasoning_id, delta=delta) @@ -200,7 +194,7 @@ async def to_ui_message_stream( state.reasoning_id = None if not state.text_id: - state.text_id = _generate_id("text") + state.text_id = messages_.generate_id("text") yield protocol.TextStartPart(id=state.text_id) yield protocol.TextDeltaPart(id=state.text_id, delta=delta) @@ -253,7 +247,7 @@ async def to_ui_message_stream( and not has_new_pending_tools and not has_new_tool_results ): - text_id = _generate_id("text") + text_id = messages_.generate_id("text") yield protocol.TextStartPart(id=text_id) yield protocol.TextEndPart(id=text_id) case messages_.ToolPart( @@ -492,7 +486,7 @@ def to_messages( ) # The UI sends one assistant message per conversation turn, but a - # single turn may span multiple stream_loop iterations (e.g. + # single turn may span multiple default-loop iterations (e.g. # [text, tool(done), text, tool(done), text]). LLM APIs expect # one message per iteration, so split at completed-tool boundaries. if ui_msg.role == "assistant": @@ -525,7 +519,7 @@ def _split_assistant_parts( ) -> list[messages_.Message]: """Split assistant parts at completed-tool → non-tool boundaries. - Returns one ``Message`` per ``stream_loop`` iteration so that LLM + Returns one ``Message`` per default-loop iteration so that LLM adapters receive correctly-shaped single-iteration messages. """ messages: list[messages_.Message] = [] diff --git a/src/vercel_ai_sdk/adapters/ai_sdk_ui/ui_message.py b/src/vercel_ai_sdk/adapters/ai_sdk_ui/ui_message.py index f1a15a34..b886444b 100644 --- a/src/vercel_ai_sdk/adapters/ai_sdk_ui/ui_message.py +++ b/src/vercel_ai_sdk/adapters/ai_sdk_ui/ui_message.py @@ -9,15 +9,11 @@ from __future__ import annotations -import uuid from typing import Any, Literal, cast import pydantic - -def _generate_id(prefix: str = "id") -> str: - """Generate a unique ID with prefix.""" - return f"{prefix}_{uuid.uuid4().hex[:12]}" +from ...types import messages as messages_ class UITextPart(pydantic.BaseModel): @@ -211,7 +207,7 @@ class UIMessage(pydantic.BaseModel): model_config = pydantic.ConfigDict(populate_by_name=True) - id: str = pydantic.Field(default_factory=lambda: _generate_id("msg")) + id: str = pydantic.Field(default_factory=lambda: messages_.generate_id("msg")) role: Literal["user", "assistant", "system"] parts: list[UIMessagePart] = pydantic.Field(default_factory=list) diff --git a/src/vercel_ai_sdk/agents/agent.py b/src/vercel_ai_sdk/agents/agent.py index 1a5b1ec7..5d9e2688 100644 --- a/src/vercel_ai_sdk/agents/agent.py +++ b/src/vercel_ai_sdk/agents/agent.py @@ -32,22 +32,20 @@ from .. import models from ..types import messages as messages_ from . import checkpoint as checkpoint_ -from . import context as context_ -from . import runtime as runtime_ -from . import streams as streams_ +from . import context, runtime, streams from . import tools as tools_ # ── Types ───────────────────────────────────────────────────────── LoopFn = Callable[ - ["Agent", list[messages_.Message]], Awaitable[streams_.StreamResult | None] + ["Agent", list[messages_.Message]], Awaitable[streams.StreamResult | None] ] # ── Composition primitives ──────────────────────────────────────── -@streams_.stream +@streams.stream async def stream_step( model: models.Model, messages: list[messages_.Message], @@ -66,8 +64,7 @@ async def stream_step( async for msg in models.stream( model, messages, tools=tools, output_type=output_type, **kwargs ): - msg.label = label - yield msg + yield msg.model_copy(update={"label": label}) if label is not None else msg # ── AgentRun ────────────────────────────────────────────────────── @@ -92,26 +89,26 @@ class AgentRun: print(result.text) """ - def __init__(self, inner: runtime_.RunResult) -> None: + def __init__(self, inner: runtime.RunResult) -> None: self._inner = inner async def __aiter__(self) -> AsyncGenerator[messages_.Message]: async for msg in self._inner: yield msg - async def collect(self) -> streams_.StreamResult: + async def collect(self) -> streams.StreamResult: """Drain the stream and return a :class:`StreamResult`.""" msgs: list[messages_.Message] = [] async for msg in self._inner: msgs.append(msg) - return streams_.StreamResult(messages=msgs) + return streams.StreamResult(messages=msgs) @property def checkpoint(self) -> checkpoint_.Checkpoint: return self._inner.checkpoint @property - def pending_hooks(self) -> dict[str, runtime_.HookInfo]: + def pending_hooks(self) -> dict[str, runtime.HookInfo]: return self._inner.pending_hooks @@ -183,7 +180,7 @@ async def custom( async def _default_loop( self, messages: list[messages_.Message] - ) -> streams_.StreamResult: + ) -> streams.StreamResult: """Built-in loop: stream LLM, execute tools, repeat.""" local_messages = list(messages) @@ -194,15 +191,19 @@ async def _default_loop( return result last_msg = result.last_message - if last_msg is not None: - local_messages.append(last_msg) + if last_msg is None: + return result - await asyncio.gather( + updated_parts = await asyncio.gather( *( - runtime_.execute_tool(tc, message=last_msg) + runtime.execute_tool(tc, message=last_msg) for tc in result.tool_calls ) ) + updated_msg = last_msg + for updated_tc in updated_parts: + updated_msg = updated_msg.replace(updated_tc) + local_messages.append(updated_msg) def run( self, @@ -230,15 +231,15 @@ def run( ) full_messages.extend(messages) - ctx = context_.Context(tools=self.tools) + ctx = context.Context(tools=self.tools) - # Build the graph function that runtime_.run() expects - async def _graph() -> streams_.StreamResult | None: + # Build the graph function that runtime.run() expects + async def _graph() -> streams.StreamResult | None: if self._custom_loop: return await self._custom_loop(self, full_messages) return await self._default_loop(full_messages) - inner = runtime_.run( + inner = runtime.run( _graph, checkpoint=checkpoint, context=ctx, diff --git a/src/vercel_ai_sdk/agents/checkpoint.py b/src/vercel_ai_sdk/agents/checkpoint.py index c3d079bc..30c9bda6 100644 --- a/src/vercel_ai_sdk/agents/checkpoint.py +++ b/src/vercel_ai_sdk/agents/checkpoint.py @@ -5,7 +5,7 @@ import pydantic from ..types import messages as messages_ -from . import streams as streams_ +from . import streams class StepEvent(pydantic.BaseModel): @@ -14,8 +14,8 @@ class StepEvent(pydantic.BaseModel): index: int messages: list[messages_.Message] - def to_stream_result(self) -> streams_.StreamResult: - return streams_.StreamResult(messages=list(self.messages)) + def to_stream_result(self) -> streams.StreamResult: + return streams.StreamResult(messages=list(self.messages)) class ToolEvent(pydantic.BaseModel): diff --git a/src/vercel_ai_sdk/agents/mcp/client.py b/src/vercel_ai_sdk/agents/mcp/client.py index def1f0e6..17d7bb3a 100644 --- a/src/vercel_ai_sdk/agents/mcp/client.py +++ b/src/vercel_ai_sdk/agents/mcp/client.py @@ -14,8 +14,7 @@ import mcp.client.streamable_http import mcp.types -from .. import context as context_ -from .. import tools as tools_ +from .. import context, tools __all__ = [ "get_stdio_tools", @@ -140,7 +139,7 @@ async def get_stdio_tools( env: dict[str, str] | None = None, cwd: str | None = None, tool_prefix: str | None = None, -) -> list[tools_.Tool[..., Any]]: +) -> list[tools.Tool[..., Any]]: """ Get tools from an MCP server running as a subprocess. @@ -155,7 +154,7 @@ async def get_stdio_tools( tool_prefix: Optional prefix to add to all tool names. Returns: - List of Tool objects that can be passed to stream_loop. + List of Tool objects that can be passed to an agent or custom loop. Example: tools = await ai.mcp.get_stdio_tools( @@ -188,7 +187,7 @@ async def get_http_tools( *, headers: dict[str, str] | None = None, tool_prefix: str | None = None, -) -> list[tools_.Tool[..., Any]]: +) -> list[tools.Tool[..., Any]]: """ Get tools from an MCP server over HTTP (Streamable HTTP transport). @@ -201,7 +200,7 @@ async def get_http_tools( tool_prefix: Optional prefix to add to all tool names. Returns: - List of Tool objects that can be passed to stream_loop. + List of Tool objects that can be passed to an agent or custom loop. Example: tools = await ai.mcp.get_http_tools( @@ -231,13 +230,13 @@ def _mcp_tool_to_native( connection_key: str, transport_factory: Callable[[], contextlib.AbstractAsyncContextManager[Any]], tool_prefix: str | None, -) -> tools_.Tool[..., Any]: +) -> tools.Tool[..., Any]: """Convert an MCP tool to a native Tool.""" name = mcp_tool.name if tool_prefix: name = f"{tool_prefix}_{name}" - schema = tools_.ToolSchema( + schema = tools.ToolSchema( name=name, description=mcp_tool.description or "", param_schema=mcp_tool.inputSchema, @@ -246,29 +245,29 @@ def _mcp_tool_to_native( # Determine source provenance from connection key if connection_key.startswith("http:"): - source = context_.ToolSource( + source = context.ToolSource( kind="mcp_http", uri=connection_key.removeprefix("http:"), ) elif connection_key.startswith("stdio:"): - source = context_.ToolSource( + source = context.ToolSource( kind="mcp_stdio", server_command=connection_key.removeprefix("stdio:"), ) else: - source = context_.ToolSource(kind="mcp") + source = context.ToolSource(kind="mcp") - t = tools_.Tool( + t = tools.Tool( fn=_make_tool_fn(connection_key, mcp_tool.name, transport_factory), schema=schema, source=source, ) # Register on active Context if available, else fall back to global - ctx = context_._context.get(None) + ctx = context._context.get(None) if ctx is not None: ctx.register_tool(t) - tools_._tool_registry[name] = t + tools._tool_registry[name] = t return t diff --git a/src/vercel_ai_sdk/agents/runtime.py b/src/vercel_ai_sdk/agents/runtime.py index 724267b3..19098aa7 100644 --- a/src/vercel_ai_sdk/agents/runtime.py +++ b/src/vercel_ai_sdk/agents/runtime.py @@ -10,14 +10,11 @@ import pydantic -from ..telemetry import events as telemetry_ +from ..telemetry import events as telemetry from ..types import messages as messages_ from . import checkpoint as checkpoint_ from . import context as context_ -from . import hooks as hooks_ -from . import mcp -from . import streams as streams_ -from . import tools as tools_ +from . import hooks, mcp, streams, tools logger = logging.getLogger(__name__) @@ -60,7 +57,7 @@ def __init__(self, checkpoint: checkpoint_.Checkpoint | None = None) -> None: def step_index(self) -> int: return self._step_index - def try_replay_step(self) -> streams_.StreamResult | None: + def try_replay_step(self) -> streams.StreamResult | None: if self._step_index < len(self._checkpoint.steps): event = self._checkpoint.steps[self._step_index] self._step_index += 1 @@ -68,7 +65,7 @@ def try_replay_step(self) -> streams_.StreamResult | None: return event.to_stream_result() return None - def record_step(self, result: streams_.StreamResult) -> None: + def record_step(self, result: streams.StreamResult) -> None: event = checkpoint_.StepEvent( index=self._step_index, messages=list(result.messages), @@ -156,7 +153,7 @@ class _Sentinel: def __init__(self) -> None: self._step_queue: asyncio.Queue[ - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + tuple[streams.Stream, asyncio.Future[streams.StreamResult]] | HookSuspension | LoopExecutor._Sentinel ] = asyncio.Queue() @@ -172,7 +169,7 @@ def __init__(self) -> None: # ── Producers (called by graph code) ────────────────────── async def put_step( - self, step_fn: streams_.Stream, future: asyncio.Future[streams_.StreamResult] + self, step_fn: streams.Stream, future: asyncio.Future[streams.StreamResult] ) -> None: await self._step_queue.put((step_fn, future)) @@ -190,7 +187,7 @@ async def done(self) -> None: async def next( self, timeout: float = 0.1 ) -> ( - tuple[streams_.Stream, asyncio.Future[streams_.StreamResult]] + tuple[streams.Stream, asyncio.Future[streams.StreamResult]] | HookSuspension | None ): @@ -286,12 +283,12 @@ def _find_runtime_param(fn: Callable[..., Any]) -> str | None: async def execute_tool( tool_call: messages_.ToolPart, message: messages_.Message | None = None, -) -> Any: +) -> messages_.ToolPart: """Execute a single tool call with replay support. Looks up the tool by name — first from the active Context (if any), - then from the global registry. Executes it and updates the ToolPart - (and parent Message) with the result. Emits the updated message to + then from the global registry. Returns an updated (immutable) + ToolPart with the result filled in. Emits the updated message to the LoopExecutor queue so the UI sees the transition from status="pending" to status="result" (or "error"). @@ -300,63 +297,61 @@ async def execute_tool( """ rt = _runtime.get(None) - # Replay: return cached result if available + # Replay: return updated part from cache if rt: cached = rt.log.try_replay_tool(tool_call.tool_call_id) if cached is not None: if cached.status == "error": - tool_call.set_error(cached.result) - else: - tool_call.set_result(cached.result) - return cached.result + return tool_call.with_error(cached.result) + return tool_call.with_result(cached.result) - telemetry_.handle( - telemetry_.ToolCallStartEvent( + telemetry.handle( + telemetry.ToolCallStartEvent( tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, args=tool_call.tool_args, ) ) - t0 = telemetry_.time_ms() + t0 = telemetry.time_ms() # Fresh execution — resolve from Context first, then global registry - tool: tools_.Tool[..., Any] | None = None + tool: tools.Tool[..., Any] | None = None ctx = context_._context.get(None) if ctx is not None: tool = ctx.get_tool(tool_call.tool_name) if tool is None: - tool = tools_.get_tool(tool_call.tool_name) + tool = tools.get_tool(tool_call.tool_name) if tool is None: raise ValueError(f"Tool not found in registry: {tool_call.tool_name}") error_str: str | None = None try: result = await tool.validate_and_call(tool_call.tool_args, rt) - tool_call.set_result(result) + updated = tool_call.with_result(result) except (json.JSONDecodeError, pydantic.ValidationError) as exc: result = f"{type(exc).__name__}: {exc}" error_str = result - tool_call.set_error(result) + updated = tool_call.with_error(result) - telemetry_.handle( - telemetry_.ToolCallFinishEvent( + telemetry.handle( + telemetry.ToolCallFinishEvent( tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, result=result, error=error_str, - duration_ms=telemetry_.time_ms() - t0, + duration_ms=telemetry.time_ms() - t0, ) ) # Record for checkpoint if rt: - rt.log.record_tool(tool_call.tool_call_id, result, status=tool_call.status) + rt.log.record_tool(tool_call.tool_call_id, result, status=updated.status) # Emit updated message so UI sees status change if rt and message: - await rt.executor.put_message(message.model_copy(deep=True)) + await rt.executor.put_message(message.replace(updated)) - return result + return updated # ── RunResult ───────────────────────────────────────────────────── @@ -450,7 +445,7 @@ def run( if checkpoint and checkpoint.pending_hooks: pending_labels = [ph.label for ph in checkpoint.pending_hooks] has_resolution = any( - label in hooks_._pending_resolutions for label in pending_labels + label in hooks._pending_resolutions for label in pending_labels ) if not has_resolution: logger.info( @@ -474,9 +469,9 @@ async def _generate() -> AsyncGenerator[messages_.Message]: ctx = context or context_.Context() token_context = context_._context.set(ctx) - token_run_id = telemetry_.start_run() + token_run_id = telemetry.start_run() - telemetry_.handle(telemetry_.RunStartEvent()) + telemetry.handle(telemetry.RunStartEvent()) mcp_pool: dict[str, mcp.client._Connection] = {} mcp_token = mcp.client._pool.set(mcp_pool) @@ -542,8 +537,8 @@ async def _generate() -> AsyncGenerator[messages_.Message]: # ── Regular step ─────────────────────────── step_fn, future = item - telemetry_.handle( - telemetry_.StepStartEvent( + telemetry.handle( + telemetry.StepStartEvent( step_index=rt.log.step_index, ) ) @@ -554,18 +549,17 @@ async def _generate() -> AsyncGenerator[messages_.Message]: result_messages: list[messages_.Message] = [] async for msg in step_fn(): - msg_copy = msg.model_copy(deep=True) - yield msg_copy + yield msg result_messages.append(msg) for tool_msg in rt.executor.drain_messages(): yield tool_msg - step_result = streams_.StreamResult(messages=result_messages) + step_result = streams.StreamResult(messages=result_messages) future.set_result(step_result) - telemetry_.handle( - telemetry_.StepFinishEvent( + telemetry.handle( + telemetry.StepFinishEvent( step_index=rt.log.step_index, result=step_result, ) @@ -589,15 +583,15 @@ async def _generate() -> AsyncGenerator[messages_.Message]: raise finally: - telemetry_.handle( - telemetry_.RunFinishEvent( + telemetry.handle( + telemetry.RunFinishEvent( usage=total_usage, error=run_error, ) ) - telemetry_.end_run(token_run_id) + telemetry.end_run(token_run_id) - hooks_._cleanup_run(rt.executor._hook_labels) + hooks._cleanup_run(rt.executor._hook_labels) if mcp_token is not None: await mcp.client.close_connections() diff --git a/src/vercel_ai_sdk/models/ai_gateway/_common.py b/src/vercel_ai_sdk/models/ai_gateway/_common.py index 0031661f..02333799 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/_common.py +++ b/src/vercel_ai_sdk/models/ai_gateway/_common.py @@ -20,10 +20,10 @@ import httpx +from ...types import media from ...types import messages as messages_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import media as media_ _PROTOCOL_VERSION = "0.0.1" @@ -64,7 +64,7 @@ def extract_input_files(messages: list[messages_.Message]) -> list[messages_.Fil def file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: """Convert a :class:`FilePart` to the gateway wire format for input files.""" data = part.data - if isinstance(data, str) and media_.is_url(data): + if isinstance(data, str) and media.is_url(data): return {"type": "url", "url": data} if isinstance(data, bytes): b64 = base64.b64encode(data).decode("ascii") diff --git a/src/vercel_ai_sdk/models/ai_gateway/generate.py b/src/vercel_ai_sdk/models/ai_gateway/generate.py index ab460b02..569fb8cd 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/generate.py +++ b/src/vercel_ai_sdk/models/ai_gateway/generate.py @@ -12,12 +12,12 @@ import httpx import pydantic +from ...types import media from ...types import messages as messages_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import media as media_ -from . import _common -from . import errors as errors_ +from ..core.helpers import files +from . import _common, errors # --------------------------------------------------------------------------- # Parameter types @@ -90,7 +90,7 @@ async def _generate_image( response = await client.http.post(url, json=body, headers=headers) if response.status_code >= 400: - raise errors_.create_gateway_error( + raise errors.create_gateway_error( response_body=response.text, status_code=response.status_code, api_key_provided=bool(client.api_key), @@ -106,12 +106,12 @@ async def _generate_image( output_tokens=usage_data.get("outputTokens") or 0, ) - files: list[messages_.Part] = [] + parts: list[messages_.Part] = [] for img_b64 in raw_images: - media_type = media_.detect_image_media_type(img_b64) or "image/png" - files.append(messages_.FilePart(data=img_b64, media_type=media_type)) + media_type = media.detect_image_media_type(img_b64) or "image/png" + parts.append(messages_.FilePart(data=img_b64, media_type=media_type)) - return messages_.Message(role="assistant", parts=files, usage=usage) + return messages_.Message(role="assistant", parts=parts, usage=usage) # --------------------------------------------------------------------------- @@ -149,7 +149,7 @@ async def _generate_video( ) as response: if response.status_code >= 400: await response.aread() - raise errors_.create_gateway_error( + raise errors.create_gateway_error( response_body=response.text, status_code=response.status_code, api_key_provided=bool(client.api_key), @@ -162,34 +162,34 @@ async def _generate_video( break if not event_data: - raise errors_.GatewayResponseError( + raise errors.GatewayResponseError( "SSE stream ended without any data events", ) if event_data.get("type") == "error": - raise errors_.GatewayInvalidRequestError( + raise errors.GatewayInvalidRequestError( message=event_data.get("message", "unknown error"), status_code=event_data.get("statusCode", 400), ) raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - files: list[messages_.Part] = [] + parts: list[messages_.Part] = [] for video_data in raw_videos: vtype = video_data.get("type", "base64") media_type = video_data.get("mediaType", "video/mp4") if vtype == "url": - downloaded_bytes, content_type = await media_.download(video_data["url"]) + downloaded_bytes, content_type = await files.download(video_data["url"]) if content_type: media_type = content_type - files.append( + parts.append( messages_.FilePart(data=downloaded_bytes, media_type=media_type) ) else: raw_data = video_data.get("data", "") - files.append(messages_.FilePart(data=raw_data, media_type=media_type)) + parts.append(messages_.FilePart(data=raw_data, media_type=media_type)) - return messages_.Message(role="assistant", parts=files) + return messages_.Message(role="assistant", parts=parts) # --------------------------------------------------------------------------- diff --git a/src/vercel_ai_sdk/models/ai_gateway/stream.py b/src/vercel_ai_sdk/models/ai_gateway/stream.py index 92a63266..f50a8658 100644 --- a/src/vercel_ai_sdk/models/ai_gateway/stream.py +++ b/src/vercel_ai_sdk/models/ai_gateway/stream.py @@ -12,14 +12,13 @@ import httpx import pydantic +from ...types import media from ...types import messages as messages_ from ...types import tools as tools_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import media as media_ -from ..core.helpers import streaming as streaming_ -from . import _common -from . import errors as errors_ +from ..core.helpers import files, streaming +from . import _common, errors # --------------------------------------------------------------------------- # Request building — Message list → v3 prompt @@ -29,14 +28,14 @@ async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: """Convert a :class:`FilePart` to a v3 ``file`` content part.""" data = part.data - if isinstance(data, str) and media_.is_downloadable_url(data): - downloaded, _ = await media_.download(data) + if isinstance(data, str) and media.is_downloadable_url(data): + downloaded, _ = await files.download(data) data = downloaded entry: dict[str, Any] = { "type": "file", "mediaType": part.media_type, - "data": media_.data_to_data_url(data, part.media_type), + "data": media.data_to_data_url(data, part.media_type), } if part.filename is not None: entry["filename"] = part.filename @@ -160,16 +159,16 @@ async def _build_request_body( # --------------------------------------------------------------------------- -def _expand_tool_call(data: dict[str, Any]) -> list[streaming_.StreamEvent]: +def _expand_tool_call(data: dict[str, Any]) -> list[streaming.StreamEvent]: """Expand a complete ``tool-call`` part into Start + ArgsDelta + End.""" tc_id = data.get("toolCallId", "") tool_name = data.get("toolName", "") tool_input = data.get("input", "") args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) return [ - streaming_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - streaming_.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), - streaming_.ToolEnd(tool_call_id=tc_id), + streaming.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + streaming.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), + streaming.ToolEnd(tool_call_id=tc_id), ] @@ -200,40 +199,40 @@ def _parse_usage(data: Any) -> messages_.Usage: ) -def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: +def _parse_stream_part(data: dict[str, Any]) -> list[streaming.StreamEvent]: """Convert a ``LanguageModelV3StreamPart`` to internal events.""" match data.get("type", ""): case "text-start": - return [streaming_.TextStart(block_id=data.get("id", "text"))] + return [streaming.TextStart(block_id=data.get("id", "text"))] case "text-delta": return [ - streaming_.TextDelta( + streaming.TextDelta( block_id=data.get("id", "text"), delta=data.get("textDelta", data.get("delta", "")), ) ] case "text-end": - return [streaming_.TextEnd(block_id=data.get("id", "text"))] + return [streaming.TextEnd(block_id=data.get("id", "text"))] case "reasoning-start": - return [streaming_.ReasoningStart(block_id=data.get("id", "reasoning"))] + return [streaming.ReasoningStart(block_id=data.get("id", "reasoning"))] case "reasoning-delta": return [ - streaming_.ReasoningDelta( + streaming.ReasoningDelta( block_id=data.get("id", "reasoning"), delta=data.get("delta", ""), ) ] case "reasoning-end": - return [streaming_.ReasoningEnd(block_id=data.get("id", "reasoning"))] + return [streaming.ReasoningEnd(block_id=data.get("id", "reasoning"))] case "tool-input-start": return [ - streaming_.ToolStart( + streaming.ToolStart( tool_call_id=data.get("id", ""), tool_name=data.get("toolName", ""), ) @@ -241,21 +240,21 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: case "tool-input-delta": return [ - streaming_.ToolArgsDelta( + streaming.ToolArgsDelta( tool_call_id=data.get("id", ""), delta=data.get("delta", ""), ) ] case "tool-input-end": - return [streaming_.ToolEnd(tool_call_id=data.get("id", ""))] + return [streaming.ToolEnd(tool_call_id=data.get("id", ""))] case "tool-call": return _expand_tool_call(data) case "file": return [ - streaming_.FileEvent( + streaming.FileEvent( block_id=data.get("id", f"file-{len(data)}"), media_type=data.get("mediaType", "application/octet-stream"), data=data.get("data", ""), @@ -272,7 +271,7 @@ def _parse_stream_part(data: dict[str, Any]) -> list[streaming_.StreamEvent]: finish_reason = s case _: finish_reason = "stop" - return [streaming_.MessageDone(finish_reason=finish_reason, usage=usage)] + return [streaming.MessageDone(finish_reason=finish_reason, usage=usage)] case _: return [] @@ -306,7 +305,7 @@ async def stream( ) url = f"{client.base_url.rstrip('/')}/language-model" - handler = streaming_.StreamHandler() + handler = streaming.StreamHandler() try: async with client.http.stream( @@ -317,7 +316,7 @@ async def stream( ) as response: if response.status_code >= 400: await response.aread() - raise errors_.create_gateway_error( + raise errors.create_gateway_error( response_body=response.text, status_code=response.status_code, api_key_provided=bool(client.api_key), @@ -327,12 +326,12 @@ async def stream( for event in _parse_stream_part(data): msg = handler.handle_event(event) yield msg - except errors_.GatewayError: + except errors.GatewayError: raise except httpx.TimeoutException as exc: - raise errors_.GatewayTimeoutError(cause=exc) from exc + raise errors.GatewayTimeoutError(cause=exc) from exc except Exception as exc: - raise errors_.GatewayResponseError( + raise errors.GatewayResponseError( message=f"Unexpected error during streaming: {exc}", cause=exc, ) from exc diff --git a/src/vercel_ai_sdk/models/anthropic/adapter.py b/src/vercel_ai_sdk/models/anthropic/adapter.py index 7ad3d25c..23c1f01b 100644 --- a/src/vercel_ai_sdk/models/anthropic/adapter.py +++ b/src/vercel_ai_sdk/models/anthropic/adapter.py @@ -13,12 +13,12 @@ import anthropic import pydantic +from ...types import media from ...types import messages as messages_ from ...types import tools as tools_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import media as media_ -from ..core.helpers import streaming as streaming_ +from ..core.helpers import streaming # --------------------------------------------------------------------------- # Message / tool conversion — internal types → Anthropic wire format @@ -53,7 +53,7 @@ def _file_part_to_anthropic( if mt.startswith("image/"): media_type = "image/jpeg" if mt == "image/*" else mt - if isinstance(part.data, str) and media_.is_url(part.data): + if isinstance(part.data, str) and media.is_url(part.data): return { "type": "image", "source": {"type": "url", "url": part.data}, @@ -63,12 +63,12 @@ def _file_part_to_anthropic( "source": { "type": "base64", "media_type": media_type, - "data": media_.data_to_base64(part.data), + "data": media.data_to_base64(part.data), }, } if mt == "application/pdf": - if isinstance(part.data, str) and media_.is_url(part.data): + if isinstance(part.data, str) and media.is_url(part.data): return { "type": "document", "source": {"type": "url", "url": part.data}, @@ -78,14 +78,14 @@ def _file_part_to_anthropic( "source": { "type": "base64", "media_type": "application/pdf", - "data": media_.data_to_base64(part.data), + "data": media.data_to_base64(part.data), }, } if mt == "text/plain": if isinstance(part.data, bytes): text_data = part.data.decode("utf-8") - elif media_.is_url(part.data): + elif media.is_url(part.data): return { "type": "document", "source": {"type": "url", "url": part.data}, @@ -284,7 +284,7 @@ async def stream( if output_type is not None: api_kwargs["output_format"] = output_type - handler = streaming_.StreamHandler() + handler = streaming.StreamHandler() block_types: dict[int, str] = {} tool_ids: dict[int, str] = {} @@ -304,16 +304,16 @@ async def stream( match block.type: case "text": yield handler.handle_event( - streaming_.TextStart(block_id=str(idx)) + streaming.TextStart(block_id=str(idx)) ) case "thinking": yield handler.handle_event( - streaming_.ReasoningStart(block_id=str(idx)) + streaming.ReasoningStart(block_id=str(idx)) ) case "tool_use": tool_ids[idx] = block.id yield handler.handle_event( - streaming_.ToolStart( + streaming.ToolStart( tool_call_id=block.id, tool_name=block.name, ) @@ -326,14 +326,14 @@ async def stream( match delta.type: case "text_delta": yield handler.handle_event( - streaming_.TextDelta( + streaming.TextDelta( block_id=str(idx), delta=delta.text, ) ) case "thinking_delta": yield handler.handle_event( - streaming_.ReasoningDelta( + streaming.ReasoningDelta( block_id=str(idx), delta=delta.thinking, ) @@ -346,7 +346,7 @@ async def stream( tool_id = tool_ids.get(idx) if tool_id: yield handler.handle_event( - streaming_.ToolArgsDelta( + streaming.ToolArgsDelta( tool_call_id=tool_id, delta=delta.partial_json, ) @@ -357,11 +357,11 @@ async def stream( match block_types.get(idx): case "text": yield handler.handle_event( - streaming_.TextEnd(block_id=str(idx)) + streaming.TextEnd(block_id=str(idx)) ) case "thinking": yield handler.handle_event( - streaming_.ReasoningEnd( + streaming.ReasoningEnd( block_id=str(idx), signature=signature_buffer.get(idx), ) @@ -370,7 +370,7 @@ async def stream( tool_id = tool_ids.get(idx) if tool_id: yield handler.handle_event( - streaming_.ToolEnd(tool_call_id=tool_id) + streaming.ToolEnd(tool_call_id=tool_id) ) snapshot = sdk_stream.current_message_snapshot @@ -384,6 +384,6 @@ async def stream( ), raw=sdk_usage.model_dump(exclude_none=True) or None, ) - yield handler.handle_event(streaming_.MessageDone(usage=usage)) + yield handler.handle_event(streaming.MessageDone(usage=usage)) finally: await sdk_client.close() diff --git a/src/vercel_ai_sdk/models/core/helpers/files.py b/src/vercel_ai_sdk/models/core/helpers/files.py new file mode 100644 index 00000000..3d463afd --- /dev/null +++ b/src/vercel_ai_sdk/models/core/helpers/files.py @@ -0,0 +1,103 @@ +"""Network IO for media — download with size limits and SSRF prevention. + +Pure media utilities (detection, encoding, inference) live in +:mod:`vercel_ai_sdk.types.media`. +""" + +from __future__ import annotations + +import httpx + +DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +class DownloadError(Exception): + """Raised when a URL download fails.""" + + def __init__( + self, + url: str, + *, + status_code: int | None = None, + status_text: str | None = None, + cause: BaseException | None = None, + ) -> None: + parts = [f"Failed to download {url!r}"] + if status_code is not None: + parts.append(f"status={status_code}") + if status_text: + parts.append(status_text) + super().__init__(": ".join(parts)) + self.url = url + self.status_code = status_code + if cause is not None: + self.__cause__ = cause + + +def _validate_url(url: str) -> None: + """Reject non-HTTP(S) URLs (SSRF prevention).""" + from urllib.parse import urlparse + + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_SCHEMES: + raise DownloadError( + url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" + ) + + +async def download( + url: str, + *, + max_bytes: int = DEFAULT_MAX_BYTES, +) -> tuple[bytes, str | None]: + """Download *url* and return ``(data, content_type)``. + + Args: + url: The URL to fetch (must be ``http`` or ``https``). + max_bytes: Maximum response size. Defaults to 100 MiB. + + Returns: + A tuple of ``(raw_bytes, content_type_or_None)``. + + Raises: + DownloadError: On any failure (network, HTTP status, size, etc.). + """ + _validate_url(url) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(url) + + # Validate redirect target + if resp.url is not None and str(resp.url) != url: + _validate_url(str(resp.url)) + + if resp.status_code >= 400: + raise DownloadError( + url, + status_code=resp.status_code, + status_text=resp.reason_phrase or "", + ) + + data = resp.content + if len(data) > max_bytes: + raise DownloadError( + url, + status_text=( + f"Response exceeds maximum size " + f"({len(data)} > {max_bytes} bytes)" + ), + ) + + content_type = resp.headers.get("content-type") + # Strip charset/parameters: "image/png; charset=..." → "image/png" + if content_type: + content_type = content_type.split(";")[0].strip() + + return data, content_type or None + + except DownloadError: + raise + except Exception as exc: + raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/models/core/helpers/streaming.py b/src/vercel_ai_sdk/models/core/helpers/streaming.py index 11d27006..fb776295 100644 --- a/src/vercel_ai_sdk/models/core/helpers/streaming.py +++ b/src/vercel_ai_sdk/models/core/helpers/streaming.py @@ -97,7 +97,7 @@ class StreamHandler: This is the normalization layer between LLM adapters and the rest of the system. """ - message_id: str = dataclasses.field(default_factory=messages_._gen_id) + message_id: str = dataclasses.field(default_factory=messages_.generate_id) # Accumulators _text_blocks: dict[str, str] = dataclasses.field(default_factory=dict) @@ -192,6 +192,7 @@ def _build_message( is_active = bid == self._active_reasoning_id parts.append( messages_.ReasoningPart( + id=bid, text=text, signature=sig, state="streaming" if is_active else "done", @@ -204,6 +205,7 @@ def _build_message( is_active = bid == self._active_text_id parts.append( messages_.TextPart( + id=bid, text=text, state="streaming" if is_active else "done", delta=text_delta if is_active else None, @@ -215,6 +217,7 @@ def _build_message( is_active = tcid in self._active_tool_ids parts.append( messages_.ToolPart( + id=tcid, tool_call_id=tcid, tool_name=name, tool_args=args, @@ -224,8 +227,8 @@ def _build_message( ) # File parts (inline images/videos from LLMs like Gemini, GPT-5) - for _bid, (media_type, data) in self._files.items(): - parts.append(messages_.FilePart(data=data, media_type=media_type)) + for bid, (media_type, data) in self._files.items(): + parts.append(messages_.FilePart(id=bid, data=data, media_type=media_type)) return messages_.Message( id=self.message_id, @@ -259,6 +262,5 @@ async def events_to_messages( data=data, output_type_name=f"{output_type.__module__}.{output_type.__qualname__}", ) - msg = msg.model_copy() - msg.parts = [*msg.parts, part] + msg = msg.model_copy(update={"parts": [*msg.parts, part]}) yield msg diff --git a/src/vercel_ai_sdk/models/openai/adapter.py b/src/vercel_ai_sdk/models/openai/adapter.py index 8f63c244..c0d6261d 100644 --- a/src/vercel_ai_sdk/models/openai/adapter.py +++ b/src/vercel_ai_sdk/models/openai/adapter.py @@ -12,12 +12,12 @@ import openai import pydantic +from ...types import media from ...types import messages as messages_ from ...types import tools as tools_ from ..core import client as client_ from ..core import model as model_ -from ..core.helpers import media as media_ -from ..core.helpers import streaming as streaming_ +from ..core.helpers import files, streaming # --------------------------------------------------------------------------- # Message / tool conversion — internal types → OpenAI wire format @@ -57,25 +57,25 @@ async def _file_part_to_openai( if mt.startswith("image/"): media_type = "image/jpeg" if mt == "image/*" else mt - url = media_.data_to_data_url(data, media_type) + url = media.data_to_data_url(data, media_type) return {"type": "image_url", "image_url": {"url": url}} if mt.startswith("audio/"): - if isinstance(data, str) and media_.is_downloadable_url(data): - downloaded, _ = await media_.download(data) + if isinstance(data, str) and media.is_downloadable_url(data): + downloaded, _ = await files.download(data) data = downloaded fmt = mt.split("/", 1)[1] if "/" in mt else mt - b64 = media_.data_to_base64(data) + b64 = media.data_to_base64(data) return { "type": "input_audio", "input_audio": {"data": b64, "format": fmt}, } if mt == "application/pdf": - if isinstance(data, str) and media_.is_downloadable_url(data): - downloaded, _ = await media_.download(data) + if isinstance(data, str) and media.is_downloadable_url(data): + downloaded, _ = await files.download(data) data = downloaded - data_url = media_.data_to_data_url(data, mt) + data_url = media.data_to_data_url(data, mt) filename = part.filename or "document.pdf" return { "type": "file", @@ -85,7 +85,7 @@ async def _file_part_to_openai( if mt.startswith("text/"): if isinstance(data, bytes): text_content = data.decode("utf-8") - elif media_.is_url(data): + elif media.is_url(data): text_content = data else: import base64 as _b64 @@ -253,7 +253,7 @@ async def stream( reasoning_config["effort"] = reasoning_effort api_kwargs["extra_body"] = {"reasoning": reasoning_config} - handler = streaming_.StreamHandler() + handler = streaming.StreamHandler() try: sdk_stream = await sdk_client.chat.completions.create(**api_kwargs) @@ -308,10 +308,10 @@ async def stream( if not reasoning_started: reasoning_started = True yield handler.handle_event( - streaming_.ReasoningStart(block_id="reasoning") + streaming.ReasoningStart(block_id="reasoning") ) yield handler.handle_event( - streaming_.ReasoningDelta( + streaming.ReasoningDelta( block_id="reasoning", delta=reasoning_value ) ) @@ -319,15 +319,15 @@ async def stream( if delta.content: if reasoning_started: yield handler.handle_event( - streaming_.ReasoningEnd(block_id="reasoning") + streaming.ReasoningEnd(block_id="reasoning") ) reasoning_started = False if not text_started: text_started = True - yield handler.handle_event(streaming_.TextStart(block_id="text")) + yield handler.handle_event(streaming.TextStart(block_id="text")) yield handler.handle_event( - streaming_.TextDelta(block_id="text", delta=delta.content) + streaming.TextDelta(block_id="text", delta=delta.content) ) if delta.tool_calls: @@ -351,7 +351,7 @@ async def stream( if not tc_state[idx]["started"] and tid: tc_state[idx]["started"] = True yield handler.handle_event( - streaming_.ToolStart( + streaming.ToolStart( tool_call_id=tid, tool_name=tname, ) @@ -359,7 +359,7 @@ async def stream( if tid: yield handler.handle_event( - streaming_.ToolArgsDelta( + streaming.ToolArgsDelta( tool_call_id=tid, delta=tc.function.arguments, ) @@ -369,18 +369,18 @@ async def stream( finish_reason = choice.finish_reason if reasoning_started: yield handler.handle_event( - streaming_.ReasoningEnd(block_id="reasoning") + streaming.ReasoningEnd(block_id="reasoning") ) if text_started: - yield handler.handle_event(streaming_.TextEnd(block_id="text")) + yield handler.handle_event(streaming.TextEnd(block_id="text")) for tc in tc_state.values(): if tc["started"] and tc["id"]: yield handler.handle_event( - streaming_.ToolEnd(tool_call_id=tc["id"]) + streaming.ToolEnd(tool_call_id=tc["id"]) ) yield handler.handle_event( - streaming_.MessageDone(finish_reason=finish_reason, usage=usage) + streaming.MessageDone(finish_reason=finish_reason, usage=usage) ) finally: await sdk_client.close() diff --git a/src/vercel_ai_sdk/telemetry/events.py b/src/vercel_ai_sdk/telemetry/events.py index f12fbd47..3064ad9f 100644 --- a/src/vercel_ai_sdk/telemetry/events.py +++ b/src/vercel_ai_sdk/telemetry/events.py @@ -29,7 +29,7 @@ def handle(self, event: ai.telemetry.TelemetryEvent) -> None: import uuid from typing import Any, Protocol, runtime_checkable -from ..agents import streams as streams_ +from ..agents import streams from ..types import messages as messages_ # ── Protocol ─────────────────────────────────────────────────────── @@ -69,7 +69,7 @@ class StepFinishEvent(TelemetryEvent): """Emitted when a ``@stream``-decorated step finishes.""" step_index: int - result: streams_.StreamResult + result: streams.StreamResult @dataclasses.dataclass(frozen=True, slots=True) diff --git a/src/vercel_ai_sdk/types/__init__.py b/src/vercel_ai_sdk/types/__init__.py index 7bf404f1..803f8b2a 100644 --- a/src/vercel_ai_sdk/types/__init__.py +++ b/src/vercel_ai_sdk/types/__init__.py @@ -16,6 +16,7 @@ ToolDelta, ToolPart, Usage, + generate_id, make_messages, ) from .tools import ToolLike, ToolSchema @@ -34,5 +35,6 @@ "ToolLike", "ToolSchema", "Usage", + "generate_id", "make_messages", ] diff --git a/src/vercel_ai_sdk/types/builders.py b/src/vercel_ai_sdk/types/builders.py new file mode 100644 index 00000000..0a4d9f62 --- /dev/null +++ b/src/vercel_ai_sdk/types/builders.py @@ -0,0 +1,93 @@ +"""Composable message construction helpers. + +Convenience functions for building Message objects without manually +constructing Part lists. Each ``*_message`` function accepts a mix of +plain strings (auto-wrapped in :class:`TextPart`) and existing +:class:`Part` objects, returning a single :class:`Message`. +""" + +from __future__ import annotations + +from .messages import ( + FilePart, + HookPart, + Message, + Part, + ReasoningPart, + StructuredOutputPart, + TextPart, + ToolPart, +) + +_PART_TYPES = ( + TextPart, + ToolPart, + ReasoningPart, + HookPart, + StructuredOutputPart, + FilePart, +) + +# A value that can appear as message content: bare strings become TextPart. +PartLike = str | Part + + +def _coerce_parts(args: tuple[PartLike, ...]) -> list[Part]: + parts: list[Part] = [] + for arg in args: + if isinstance(arg, str): + parts.append(TextPart(text=arg)) + elif isinstance(arg, _PART_TYPES): + parts.append(arg) + else: + raise TypeError(f"Expected str or Part, got {type(arg).__name__}") + return parts + + +def system_message(*content: PartLike) -> Message: + """Create a system message. + + >>> ai.system_message("You are a helpful robot.") + """ + return Message(role="system", parts=_coerce_parts(content)) + + +def user_message(*content: PartLike) -> Message: + """Create a user message from strings and/or Part objects. + + >>> ai.user_message("Describe this image:", ai.file_part(url)) + """ + return Message(role="user", parts=_coerce_parts(content)) + + +def assistant_message(*content: PartLike) -> Message: + """Create an assistant message from strings and/or Part objects. + + >>> ai.assistant_message(ai.thinking("hmm"), "Hello!") + """ + return Message(role="assistant", parts=_coerce_parts(content)) + + +def file_part( + data: str | bytes, + *, + media_type: str | None = None, + filename: str | None = None, +) -> FilePart: + """Create a :class:`FilePart` from a URL string or raw bytes. + + Dispatches to :meth:`FilePart.from_url` (for ``str``) or + :meth:`FilePart.from_bytes` (for ``bytes``), with automatic + media-type detection. + """ + if isinstance(data, str): + return FilePart.from_url(data, media_type=media_type) + return FilePart.from_bytes(data, media_type=media_type, filename=filename) + + +def thinking(text: str, *, signature: str | None = None) -> ReasoningPart: + """Create a :class:`ReasoningPart`. + + Useful for replaying conversation history that includes model reasoning. + """ + return ReasoningPart(text=text, signature=signature) diff --git a/src/vercel_ai_sdk/models/core/helpers/media.py b/src/vercel_ai_sdk/types/media.py similarity index 75% rename from src/vercel_ai_sdk/models/core/helpers/media.py rename to src/vercel_ai_sdk/types/media.py index 3fc3e793..0d343183 100644 --- a/src/vercel_ai_sdk/models/core/helpers/media.py +++ b/src/vercel_ai_sdk/types/media.py @@ -1,11 +1,14 @@ +"""Media type detection, inference, and encoding utilities. + +Pure functions with no IO dependencies — safe to use from any layer. +""" + from __future__ import annotations import base64 import base64 as _b64 import mimetypes -import httpx - # -- URL helpers ----------------------------------------------------------- @@ -273,98 +276,3 @@ def detect_image_media_type(data: bytes | str) -> str | None: def detect_audio_media_type(data: bytes | str) -> str | None: """Detect audio format from magic bytes.""" return detect_media_type(data, AUDIO_SIGNATURES) - - -DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) -_ALLOWED_SCHEMES = frozenset({"http", "https"}) - - -class DownloadError(Exception): - """Raised when a URL download fails.""" - - def __init__( - self, - url: str, - *, - status_code: int | None = None, - status_text: str | None = None, - cause: BaseException | None = None, - ) -> None: - parts = [f"Failed to download {url!r}"] - if status_code is not None: - parts.append(f"status={status_code}") - if status_text: - parts.append(status_text) - super().__init__(": ".join(parts)) - self.url = url - self.status_code = status_code - if cause is not None: - self.__cause__ = cause - - -def _validate_url(url: str) -> None: - """Reject non-HTTP(S) URLs (SSRF prevention).""" - from urllib.parse import urlparse - - parsed = urlparse(url) - if parsed.scheme not in _ALLOWED_SCHEMES: - raise DownloadError( - url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" - ) - - -async def download( - url: str, - *, - max_bytes: int = DEFAULT_MAX_BYTES, -) -> tuple[bytes, str | None]: - """Download *url* and return ``(data, content_type)``. - - Args: - url: The URL to fetch (must be ``http`` or ``https``). - max_bytes: Maximum response size. Defaults to 100 MiB. - - Returns: - A tuple of ``(raw_bytes, content_type_or_None)``. - - Raises: - DownloadError: On any failure (network, HTTP status, size, etc.). - """ - _validate_url(url) - - try: - async with httpx.AsyncClient(follow_redirects=True) as client: - resp = await client.get(url) - - # Validate redirect target - if resp.url is not None and str(resp.url) != url: - _validate_url(str(resp.url)) - - if resp.status_code >= 400: - raise DownloadError( - url, - status_code=resp.status_code, - status_text=resp.reason_phrase or "", - ) - - data = resp.content - if len(data) > max_bytes: - raise DownloadError( - url, - status_text=( - f"Response exceeds maximum size " - f"({len(data)} > {max_bytes} bytes)" - ), - ) - - content_type = resp.headers.get("content-type") - # Strip charset/parameters: "image/png; charset=..." → "image/png" - if content_type: - content_type = content_type.split(";")[0].strip() - - return data, content_type or None - - except DownloadError: - raise - except Exception as exc: - raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/types/messages.py b/src/vercel_ai_sdk/types/messages.py index 683c7469..ef6279fa 100644 --- a/src/vercel_ai_sdk/types/messages.py +++ b/src/vercel_ai_sdk/types/messages.py @@ -2,15 +2,25 @@ import importlib import uuid -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, overload import pydantic + +def generate_id(prefix: str | None = None) -> str: + """Generate a short random ID for messages and parts.""" + raw = uuid.uuid4().hex[:12] + return f"{prefix}_{raw}" if prefix else raw + + # Streaming state for parts PartState = Literal["streaming", "done"] class TextPart(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) text: str type: Literal["text"] = "text" # Streaming state @@ -19,6 +29,9 @@ class TextPart(pydantic.BaseModel): class ToolPart(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) tool_call_id: str tool_name: str tool_args: str @@ -29,18 +42,27 @@ class ToolPart(pydantic.BaseModel): state: PartState | None = None args_delta: str | None = None # Delta for tool_args - def set_result(self, result: Any) -> None: - """Set the tool result and mark as completed.""" - self.status = "result" - self.result = result + def with_result(self, result: Any) -> ToolPart: + """Return a copy with status='result' and the given result.""" + if self.status != "pending": + raise ValueError( + f"Tool call '{self.tool_call_id}' already has status '{self.status}'" + ) + return self.model_copy(update={"status": "result", "result": result}) - def set_error(self, message: str) -> None: - """Set a tool error and mark as failed.""" - self.status = "error" - self.result = message + def with_error(self, message: str) -> ToolPart: + """Return a copy with status='error' and the error message.""" + if self.status != "pending": + raise ValueError( + f"Tool call '{self.tool_call_id}' already has status '{self.status}'" + ) + return self.model_copy(update={"status": "error", "result": message}) class ReasoningPart(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) text: str type: Literal["reasoning"] = "reasoning" # Anthropic's thinking blocks include a signature for cache/verification. @@ -54,6 +76,9 @@ class ReasoningPart(pydantic.BaseModel): class HookPart(pydantic.BaseModel): """Part representing a hook suspension point in the agent's turn.""" + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) hook_id: str hook_type: str status: Literal[ @@ -94,6 +119,9 @@ class StructuredOutputPart(pydantic.BaseModel): Pydantic model can be lazily rehydrated via the ``value`` property. """ + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) data: dict[str, Any] output_type_name: str type: Literal["structured_output"] = "structured_output" @@ -123,6 +151,9 @@ class FilePart(pydantic.BaseModel): to JSON for providers that need it). """ + model_config = pydantic.ConfigDict(frozen=True) + + id: str = pydantic.Field(default_factory=generate_id) data: str | bytes media_type: str # IANA media type, e.g. "image/png", "audio/wav" filename: str | None = None @@ -138,7 +169,7 @@ def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: ``media_type`` is provided. """ if media_type is None: - from ..models.core.helpers import media as media_helpers + from . import media as media_helpers media_type = media_helpers.infer_media_type(url) return cls(data=url, media_type=media_type) @@ -158,7 +189,7 @@ def from_bytes( detection fails. """ if media_type is None: - from ..models.core.helpers import media as media_helpers + from . import media as media_helpers media_type = media_helpers.detect_image_media_type( data @@ -184,6 +215,8 @@ class Usage(pydantic.BaseModel): can distinguish "not reported" from "zero tokens used". """ + model_config = pydantic.ConfigDict(frozen=True) + input_tokens: int = 0 output_tokens: int = 0 @@ -226,23 +259,62 @@ def _add_optional(a: int | None, b: int | None) -> int | None: ) -def _gen_id() -> str: - return uuid.uuid4().hex[:12] - - class ToolDelta(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + tool_call_id: str tool_name: str args_delta: str class Message(pydantic.BaseModel): + model_config = pydantic.ConfigDict(frozen=True) + role: Literal["user", "assistant", "system"] parts: list[Part] - id: str = pydantic.Field(default_factory=_gen_id) + id: str = pydantic.Field(default_factory=generate_id) label: str | None = None usage: Usage | None = None + @overload + def replace(self, new: Part, /) -> Message: ... + @overload + def replace(self, old: Part, new: Part, /) -> Message: ... + def replace(self, *args: Part) -> Message: + """Return a copy with a part replaced. + + Single arg: ``msg.replace(updated_part)`` — matches by ``id``. + Two args: ``msg.replace(old, new)`` — matches by identity. + + Raises ValueError if the target part is not found. + """ + if len(args) == 1: + (new,) = args + match_id: str | None = new.id + match_ref = None + elif len(args) == 2: + old, new = args + match_id = None + match_ref = old + else: + raise TypeError(f"replace() takes 1 or 2 arguments ({len(args)} given)") + found = False + new_parts: list[Part] = [] + for p in self.parts: + if not found and ( + (match_id is not None and p.id == match_id) + or (match_ref is not None and p is match_ref) + ): + found = True + new_parts.append(new) + else: + new_parts.append(p) + if not found: + if match_id is not None: + raise ValueError(f"No part with id '{match_id}' in message") + raise ValueError("Part not found in message") + return self.model_copy(update={"parts": new_parts}) + @property def output(self) -> Any: """Return the validated structured output, or None.""" @@ -335,12 +407,6 @@ def tool_calls(self) -> list[ToolPart]: # TODO properly validate args? return [part for part in self.parts if isinstance(part, ToolPart)] - def get_tool_part(self, tool_call_id: str) -> ToolPart | None: - for part in self.parts: - if isinstance(part, ToolPart) and part.tool_call_id == tool_call_id: - return part - return None - def get_hook_part(self, hook_id: str | None = None) -> HookPart | None: """Find a HookPart by hook_id, or return the first HookPart if no id given.""" for part in self.parts: diff --git a/tests/adapters/ai_sdk_ui/test_adapter.py b/tests/adapters/ai_sdk_ui/test_adapter.py index 1bbe2756..86e2a415 100644 --- a/tests/adapters/ai_sdk_ui/test_adapter.py +++ b/tests/adapters/ai_sdk_ui/test_adapter.py @@ -654,15 +654,14 @@ async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: last_msg = result.last_message assert last_msg is not None - async def approve_and_execute(tc: ai.ToolPart) -> None: + async def approve_and_execute(tc: ai.ToolPart) -> ai.ToolPart: approval = await ai.ToolApproval.create( # type: ignore[attr-defined] f"approve_{tc.tool_call_id}", metadata={"tool_name": tc.tool_name}, ) if approval.granted: - await ai.execute_tool(tc, message=last_msg) - else: - tc.set_error("denied") + return await ai.execute_tool(tc, message=last_msg) + return tc.with_error("denied") await asyncio.gather(*(approve_and_execute(tc) for tc in result.tool_calls)) diff --git a/tests/agents/test_runtime.py b/tests/agents/test_runtime.py index 6eb5b0be..8cd2e8cd 100644 --- a/tests/agents/test_runtime.py +++ b/tests/agents/test_runtime.py @@ -187,12 +187,12 @@ async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: assert isinstance(received_rt, Runtime) -# -- execute_tool: result updates ToolPart in message ---------------------- +# -- execute_tool: returns updated ToolPart -------------------------------- @pytest.mark.asyncio -async def test_execute_tool_updates_message() -> None: - """After execute_tool, the ToolPart in the message has status=result.""" +async def test_execute_tool_returns_updated_part() -> None: + """execute_tool returns an updated ToolPart; the original is unchanged.""" my_agent = ai.agent(model=MOCK_MODEL, tools=[double]) @my_agent.loop @@ -200,12 +200,14 @@ async def custom(agent: ai.Agent, msgs: list[ai.Message]) -> None: result = await ai.stream_step(agent.model, msgs, agent.tools) if result.tool_calls: msg = result.last_message - for tc in result.tool_calls: - await ai.execute_tool(tc, message=msg) - # Verify the tool part was mutated assert msg is not None - assert msg.tool_calls[0].status == "result" - assert msg.tool_calls[0].result == 10 + for tc in result.tool_calls: + updated_tc = await ai.execute_tool(tc, message=msg) + # Returned part has the result + assert updated_tc.status == "result" + assert updated_tc.result == 10 + # Original message is unchanged (immutable) + assert msg.tool_calls[0].status == "pending" call = [tool_msg(tc_id="tc-1", name="double", args='{"x": 5}')] mock_llm([call]) diff --git a/tests/models/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py index 06dc6b91..d8208b33 100644 --- a/tests/models/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -130,7 +130,7 @@ def handler(req: httpx.Request) -> httpx.Response: client = _client(httpx.MockTransport(handler)) with patch( - "vercel_ai_sdk.models.core.helpers.media.download", + "vercel_ai_sdk.models.core.helpers.files.download", new_callable=AsyncMock, return_value=(_MP4_HEADER, "video/mp4"), ) as mock_dl: diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index 512c83e3..8695b3ec 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -137,7 +137,7 @@ async def test_user_message_with_image_url(self) -> None: ) ] with patch( - "vercel_ai_sdk.models.core.helpers.media.download", + "vercel_ai_sdk.models.core.helpers.files.download", new_callable=AsyncMock, return_value=(fake_jpeg, "image/jpeg"), ): diff --git a/tests/models/core/test_media.py b/tests/models/core/test_media.py index eb77c96a..b18fa64b 100644 --- a/tests/models/core/test_media.py +++ b/tests/models/core/test_media.py @@ -9,7 +9,7 @@ import base64 -from vercel_ai_sdk.models.core.helpers.media import ( +from vercel_ai_sdk.types.media import ( data_to_base64, data_to_data_url, detect_audio_media_type, diff --git a/tests/types/test_builders.py b/tests/types/test_builders.py new file mode 100644 index 00000000..00dc6557 --- /dev/null +++ b/tests/types/test_builders.py @@ -0,0 +1,140 @@ +"""Tests for message builder helpers.""" + +import pytest + +from vercel_ai_sdk.types.builders import ( + assistant_message, + file_part, + system_message, + thinking, + user_message, +) +from vercel_ai_sdk.types.messages import ( + FilePart, + ReasoningPart, + TextPart, + ToolPart, +) + +# -- system_message -------------------------------------------------------- + + +def test_system_message_from_string() -> None: + msg = system_message("You are helpful.") + assert msg.role == "system" + assert len(msg.parts) == 1 + assert isinstance(msg.parts[0], TextPart) + assert msg.parts[0].text == "You are helpful." + + +def test_system_message_empty() -> None: + msg = system_message() + assert msg.role == "system" + assert msg.parts == [] + + +# -- user_message ---------------------------------------------------------- + + +def test_user_message_single_string() -> None: + msg = user_message("Hello") + assert msg.role == "user" + assert len(msg.parts) == 1 + assert msg.parts[0].text == "Hello" # type: ignore[union-attr] + + +def test_user_message_multiple_strings_stay_separate() -> None: + msg = user_message("foo", "bar") + assert len(msg.parts) == 2 + assert msg.parts[0].text == "foo" # type: ignore[union-attr] + assert msg.parts[1].text == "bar" # type: ignore[union-attr] + + +def test_user_message_mixed_content() -> None: + fp = FilePart(data="https://example.com/img.png", media_type="image/png") + msg = user_message("Describe this:", fp, "Thanks") + assert len(msg.parts) == 3 + assert isinstance(msg.parts[0], TextPart) + assert isinstance(msg.parts[1], FilePart) + assert isinstance(msg.parts[2], TextPart) + + +def test_user_message_part_passthrough() -> None: + tp = TextPart(text="already a part") + msg = user_message(tp) + assert msg.parts[0] is tp + + +# -- assistant_message ----------------------------------------------------- + + +def test_assistant_message_with_thinking() -> None: + r = thinking("hmm let me think") + msg = assistant_message(r, "Here's my answer.") + assert msg.role == "assistant" + assert len(msg.parts) == 2 + assert isinstance(msg.parts[0], ReasoningPart) + assert isinstance(msg.parts[1], TextPart) + + +def test_assistant_message_with_tool_part() -> None: + tool = ToolPart(tool_call_id="tc-1", tool_name="test", tool_args="{}") + msg = assistant_message("calling tool", tool) + assert len(msg.parts) == 2 + assert isinstance(msg.parts[1], ToolPart) + + +# -- file_part ------------------------------------------------------------- + + +def test_file_part_from_url() -> None: + fp = file_part("https://example.com/image.png") + assert isinstance(fp, FilePart) + assert fp.data == "https://example.com/image.png" + assert fp.media_type == "image/png" + + +def test_file_part_from_bytes_with_explicit_media_type() -> None: + fp = file_part(b"\x00\x00", media_type="application/octet-stream") + assert isinstance(fp, FilePart) + assert fp.media_type == "application/octet-stream" + + +def test_file_part_from_bytes_auto_detect_png() -> None: + # PNG magic bytes + png_header = b"\x89PNG\r\n\x1a\n" + b"\x00" * 100 + fp = file_part(png_header) + assert fp.media_type == "image/png" + + +def test_file_part_from_bytes_unknown_raises() -> None: + with pytest.raises(ValueError, match="Cannot detect media_type"): + file_part(b"\x00\x00\x00") + + +def test_file_part_with_filename() -> None: + fp = file_part(b"\x89PNG\r\n\x1a\n" + b"\x00" * 100, filename="cat.png") + assert fp.filename == "cat.png" + + +# -- thinking -------------------------------------------------------------- + + +def test_thinking_basic() -> None: + r = thinking("deep thoughts") + assert isinstance(r, ReasoningPart) + assert r.text == "deep thoughts" + assert r.signature is None + + +def test_thinking_with_signature() -> None: + r = thinking("deep thoughts", signature="sig123") + assert r.signature == "sig123" + + +# -- type coercion edge cases ---------------------------------------------- + + +def test_invalid_type_raises() -> None: + with pytest.raises(TypeError): + user_message(42) # type: ignore[arg-type] diff --git a/tests/types/test_messages.py b/tests/types/test_messages.py index 4c8de27b..791c89ba 100644 --- a/tests/types/test_messages.py +++ b/tests/types/test_messages.py @@ -1,5 +1,5 @@ -"""Message model: properties, ToolPart.set_result/set_error, make_messages, -StructuredOutputPart, FilePart.""" +"""Message model: properties, immutability, part IDs, ToolPart.with_result/with_error, +Message.replace, make_messages, StructuredOutputPart, FilePart.""" import pydantic import pytest @@ -134,7 +134,7 @@ def test_tool_deltas() -> None: assert deltas[0].args_delta == '"te' -# -- tool_calls / get_tool_part ------------------------------------------- +# -- tool_calls ------------------------------------------------------------ def test_tool_calls() -> None: @@ -151,22 +151,6 @@ def test_tool_calls() -> None: assert m.tool_calls[0].tool_call_id == "tc1" -def test_get_tool_part_found() -> None: - m = Message( - id="m1", - role="assistant", - parts=[ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}")], - ) - tp = m.get_tool_part("tc1") - assert tp is not None - assert tp.tool_name == "t" - - -def test_get_tool_part_missing() -> None: - m = Message(id="m1", role="assistant", parts=[TextPart(text="no tools")]) - assert m.get_tool_part("tc-nope") is None - - # -- get_hook_part --------------------------------------------------------- @@ -193,25 +177,126 @@ def test_get_hook_part_missing() -> None: assert m.get_hook_part("h-nope") is None -# -- ToolPart.set_result / set_error --------------------------------------- +# -- Immutability ---------------------------------------------------------- -def test_set_result() -> None: +def test_frozen_rejects_field_mutation() -> None: + """Frozen models reject direct attribute assignment.""" tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") + with pytest.raises(pydantic.ValidationError): + tp.status = "result" # type: ignore[misc] + + m = Message(id="m1", role="assistant", parts=[TextPart(text="hi")]) + with pytest.raises(pydantic.ValidationError): + m.label = "test" # type: ignore[misc] + + +# -- ToolPart.with_result / with_error ------------------------------------ + + +def test_with_result() -> None: + tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") + assert tp.status == "pending" + updated = tp.with_result({"answer": 42}) + assert updated.status == "result" + assert updated.result == {"answer": 42} + assert updated.id == tp.id # id preserved + # Original unchanged assert tp.status == "pending" - tp.set_result({"answer": 42}) - # mypy narrows status to Literal["pending"] from the constructor default and - # can't track that set_result() mutates it to "result" - assert tp.status == "result" # type: ignore[comparison-overlap] - assert tp.result == {"answer": 42} + assert tp.result is None -def test_set_error() -> None: +def test_with_error() -> None: tp = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") + updated = tp.with_error("Something went wrong") + assert updated.status == "error" + assert updated.result == "Something went wrong" + assert updated.id == tp.id # id preserved + # Original unchanged assert tp.status == "pending" - tp.set_error("Something went wrong") - assert tp.status == "error" # type: ignore[comparison-overlap] - assert tp.result == "Something went wrong" + + +def test_with_result_rejects_non_pending() -> None: + """with_result / with_error reject non-pending tool calls.""" + tp = ToolPart( + tool_call_id="tc1", tool_name="t", tool_args="{}", status="result", result=42 + ) + with pytest.raises(ValueError, match="already has status"): + tp.with_result("new result") + with pytest.raises(ValueError, match="already has status"): + tp.with_error("oops") + + +# -- Part ids -------------------------------------------------------------- + + +def test_parts_have_auto_generated_ids() -> None: + """All parts get an auto-generated id.""" + text = TextPart(text="hi") + tool = ToolPart(tool_call_id="tc1", tool_name="t", tool_args="{}") + reasoning = ReasoningPart(text="thinking") + assert text.id # non-empty + assert tool.id + assert reasoning.id + # All different + assert len({text.id, tool.id, reasoning.id}) == 3 + + +def test_part_id_explicit() -> None: + """Parts accept an explicit id.""" + tp = TextPart(id="my-id", text="hi") + assert tp.id == "my-id" + + +# -- Message.replace ------------------------------------------------------- + + +def test_replace() -> None: + tp = ToolPart(id="p1", tool_call_id="tc1", tool_name="t", tool_args="{}") + m = Message( + id="m1", + role="assistant", + parts=[TextPart(id="p0", text="hi"), tp], + ) + updated_tp = tp.with_result({"answer": 42}) + updated_m = m.replace(updated_tp) + tc = next(p for p in updated_m.parts if isinstance(p, ToolPart)) + assert tc.status == "result" + assert tc.result == {"answer": 42} + # Original unchanged + orig_tc = next(p for p in m.parts if isinstance(p, ToolPart)) + assert orig_tc.status == "pending" + + +def test_replace_missing_id() -> None: + m = Message(id="m1", role="assistant", parts=[TextPart(id="p0", text="hi")]) + orphan = TextPart(id="no-such-id", text="x") + with pytest.raises(ValueError, match="in message"): + m.replace(orphan) + + +def test_replace_two_arg() -> None: + """replace(old, new) matches by identity, ignores id.""" + old_text = TextPart(id="p0", text="hello") + m = Message(id="m1", role="assistant", parts=[old_text]) + # new part has a different id — doesn't matter, old is matched by identity + new_text = TextPart(id="different", text="world") + updated = m.replace(old_text, new_text) + part = updated.parts[0] + assert isinstance(part, TextPart) + assert part.text == "world" + assert part.id == "different" + # Original unchanged + orig = m.parts[0] + assert isinstance(orig, TextPart) + assert orig.text == "hello" + + +def test_replace_two_arg_missing() -> None: + m = Message(id="m1", role="assistant", parts=[TextPart(id="p0", text="hi")]) + stranger = TextPart(id="p0", text="hi") # same content, different object + with pytest.raises(ValueError, match="not found in message"): + m.replace(stranger, TextPart(text="new")) # -- make_messages ---------------------------------------------------------