diff --git a/.gitignore b/.gitignore index 2e7fab51..57de6beb 100644 --- a/.gitignore +++ b/.gitignore @@ -210,3 +210,5 @@ __marimo__/ .reference .vercel .env*.local + +.DS_Store diff --git a/examples/samples/media/image_edit.py b/examples/samples/media/image_edit.py new file mode 100644 index 00000000..38b0218d --- /dev/null +++ b/examples/samples/media/image_edit.py @@ -0,0 +1,63 @@ +"""Image editing with a dedicated image model. + +Demonstrates sending an input image to be edited/transformed by the +image model. The input image is passed as a FilePart in the user +message, and the model returns the edited version. + +Usage: + uv run examples/samples/media/image_edit.py +""" + +import asyncio +import base64 +import pathlib + +import vercel_ai_sdk as ai + + +async def main() -> None: + model = ai.ai_gateway.GatewayImageModel( + model="openai/gpt-image-1", + ) + + # Load an existing image to use as input for editing. + # In practice you would load a real image file: + # image_data = pathlib.Path("my_photo.png").read_bytes() + # input_image = ai.FilePart.from_bytes(image_data, media_type="image/png") + input_image = ai.FilePart.from_url( + "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg", + media_type="image/jpeg", + ) + + # Ask the model to transform the photo into anime style + msg = await model.generate( + [ + ai.Message( + role="user", + parts=[ + ai.TextPart( + text=( + "Transform this photo into a soft watercolor " + "anime style. Turn the cat into an anime girl " + "with cat ears and a tail, sitting in the same " + "pose. Add cherry blossom petals falling gently " + "in the background." + ) + ), + input_image, + ], + ) + ], + size="1024x1024", + ) + + print(f"Generated {len(msg.images)} edited image(s)") + for i, img in enumerate(msg.images): + filename = f"catgirl_edit_{i}.png" + data = img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {img.media_type}, {len(data)} bytes") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/samples/media/image_gen_dedicated.py b/examples/samples/media/image_gen_dedicated.py new file mode 100644 index 00000000..b394c670 --- /dev/null +++ b/examples/samples/media/image_gen_dedicated.py @@ -0,0 +1,53 @@ +"""Dedicated image generation model (Imagen 4). + +Uses the ImageModel interface to generate images via the AI Gateway's +/image-model endpoint. Unlike language models, dedicated image models +are optimized purely for image generation with parameters like size, +aspect ratio, and seed. + +Usage: + uv run examples/samples/media/image_gen_dedicated.py +""" + +import asyncio +import base64 +import pathlib + +import vercel_ai_sdk as ai + + +async def main() -> None: + model = ai.ai_gateway.GatewayImageModel( + model="google/imagen-4.0-generate-001", + ) + + # Generate two images of an anime girl character + msg = await model.generate( + ai.make_messages( + user=( + "Anime girl with twin tails and cat ears, wearing a " + "sailor school uniform, striking a victory pose in front " + "of a futuristic Tokyo skyline at night, neon lights " + "reflecting in her eyes, digital art style" + ), + ), + n=2, + aspect_ratio="16:9", + ) + + print(f"Generated {len(msg.images)} images") + for i, img in enumerate(msg.images): + filename = f"neko_girl_{i}.png" + data = img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {img.media_type}, {len(data)} bytes") + + if msg.usage: + print( + f"Usage: {msg.usage.input_tokens} input, " + f"{msg.usage.output_tokens} output tokens" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/samples/media/image_gen_inline.py b/examples/samples/media/image_gen_inline.py new file mode 100644 index 00000000..c23b94fc --- /dev/null +++ b/examples/samples/media/image_gen_inline.py @@ -0,0 +1,63 @@ +"""Inline image generation via a language model (Gemini 3 Pro Image). + +Models like Gemini 3 Pro Image and GPT-5 can generate images alongside +text as part of their language model response. The images arrive as +FileParts in the streamed Message. + +Usage: + uv run examples/samples/media/image_gen_inline.py +""" + +import asyncio +import base64 +import pathlib + +import vercel_ai_sdk as ai + + +async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: + return await ai.stream_loop( + llm, + messages=ai.make_messages( + system=( + "You are an anime art assistant. When asked to draw or create " + "an image, generate it in a soft pastel anime style with " + "detailed backgrounds and expressive characters." + ), + user=user_query, + ), + tools=[], + ) + + +async def main() -> None: + # Gemini 3 Pro Image is a language model that can output images inline + llm = ai.ai_gateway.GatewayModel(model="google/gemini-3-pro-image") + + prompt = ( + "Draw an anime girl with long silver hair and violet eyes, " + "sitting in a field of cherry blossoms at sunset. " + "She's wearing a traditional kimono and reading a book." + ) + + async for msg in ai.run(agent, llm, prompt): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + + print() + + # The final message may contain both text and images + if msg.images: + for i, img in enumerate(msg.images): + filename = f"sakura_girl_{i}.png" + data = ( + img.data if isinstance(img.data, bytes) else base64.b64decode(img.data) + ) + pathlib.Path(filename).write_bytes(data) + print(f"Saved {filename} ({img.media_type}, {len(data)} bytes)") + else: + print("No images were generated in this response.") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/samples/media/multimodal.py b/examples/samples/media/multimodal.py new file mode 100644 index 00000000..cad74e55 --- /dev/null +++ b/examples/samples/media/multimodal.py @@ -0,0 +1,42 @@ +"""Multimodal input example: send an image URL to the model. + +Usage: + uv run examples/samples/media/multimodal.py +""" + +import asyncio + +import vercel_ai_sdk as ai + +IMAGE_URL = ( + "https://4kwallpapers.com/images/wallpapers/hatsune-miku-3840x2160-15479.jpg" +) + + +async def agent(llm: ai.LanguageModel, user_query: str) -> ai.StreamResult: + return await ai.stream_loop( + llm, + messages=[ + ai.Message( + role="user", + parts=[ + ai.TextPart(text=user_query), + ai.FilePart.from_url(IMAGE_URL), + ], + ) + ], + tools=[], + ) + + +async def main() -> None: + llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-opus-4.6") + + async for msg in ai.run(agent, llm, "What's in this image? Be concise."): + if msg.text_delta: + print(msg.text_delta, end="", flush=True) + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/samples/media/video_gen.py b/examples/samples/media/video_gen.py new file mode 100644 index 00000000..17b94875 --- /dev/null +++ b/examples/samples/media/video_gen.py @@ -0,0 +1,50 @@ +"""Video generation with Veo 3. + +Uses the VideoModel interface to generate videos via the AI Gateway's +/video-model endpoint. The gateway handles the long-running generation +process (which can take minutes) and returns the result via SSE. + +Usage: + uv run examples/samples/media/video_gen.py +""" + +import asyncio +import base64 +import pathlib + +import vercel_ai_sdk as ai + + +async def main() -> None: + model = ai.ai_gateway.GatewayVideoModel( + model="google/veo-3.0-generate-001", + ) + + # Generate a short anime-style video clip + print("Generating video (this may take a minute or two)...") + msg = await model.generate( + ai.make_messages( + user=( + "An anime girl with long pink hair and a flowing white " + "dress stands on a hilltop at golden hour. A warm breeze " + "lifts her hair as she releases a paper lantern into the " + "sunset sky. The camera slowly pulls back to reveal dozens " + "of lanterns rising over a countryside village below. " + "Soft cel-shaded anime art style, warm palette." + ), + ), + aspect_ratio="16:9", + duration=8, + ) + + print(f"Generated {len(msg.videos)} video(s)") + for i, vid in enumerate(msg.videos): + ext = "mp4" if "mp4" in vid.media_type else "webm" + filename = f"lantern_girl_{i}.{ext}" + data = vid.data if isinstance(vid.data, bytes) else base64.b64decode(vid.data) + pathlib.Path(filename).write_bytes(data) + print(f" {filename}: {vid.media_type}, {len(data)} bytes") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/vercel_ai_sdk/__init__.py b/src/vercel_ai_sdk/__init__.py index 3d175a1f..2bb34657 100644 --- a/src/vercel_ai_sdk/__init__.py +++ b/src/vercel_ai_sdk/__init__.py @@ -3,9 +3,11 @@ from .core.checkpoint import Checkpoint, PendingHookInfo from .core.hooks import Hook, ToolApproval, hook from .core.llm import LanguageModel +from .core.media import ImageModel, MediaModel, MediaResult, VideoModel # Re-export core types from .core.messages import ( + FilePart, HookPart, Message, Part, @@ -40,11 +42,16 @@ "ToolPart", "ToolDelta", "ReasoningPart", + "FilePart", "ToolLike", "ToolSchema", "Tool", "Usage", "LanguageModel", + "MediaModel", + "MediaResult", + "ImageModel", + "VideoModel", "Runtime", "RunResult", "HookInfo", diff --git a/src/vercel_ai_sdk/ai_gateway/__init__.py b/src/vercel_ai_sdk/ai_gateway/__init__.py index 8939f87c..6cdedbd3 100644 --- a/src/vercel_ai_sdk/ai_gateway/__init__.py +++ b/src/vercel_ai_sdk/ai_gateway/__init__.py @@ -1,46 +1,60 @@ -from __future__ import annotations +"""Vercel AI Gateway provider using the v3 protocol. +Communicates directly with the gateway using the AI SDK's native wire +formats. The gateway server handles all provider-specific translation. + +Usage:: + + import vercel_ai_sdk as ai + + # Language model + llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-sonnet-4") + + # Image model + img = ai.ai_gateway.GatewayImageModel(model="google/imagen-4.0-generate-001") + msg = await img.generate(ai.make_messages(user="A sunset over Tokyo")) + + # Video model + vid = ai.ai_gateway.GatewayVideoModel(model="google/veo-3.0-generate-001") + msg = await vid.generate(ai.make_messages(user="A cat on a beach")) +""" + +import base64 +import json import os from collections.abc import AsyncGenerator, Sequence -from typing import override +from typing import Any, override +import httpx import pydantic from .. import core -from ..anthropic import AnthropicModel -from ..openai import OpenAIModel +from ..core.media import data as media_data +from ..core.media import detect_media_type +from ..core.media import models as media_models +from . import errors as errors_ +from . import protocol as protocol_ -_DEFAULT_BASE_URL = "https://ai-gateway.vercel.sh" +_DEFAULT_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" +_PROTOCOL_VERSION = "0.0.1" class GatewayModel(core.llm.LanguageModel): - """Vercel AI Gateway provider. - - Pre-configured for the Vercel AI Gateway with automatic routing: - Anthropic models use the native Anthropic API through the gateway, - except when structured output is requested (which requires the - OpenAI-compatible endpoint). All other models use the - OpenAI-compatible endpoint. + """Vercel AI Gateway language model using the v3 protocol. - Usage:: - - import vercel_ai_sdk as ai - - llm = ai.ai_gateway.GatewayModel(model="anthropic/claude-sonnet-4") + Sends the AI SDK's native message format directly to the gateway + server and receives responses in the AI SDK's native stream-part + format. The gateway server handles all provider-specific + translation. Args: model: Model identifier in ``provider/model`` format - (e.g., ``'anthropic/claude-sonnet-4'``, ``'openai/gpt-4.1'``) - api_key: API key for the gateway. Falls back to the - ``AI_GATEWAY_API_KEY`` environment variable. - base_url: Gateway base URL. Defaults to - ``https://ai-gateway.vercel.sh``. - thinking: Enable reasoning/thinking output. - budget_tokens: Max tokens for reasoning - (mutually exclusive with *reasoning_effort*). - reasoning_effort: Effort level for reasoning — ``'none'``, - ``'minimal'``, ``'low'``, ``'medium'``, ``'high'``, ``'xhigh'`` - (mutually exclusive with *budget_tokens*; OpenAI models only). + (e.g. ``'anthropic/claude-sonnet-4'``). + api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. + base_url: Gateway base URL. + provider_options: Gateway options (``order``, ``only``, + ``models``, ``byok``, ``tags``, etc.). + headers: Extra headers for every request. """ def __init__( @@ -48,54 +62,106 @@ def __init__( model: str = "anthropic/claude-sonnet-4", api_key: str | None = None, base_url: str = _DEFAULT_BASE_URL, - thinking: bool = False, - budget_tokens: int | None = None, - reasoning_effort: str | None = None, + provider_options: dict[str, Any] | None = None, + headers: dict[str, str] | None = None, + *, + _transport: httpx.AsyncBaseTransport | None = None, ) -> None: self._model = model self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" self._base_url = base_url.rstrip("/") - self._thinking = thinking - self._budget_tokens = budget_tokens - self._reasoning_effort = reasoning_effort - - def _is_anthropic_model(self) -> bool: - return self._model.startswith("anthropic/") - - def _make_openai(self) -> OpenAIModel: - return OpenAIModel( - model=self._model, - base_url=f"{self._base_url}/v1", - api_key=self._api_key, - thinking=self._thinking, - budget_tokens=self._budget_tokens, - reasoning_effort=self._reasoning_effort, + self._provider_options = provider_options + self._extra_headers = headers or {} + self._transport = _transport + + # -- Internals ----------------------------------------------------------- + + def _headers(self, *, streaming: bool) -> dict[str, str]: + h: dict[str, str] = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self._api_key}", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + "ai-language-model-specification-version": "3", + "ai-language-model-id": self._model, + "ai-language-model-streaming": str(streaming).lower(), + } + if self._api_key: + h["ai-gateway-auth-method"] = "api-key" + h.update(self._extra_headers) + return h + + async def _raise_for_status(self, response: httpx.Response) -> None: + """Raise a typed :class:`GatewayError` for HTTP >= 400.""" + try: + body: Any = response.json() + except Exception: + body = response.text + raise errors_.create_gateway_error( + response_body=body, + status_code=response.status_code, + api_key_provided=bool(self._api_key), ) - def _make_anthropic(self) -> AnthropicModel: - return AnthropicModel( - model=self._model, - base_url=self._base_url, - api_key=self._api_key, - thinking=self._thinking, - budget_tokens=self._budget_tokens or 10000, + # -- Stream events ------------------------------------------------------- + + async def stream_events( + self, + messages: list[core.messages.Message], + tools: Sequence[core.tools.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + ) -> AsyncGenerator[core.llm.StreamEvent]: + """Yield ``StreamEvent`` objects from the gateway SSE stream.""" + body = await protocol_.build_request_body( + messages, + tools=tools, + output_type=output_type, + provider_options=self._provider_options, ) + url = f"{self._base_url}/language-model" + try: + async with ( + httpx.AsyncClient(transport=self._transport) as client, + client.stream( + "POST", + url, + json=body, + headers=self._headers(streaming=True), + timeout=httpx.Timeout(timeout=300.0, connect=10.0), + ) as response, + ): + if response.status_code >= 400: + await response.aread() + await self._raise_for_status(response) - def _resolve( - self, output_type: type[pydantic.BaseModel] | None - ) -> core.llm.LanguageModel: - """Pick delegate based on model and feature requirements. - - - Anthropic models without structured output use the native - Anthropic API (richer reasoning support, native tool format). - - Anthropic models *with* structured output use OpenAI-compat - (structured output via the Anthropic-native gateway endpoint - is not currently supported). - - All other models use OpenAI-compat. - """ - if self._is_anthropic_model() and output_type is None: - return self._make_anthropic() - return self._make_openai() + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + data = json.loads(payload) + except json.JSONDecodeError: + continue + for event in protocol_.parse_stream_part(data): + yield event + + except errors_.GatewayError: + raise + except httpx.TimeoutException as exc: + raise errors_.GatewayTimeoutError( + cause=exc, + ) from exc + except Exception as exc: + raise errors_.GatewayResponseError( + message=( + f"Invalid error response format: Gateway request failed: {exc}" + ), + cause=exc, + ) from exc + + # -- LanguageModel interface --------------------------------------------- @override async def stream( @@ -104,6 +170,385 @@ async def stream( tools: Sequence[core.tools.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, ) -> AsyncGenerator[core.messages.Message]: - delegate = self._resolve(output_type) - async for msg in delegate.stream(messages, tools, output_type): + handler = core.llm.StreamHandler() + msg: core.messages.Message | None = None + async for event in self.stream_events(messages, tools, output_type): + msg = handler.handle_event(event) yield msg + + if output_type is not None and msg is not None and msg.text: + data = json.loads(msg.text) + output_type.model_validate(data) + part = core.messages.StructuredOutputPart( + data=data, + output_type_name=( + f"{output_type.__module__}.{output_type.__qualname__}" + ), + ) + msg = msg.model_copy() + msg.parts = [*msg.parts, part] + yield msg + + +# --------------------------------------------------------------------------- +# Shared helpers for image/video models +# --------------------------------------------------------------------------- + + +def _base_headers(api_key: str, extra: dict[str, str]) -> dict[str, str]: + """Build common gateway headers.""" + h: dict[str, str] = { + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + if api_key: + h["ai-gateway-auth-method"] = "api-key" + h.update(extra) + return h + + +async def _raise_for_status(response: httpx.Response, *, api_key: str) -> None: + """Raise a typed :class:`GatewayError` for HTTP >= 400.""" + try: + body: Any = response.json() + except Exception: + body = response.text + raise errors_.create_gateway_error( + response_body=body, + status_code=response.status_code, + api_key_provided=bool(api_key), + ) + + +def _file_part_to_wire(part: core.messages.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to the gateway wire format for input files.""" + data = part.data + if isinstance(data, str) and media_data.is_url(data): + return {"type": "url", "url": data} + if isinstance(data, bytes): + b64 = base64.b64encode(data).decode("ascii") + elif isinstance(data, str): + # Assume raw base64 + b64 = data + else: + b64 = str(data) + return {"type": "file", "data": b64, "mediaType": part.media_type} + + +# --------------------------------------------------------------------------- +# GatewayImageModel +# --------------------------------------------------------------------------- + + +class GatewayImageModel(media_models.ImageModel): + """Vercel AI Gateway image model. + + Sends requests to ``/v3/ai/image-model`` and returns a :class:`Message` + with :class:`FilePart`\\s for each generated image. + + Args: + model: Model identifier (e.g. ``'google/imagen-4.0-generate-001'``). + api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. + base_url: Gateway base URL. + headers: Extra headers for every request. + """ + + def __init__( + self, + model: str = "google/imagen-4.0-generate-001", + api_key: str | None = None, + base_url: str = _DEFAULT_BASE_URL, + headers: dict[str, str] | None = None, + *, + _transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self._model = model + self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" + self._base_url = base_url.rstrip("/") + self._extra_headers = headers or {} + self._transport = _transport + + def _headers(self) -> dict[str, str]: + return _base_headers( + self._api_key, + { + "ai-image-model-specification-version": "3", + "ai-model-id": self._model, + **self._extra_headers, + }, + ) + + @override + async def make_request( + self, + prompt: str, + input_files: list[core.messages.FilePart], + *, + n: int = 1, + size: str | None = None, + aspect_ratio: str | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> media_models.MediaResult: + body: dict[str, Any] = { + "prompt": prompt, + "n": n, + "providerOptions": provider_options or {}, + } + if size is not None: + body["size"] = size + if aspect_ratio is not None: + body["aspectRatio"] = aspect_ratio + if seed is not None: + body["seed"] = seed + if input_files: + body["files"] = [_file_part_to_wire(f) for f in input_files] + + url = f"{self._base_url}/image-model" + try: + async with httpx.AsyncClient(transport=self._transport) as client: + response = await client.post( + url, + json=body, + headers=self._headers(), + timeout=httpx.Timeout(timeout=300.0, connect=10.0), + ) + if response.status_code >= 400: + await _raise_for_status(response, api_key=self._api_key) + + data = response.json() + + except errors_.GatewayError: + raise + except httpx.TimeoutException as exc: + raise errors_.GatewayTimeoutError(cause=exc) from exc + except Exception as exc: + raise errors_.GatewayResponseError( + message=f"Gateway image request failed: {exc}", + cause=exc, + ) from exc + + # Parse response: {images: string[], warnings?, usage?} + raw_images: list[str] = data.get("images", []) + usage_data = data.get("usage") + usage = None + if usage_data: + usage = core.messages.Usage( + input_tokens=usage_data.get("inputTokens") or 0, + output_tokens=usage_data.get("outputTokens") or 0, + ) + + files: list[core.messages.FilePart] = [] + for img_b64 in raw_images: + media_type = detect_media_type.detect_image_media_type(img_b64) + files.append( + core.messages.FilePart( + data=img_b64, + media_type=media_type or "image/png", + ) + ) + + return media_models.MediaResult(files=files, usage=usage) + + +# --------------------------------------------------------------------------- +# GatewayVideoModel +# --------------------------------------------------------------------------- + + +class GatewayVideoModel(media_models.VideoModel): + """Vercel AI Gateway video model. + + Sends requests to ``/v3/ai/video-model`` (with SSE response) and returns + a :class:`Message` with :class:`FilePart`\\s for each generated video. + + Args: + model: Model identifier (e.g. ``'google/veo-3.0-generate-001'``). + api_key: API key. Falls back to ``AI_GATEWAY_API_KEY``. + base_url: Gateway base URL. + headers: Extra headers for every request. + """ + + def __init__( + self, + model: str = "google/veo-3.0-generate-001", + api_key: str | None = None, + base_url: str = _DEFAULT_BASE_URL, + headers: dict[str, str] | None = None, + *, + _transport: httpx.AsyncBaseTransport | None = None, + ) -> None: + self._model = model + self._api_key = api_key or os.environ.get("AI_GATEWAY_API_KEY") or "" + self._base_url = base_url.rstrip("/") + self._extra_headers = headers or {} + self._transport = _transport + + def _headers(self) -> dict[str, str]: + return _base_headers( + self._api_key, + { + "ai-video-model-specification-version": "3", + "ai-model-id": self._model, + "accept": "text/event-stream", + **self._extra_headers, + }, + ) + + @override + async def make_request( + self, + prompt: str, + input_files: list[core.messages.FilePart], + *, + n: int = 1, + aspect_ratio: str | None = None, + resolution: str | None = None, + duration: float | None = None, + fps: int | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> media_models.MediaResult: + image_wire: dict[str, Any] | None = None + if input_files: + image_wire = _file_part_to_wire(input_files[0]) + + body: dict[str, Any] = { + "prompt": prompt, + "n": n, + "providerOptions": provider_options or {}, + } + if aspect_ratio is not None: + body["aspectRatio"] = aspect_ratio + if resolution is not None: + body["resolution"] = resolution + if duration is not None: + body["duration"] = duration + if fps is not None: + body["fps"] = fps + if seed is not None: + body["seed"] = seed + if image_wire is not None: + body["image"] = image_wire + + url = f"{self._base_url}/video-model" + try: + async with ( + httpx.AsyncClient(transport=self._transport) as client, + client.stream( + "POST", + url, + json=body, + headers=self._headers(), + timeout=httpx.Timeout(timeout=600.0, connect=10.0), + ) as response, + ): + if response.status_code >= 400: + await response.aread() + await _raise_for_status(response, api_key=self._api_key) + + event_data = await self._read_first_sse_event(response) + + except errors_.GatewayError: + raise + except httpx.TimeoutException as exc: + raise errors_.GatewayTimeoutError(cause=exc) from exc + except Exception as exc: + raise errors_.GatewayResponseError( + message=f"Gateway video request failed: {exc}", + cause=exc, + ) from exc + + # Handle error event + if event_data.get("type") == "error": + status = event_data.get("statusCode", 500) + message = event_data.get("message", "Video generation failed") + error_type = event_data.get("errorType", "") + if status == 400 or error_type == "invalid_request_error": + raise errors_.GatewayInvalidRequestError( + message=message, status_code=status + ) + raise errors_.GatewayResponseError(message=message, status_code=status) + + # Handle result event + raw_videos: list[dict[str, Any]] = event_data.get("videos", []) + files: list[core.messages.FilePart] = [] + for video_data in raw_videos: + file_part = await self._video_data_to_file_part(video_data) + files.append(file_part) + + return media_models.MediaResult(files=files) + + @staticmethod + async def _read_first_sse_event(response: httpx.Response) -> dict[str, Any]: + """Read and parse the first SSE data event from the response.""" + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + result: dict[str, Any] = json.loads(payload) + return result + except json.JSONDecodeError: + continue + raise errors_.GatewayResponseError( + message="SSE stream ended without a data event", + ) + + @staticmethod + async def _video_data_to_file_part( + video_data: dict[str, Any], + ) -> core.messages.FilePart: + """Convert a gateway video result to a :class:`FilePart`. + + Handles ``{type: "url", url, mediaType}`` (downloads the video) + and ``{type: "base64", data, mediaType}``. + """ + vtype = video_data.get("type", "base64") + media_type = video_data.get("mediaType", "video/mp4") + + if vtype == "url": + video_url = video_data["url"] + downloaded_bytes, content_type = await core.media.download.download( + video_url + ) + # Prefer provider mediaType, then download content-type, then detect + if media_type == "video/mp4" and content_type: + media_type = content_type + detected = detect_media_type.detect_media_type( + downloaded_bytes, detect_media_type.VIDEO_SIGNATURES + ) + if detected: + media_type = detected + return core.messages.FilePart( + data=downloaded_bytes, + media_type=media_type, + ) + + # base64 + data = video_data.get("data", "") + detected = detect_media_type.detect_media_type( + data, detect_media_type.VIDEO_SIGNATURES + ) + if detected: + media_type = detected + return core.messages.FilePart( + data=data, + media_type=media_type, + ) + + +# --------------------------------------------------------------------------- +# Stubs for future model types +# --------------------------------------------------------------------------- + + +class GatewayEmbeddingModel: + """Stub -- not yet implemented.""" + + def __init__(self, model: str, **kwargs: Any) -> None: + raise NotImplementedError("GatewayEmbeddingModel is not yet implemented.") diff --git a/src/vercel_ai_sdk/ai_gateway/errors.py b/src/vercel_ai_sdk/ai_gateway/errors.py new file mode 100644 index 00000000..d0dade24 --- /dev/null +++ b/src/vercel_ai_sdk/ai_gateway/errors.py @@ -0,0 +1,305 @@ +"""Vercel AI Gateway error hierarchy. + +Maps HTTP error responses from the gateway server to typed Python exceptions. +Each error class corresponds to a specific ``error.type`` value in the +gateway's JSON error response format:: + + { + "error": { + "message": "...", + "type": "authentication_error" | "invalid_request_error" | ..., + "param": ..., + "code": ... + }, + "generationId": "..." + } +""" + +import json +from typing import Any, Self + +_KEY_URL = "https://vercel.com/d?to=%2F%5Bteam%5D%2F%7E%2Fai%2Fapi-keys" + + +# --------------------------------------------------------------------------- +# Base class +# --------------------------------------------------------------------------- + + +class GatewayError(Exception): + """Base class for all Vercel AI Gateway errors.""" + + type: str = "gateway_error" + + def __init__( + self, + message: str = "", + *, + status_code: int = 500, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + display = f"{message} [{generation_id}]" if generation_id else message + super().__init__(display) + self.status_code = status_code + self.generation_id = generation_id + if cause is not None: + self.__cause__ = cause + + +# --------------------------------------------------------------------------- +# Concrete errors — thin subclasses that set type + default status_code +# --------------------------------------------------------------------------- + + +class GatewayAuthenticationError(GatewayError): + """Authentication failed (HTTP 401).""" + + type = "authentication_error" + + def __init__( + self, + message: str = "Authentication failed", + *, + status_code: int = 401, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + @classmethod + def create_contextual( + cls, + *, + api_key_provided: bool, + status_code: int = 401, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> Self: + """Build a helpful message based on which auth method was used.""" + if api_key_provided: + msg = ( + "AI Gateway authentication failed: Invalid API key.\n\n" + f"Create a new API key: {_KEY_URL}\n\n" + "Provide via 'api_key' option or " + "'AI_GATEWAY_API_KEY' environment variable." + ) + else: + msg = ( + "AI Gateway authentication failed: " + "No authentication provided.\n\n" + f"Create an API key: {_KEY_URL}\n" + "Provide via 'api_key' option or " + "'AI_GATEWAY_API_KEY' environment variable." + ) + return cls( + msg, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + +class GatewayInvalidRequestError(GatewayError): + """Malformed or invalid request (HTTP 400).""" + + type = "invalid_request_error" + + def __init__( + self, + message: str = "Invalid request", + *, + status_code: int = 400, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayRateLimitError(GatewayError): + """Rate limit exceeded (HTTP 429).""" + + type = "rate_limit_exceeded" + + def __init__( + self, + message: str = "Rate limit exceeded", + *, + status_code: int = 429, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayModelNotFoundError(GatewayError): + """Requested model was not found (HTTP 404).""" + + type = "model_not_found" + + def __init__( + self, + message: str = "Model not found", + *, + status_code: int = 404, + model_id: str | None = None, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + self.model_id = model_id + + +class GatewayInternalServerError(GatewayError): + """Internal error on the gateway server (HTTP 500).""" + + type = "internal_server_error" + + def __init__( + self, + message: str = "Internal server error", + *, + status_code: int = 500, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +class GatewayResponseError(GatewayError): + """Malformed or unparseable response (HTTP 502).""" + + type = "response_error" + + def __init__( + self, + message: str = "Invalid response", + *, + status_code: int = 502, + response: Any = None, + validation_error: Any = None, + cause: BaseException | None = None, + generation_id: str | None = None, + ) -> None: + super().__init__( + message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + self.response = response + self.validation_error = validation_error + + +class GatewayTimeoutError(GatewayError): + """Gateway request timed out (HTTP 408).""" + + type = "timeout_error" + + def __init__( + self, + message: str = "Request timed out", + *, + status_code: int = 408, + **kwargs: Any, + ) -> None: + super().__init__(message, status_code=status_code, **kwargs) + + +# --------------------------------------------------------------------------- +# Error factory +# --------------------------------------------------------------------------- + +_TYPE_MAP: dict[str, type[GatewayError]] = { + "authentication_error": GatewayAuthenticationError, + "invalid_request_error": GatewayInvalidRequestError, + "rate_limit_exceeded": GatewayRateLimitError, + "model_not_found": GatewayModelNotFoundError, + "internal_server_error": GatewayInternalServerError, +} + +_MALFORMED = "Invalid error response format: Gateway request failed" + + +def create_gateway_error( + *, + response_body: Any, + status_code: int, + api_key_provided: bool = False, + cause: BaseException | None = None, +) -> GatewayError: + """Create a typed error from a gateway JSON error response. + + Falls back to :class:`GatewayResponseError` when the body doesn't + match the expected ``{"error": {"message": ..., "type": ...}}`` + shape. + """ + # Parse the response body + body: Any = response_body + if isinstance(body, (str, bytes)): + try: + body = json.loads(body) + except (json.JSONDecodeError, ValueError): + return GatewayResponseError( + message=_MALFORMED, + status_code=status_code, + response=response_body, + validation_error="Response body is not valid JSON", + cause=cause, + ) + + # Validate shape + error_obj = body.get("error") if isinstance(body, dict) else None + if not isinstance(error_obj, dict) or "message" not in error_obj: + reason = ( + "Missing 'error' field in response" + if not isinstance(error_obj, dict) + else "Missing 'message' field in error object" + ) + return GatewayResponseError( + message=_MALFORMED, + status_code=status_code, + response=body, + validation_error=reason, + cause=cause, + ) + + message: str = error_obj["message"] + error_type: str | None = error_obj.get("type") + generation_id: str | None = body.get("generationId") + + match error_type: + case "authentication_error": + return GatewayAuthenticationError.create_contextual( + api_key_provided=api_key_provided, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) + + case "model_not_found": + param = error_obj.get("param") + model_id = param.get("modelId") if isinstance(param, dict) else None + return GatewayModelNotFoundError( + message=message, + status_code=status_code, + model_id=model_id, + cause=cause, + generation_id=generation_id, + ) + + case _: + cls = _TYPE_MAP.get(error_type or "", GatewayInternalServerError) + return cls( + message=message, + status_code=status_code, + cause=cause, + generation_id=generation_id, + ) diff --git a/src/vercel_ai_sdk/ai_gateway/protocol.py b/src/vercel_ai_sdk/ai_gateway/protocol.py new file mode 100644 index 00000000..b12755bd --- /dev/null +++ b/src/vercel_ai_sdk/ai_gateway/protocol.py @@ -0,0 +1,421 @@ +"""Vercel AI Gateway v3 protocol: serialization and deserialization. + +Converts between the Python SDK's internal ``Message`` / ``StreamEvent`` +types and the LanguageModelV3 wire format used by the gateway at +``/v3/ai/language-model``. + +Wire format reference (from ``@ai-sdk/provider``): + +* **Request body** -- ``LanguageModelV3CallOptions`` (prompt + tools + + provider options, sent as JSON). +* **Stream response** -- Server-Sent Events where each ``data:`` line is + a JSON ``LanguageModelV3StreamPart`` (discriminated on ``type``). +* **Non-stream response** -- JSON ``LanguageModelV3GenerateResult``. +""" + +import json +from collections.abc import Sequence +from typing import Any + +from .. import core + +# --------------------------------------------------------------------------- +# Internal messages -> v3 prompt format (outgoing request body) +# --------------------------------------------------------------------------- + + +async def _file_part_to_v3(part: core.messages.FilePart) -> dict[str, Any]: + """Convert an internal :class:`FilePart` to a v3 ``file`` content part. + + Binary data is converted to a ``data:`` URL for JSON transport (matching + the JS SDK gateway's ``maybeEncodeFileParts``). HTTP(S) URLs are + downloaded and converted to ``data:`` URLs because the gateway wire + format does not accept raw HTTP URLs for file content. + """ + data = part.data + if isinstance(data, str) and core.media.data.is_downloadable_url(data): + downloaded, _ = await core.media.download.download(data) + data = downloaded + + entry: dict[str, Any] = { + "type": "file", + "mediaType": part.media_type, + "data": core.media.data.data_to_data_url(data, part.media_type), + } + if part.filename is not None: + entry["filename"] = part.filename + return entry + + +async def messages_to_v3_prompt( + messages: list[core.messages.Message], +) -> list[dict[str, Any]]: + """Convert internal ``Message`` list to ``LanguageModelV3Prompt``. + + The v3 prompt format is an array of messages, each with a ``role`` and + typed ``content`` parts:: + + [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, + {"role": "assistant", "content": [ + {"type": "text", "text": "Hello!"}, + {"type": "reasoning", "text": "..."}, + {"type": "tool-call", "toolCallId": "tc-1", ...}, + ]}, + {"role": "tool", "content": [ + {"type": "tool-result", "toolCallId": "tc-1", ...}, + ]}, + ] + """ + result: list[dict[str, Any]] = [] + for msg in messages: + match msg.role: + case "system": + text = "".join( + p.text for p in msg.parts if isinstance(p, core.messages.TextPart) + ) + result.append({"role": "system", "content": text}) + + case "user": + content: list[dict[str, Any]] = [] + for p in msg.parts: + if isinstance(p, core.messages.TextPart): + content.append({"type": "text", "text": p.text}) + elif isinstance(p, core.messages.FilePart): + content.append(await _file_part_to_v3(p)) + result.append({"role": "user", "content": content}) + + case "assistant": + assistant_content: list[dict[str, Any]] = [] + tool_results: list[dict[str, Any]] = [] + + for part in msg.parts: + match part: + case core.messages.ReasoningPart(text=text): + assistant_content.append( + {"type": "reasoning", "text": text} + ) + + case core.messages.TextPart(text=text): + assistant_content.append({"type": "text", "text": text}) + + case core.messages.ToolPart() as tp: + tool_input: Any = ( + json.loads(tp.tool_args) if tp.tool_args else {} + ) + assistant_content.append( + { + "type": "tool-call", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "input": tool_input, + } + ) + if tp.status in ("result", "error"): + output = ( + { + "type": "error-text", + "value": ( + str(tp.result) + if tp.result is not None + else "" + ), + } + if tp.status == "error" + else { + "type": "json", + "value": tp.result, + } + ) + tool_results.append( + { + "type": "tool-result", + "toolCallId": tp.tool_call_id, + "toolName": tp.tool_name, + "output": output, + } + ) + + result.append( + { + "role": "assistant", + "content": assistant_content, + } + ) + if tool_results: + result.append( + { + "role": "tool", + "content": tool_results, + } + ) + + return result + + +# --------------------------------------------------------------------------- +# Request body serialization +# --------------------------------------------------------------------------- + + +async def build_request_body( + messages: list[core.messages.Message], + tools: Sequence[core.tools.ToolLike] | None = None, + output_type: type[Any] | None = None, + provider_options: dict[str, Any] | None = None, +) -> dict[str, Any]: + """Build the full ``LanguageModelV3CallOptions`` request body.""" + body: dict[str, Any] = { + "prompt": await messages_to_v3_prompt(messages), + } + if tools: + body["tools"] = [ + { + "type": "function", + "name": tool.name, + "description": tool.description, + "inputSchema": tool.param_schema, + } + for tool in tools + ] + if output_type is not None: + import pydantic + + if issubclass(output_type, pydantic.BaseModel): + body["responseFormat"] = { + "type": "json", + "schema": output_type.model_json_schema(), + "name": output_type.__name__, + } + if provider_options: + body["providerOptions"] = provider_options + return body + + +# --------------------------------------------------------------------------- +# v3 stream parts -> internal StreamEvent (incoming SSE response) +# --------------------------------------------------------------------------- + + +def parse_stream_part( + data: dict[str, Any], +) -> list[core.llm.StreamEvent]: + """Convert a ``LanguageModelV3StreamPart`` to internal events. + + Most parts map 1:1. A ``tool-call`` part (complete, non-streaming) + expands to Start + ArgsDelta + End. Lifecycle events + (``stream-start``, ``response-metadata``, ``raw``) are silently + dropped. + """ + match data.get("type", ""): + case "text-start": + return [ + core.llm.TextStart( + block_id=data.get("id", "text"), + ) + ] + + case "text-delta": + return [ + core.llm.TextDelta( + block_id=data.get("id", "text"), + delta=data.get("textDelta", data.get("delta", "")), + ) + ] + + case "text-end": + return [ + core.llm.TextEnd( + block_id=data.get("id", "text"), + ) + ] + + case "reasoning-start": + return [ + core.llm.ReasoningStart( + block_id=data.get("id", "reasoning"), + ) + ] + + case "reasoning-delta": + return [ + core.llm.ReasoningDelta( + block_id=data.get("id", "reasoning"), + delta=data.get("delta", ""), + ) + ] + + case "reasoning-end": + return [ + core.llm.ReasoningEnd( + block_id=data.get("id", "reasoning"), + ) + ] + + case "tool-input-start": + return [ + core.llm.ToolStart( + tool_call_id=data.get("id", ""), + tool_name=data.get("toolName", ""), + ) + ] + + case "tool-input-delta": + return [ + core.llm.ToolArgsDelta( + tool_call_id=data.get("id", ""), + delta=data.get("delta", ""), + ) + ] + + case "tool-input-end": + return [ + core.llm.ToolEnd( + tool_call_id=data.get("id", ""), + ) + ] + + case "tool-call": + return _expand_tool_call(data) + + case "file": + return [ + core.llm.FileEvent( + block_id=data.get("id", f"file-{len(data)}"), + media_type=data.get("mediaType", "application/octet-stream"), + data=data.get("data", ""), + ) + ] + + case "finish": + return [_parse_finish(data)] + + case _: + return [] + + +# --------------------------------------------------------------------------- +# Non-streaming response -> internal StreamEvents +# --------------------------------------------------------------------------- + + +def parse_generate_result( + data: dict[str, Any], +) -> list[core.llm.StreamEvent]: + """Convert a ``LanguageModelV3GenerateResult`` into events. + + Synthesises Start/Delta/End events from the content, then a final + ``MessageDone``. + """ + events: list[core.llm.StreamEvent] = [] + + def _expand_content_item(item: dict[str, Any]) -> None: + match item.get("type", ""): + case "text": + bid = item.get("id", "text") + text = item.get("text", "") + events.append(core.llm.TextStart(block_id=bid)) + events.append(core.llm.TextDelta(block_id=bid, delta=text)) + events.append(core.llm.TextEnd(block_id=bid)) + + case "reasoning": + bid = item.get("id", "reasoning") + text = item.get("text", "") + events.append(core.llm.ReasoningStart(block_id=bid)) + events.append(core.llm.ReasoningDelta(block_id=bid, delta=text)) + events.append(core.llm.ReasoningEnd(block_id=bid)) + + case "tool-call": + events.extend(_expand_tool_call(item)) + + case "file": + events.append( + core.llm.FileEvent( + block_id=item.get("id", f"file-{len(events)}"), + media_type=item.get("mediaType", "application/octet-stream"), + data=item.get("data", ""), + ) + ) + + match data.get("content"): + case list() as items: + for item in items: + _expand_content_item(item) + case dict() as item: + _expand_content_item(item) + + events.append(_parse_finish(data)) + return events + + +# --------------------------------------------------------------------------- +# Shared helpers (called from multiple sites) +# --------------------------------------------------------------------------- + + +def _expand_tool_call( + data: dict[str, Any], +) -> list[core.llm.StreamEvent]: + """Expand a complete ``tool-call`` part into three events.""" + tc_id = data.get("toolCallId", "") + tool_name = data.get("toolName", "") + tool_input = data.get("input", "") + args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) + return [ + core.llm.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + core.llm.ToolArgsDelta(tool_call_id=tc_id, delta=args_str), + core.llm.ToolEnd(tool_call_id=tc_id), + ] + + +def _parse_finish(data: dict[str, Any]) -> core.llm.MessageDone: + """Parse a ``finish`` stream part into a ``MessageDone`` event.""" + usage_data = data.get("usage") + usage = _parse_usage(usage_data) if usage_data else None + + match data.get("finishReason"): + case dict() as d: + finish_reason = d.get("unified", "stop") + case str() as s: + finish_reason = s + case _: + finish_reason = "stop" + + return core.llm.MessageDone(finish_reason=finish_reason, usage=usage) + + +def _parse_usage(data: Any) -> core.messages.Usage: + """Parse a v3 ``LanguageModelV3Usage`` into an internal ``Usage``. + + Supports both the v3 nested format:: + + {"inputTokens": {"total": 10, ...}, "outputTokens": {...}} + + and the flat OpenAI-style format:: + + {"prompt_tokens": 10, "completion_tokens": 20} + """ + if not isinstance(data, dict): + return core.messages.Usage() + + input_tokens_obj = data.get("inputTokens") + output_tokens_obj = data.get("outputTokens") + + if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): + inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} + out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} + return core.messages.Usage( + input_tokens=inp.get("total") or 0, + output_tokens=out.get("total") or 0, + reasoning_tokens=out.get("reasoning"), + cache_read_tokens=inp.get("cacheRead"), + cache_write_tokens=inp.get("cacheWrite"), + raw=data, + ) + + return core.messages.Usage( + input_tokens=(data.get("prompt_tokens") or data.get("inputTokens") or 0), + output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), + raw=data, + ) diff --git a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py index f39e1fd8..9b7c8c87 100644 --- a/src/vercel_ai_sdk/ai_sdk_ui/adapter.py +++ b/src/vercel_ai_sdk/ai_sdk_ui/adapter.py @@ -467,9 +467,17 @@ def to_messages( ) resolved_any_approval = True + case ui_message.UIFilePart() as fp: + internal_parts.append( + core.messages.FilePart( + data=fp.url, + media_type=fp.media_type, + filename=fp.filename, + ) + ) + case ( ui_message.UIStepStartPart() - | ui_message.UIFilePart() | ui_message.UISourceUrlPart() | ui_message.UISourceDocumentPart() ): diff --git a/src/vercel_ai_sdk/anthropic/__init__.py b/src/vercel_ai_sdk/anthropic/__init__.py index f0be21f8..0941e2c8 100644 --- a/src/vercel_ai_sdk/anthropic/__init__.py +++ b/src/vercel_ai_sdk/anthropic/__init__.py @@ -23,7 +23,73 @@ def _tools_to_anthropic(tools: Sequence[core.tools.ToolLike]) -> list[dict[str, ] -def _messages_to_anthropic( +def _file_part_to_anthropic(part: core.messages.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to an Anthropic content block. + + * ``image/*`` → ``{"type": "image", "source": ...}`` + * ``application/pdf`` → ``{"type": "document", "source": ...}`` + * ``text/plain`` → ``{"type": "document", "source": {"type": "text", ...}}`` + * anything else → ``ValueError`` + """ + mt = part.media_type + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + if isinstance(part.data, str) and core.media.data.is_url(part.data): + return { + "type": "image", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": core.media.data.data_to_base64(part.data), + }, + } + + if mt == "application/pdf": + if isinstance(part.data, str) and core.media.data.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + return { + "type": "document", + "source": { + "type": "base64", + "media_type": "application/pdf", + "data": core.media.data.data_to_base64(part.data), + }, + } + + if mt == "text/plain": + # Anthropic accepts text documents with source.type="text" + if isinstance(part.data, bytes): + text_data = part.data.decode("utf-8") + elif core.media.data.is_url(part.data): + return { + "type": "document", + "source": {"type": "url", "url": part.data}, + } + else: + import base64 as _b64 + + text_data = _b64.b64decode(part.data).decode("utf-8") + return { + "type": "document", + "source": { + "type": "text", + "media_type": "text/plain", + "data": text_data, + }, + } + + raise ValueError(f"Unsupported media type for Anthropic: {mt}") + + +async def _messages_to_anthropic( messages: list[core.messages.Message], ) -> tuple[str | None, list[dict[str, Any]]]: """Convert internal messages to Anthropic API format. @@ -89,12 +155,21 @@ def _messages_to_anthropic( result.append({"role": "assistant", "content": content}) if tool_results: result.append({"role": "user", "content": tool_results}) - else: - # User messages - content_text = "".join( - p.text for p in msg.parts if isinstance(p, core.messages.TextPart) - ) - result.append({"role": "user", "content": content_text}) + elif msg.role == "user": + has_files = any(isinstance(p, core.messages.FilePart) for p in msg.parts) + if not has_files: + content_text = "".join( + p.text for p in msg.parts if isinstance(p, core.messages.TextPart) + ) + result.append({"role": "user", "content": content_text}) + else: + user_content: list[dict[str, Any]] = [] + for p in msg.parts: + if isinstance(p, core.messages.TextPart): + user_content.append({"type": "text", "text": p.text}) + elif isinstance(p, core.messages.FilePart): + user_content.append(_file_part_to_anthropic(p)) + result.append({"role": "user", "content": user_content}) # Merge consecutive same-role messages (e.g. synthetic user(tool_result) # followed by a real user message). @@ -162,7 +237,7 @@ async def stream_events( output_type: type[pydantic.BaseModel] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent]: """Yield raw stream events from Anthropic API.""" - system_prompt, anthropic_messages = _messages_to_anthropic(messages) + system_prompt, anthropic_messages = await _messages_to_anthropic(messages) anthropic_tools = _tools_to_anthropic(tools) if tools else None kwargs: dict[str, Any] = { diff --git a/src/vercel_ai_sdk/core/__init__.py b/src/vercel_ai_sdk/core/__init__.py index d2dca9b6..e31f2299 100644 --- a/src/vercel_ai_sdk/core/__init__.py +++ b/src/vercel_ai_sdk/core/__init__.py @@ -1,3 +1,19 @@ -from . import hooks, llm, messages, runtime, telemetry, tools +from . import ( + hooks, + llm, + media, + messages, + runtime, + telemetry, + tools, +) -__all__ = ["messages", "tools", "runtime", "hooks", "llm", "telemetry"] +__all__ = [ + "hooks", + "llm", + "media", + "messages", + "runtime", + "telemetry", + "tools", +] diff --git a/src/vercel_ai_sdk/core/llm.py b/src/vercel_ai_sdk/core/llm.py index e967503b..d404b2a9 100644 --- a/src/vercel_ai_sdk/core/llm.py +++ b/src/vercel_ai_sdk/core/llm.py @@ -60,6 +60,15 @@ class ToolEnd: tool_call_id: str +@dataclasses.dataclass +class FileEvent: + """A complete generated file from the LLM (e.g. inline image from Gemini/GPT).""" + + block_id: str + media_type: str + data: str # base64 string or data-URL from the gateway + + @dataclasses.dataclass class MessageDone: finish_reason: str | None = None @@ -76,6 +85,7 @@ class MessageDone: | ToolStart | ToolArgsDelta | ToolEnd + | FileEvent | MessageDone ) @@ -98,6 +108,9 @@ class StreamHandler: _tool_calls: dict[str, tuple[str, str]] = dataclasses.field( default_factory=dict ) # (name, args) + _files: dict[str, tuple[str, str]] = dataclasses.field( + default_factory=dict + ) # block_id -> (media_type, data) # Active tracking _active_text_id: str | None = None @@ -155,6 +168,9 @@ def handle_event(self, event: StreamEvent) -> messages_.Message: case ToolEnd(tool_call_id=tcid): self._active_tool_ids.discard(tcid) + case FileEvent(block_id=bid, media_type=mt, data=d): + self._files[bid] = (mt, d) + case MessageDone(usage=usage): self._is_done = True self._usage = usage @@ -208,6 +224,10 @@ def _build_message( ) ) + # File parts (inline images/videos from LLMs like Gemini, GPT-5) + for _bid, (media_type, data) in self._files.items(): + parts.append(messages_.FilePart(data=data, media_type=media_type)) + return messages_.Message( id=self.message_id, role="assistant", diff --git a/src/vercel_ai_sdk/core/media/__init__.py b/src/vercel_ai_sdk/core/media/__init__.py new file mode 100644 index 00000000..cdb15c6c --- /dev/null +++ b/src/vercel_ai_sdk/core/media/__init__.py @@ -0,0 +1,15 @@ +"""Media handling: models, detection, download, and data-format helpers.""" + +from . import data, detect_media_type, download, models +from .models import ImageModel, MediaModel, MediaResult, VideoModel + +__all__ = [ + "data", + "detect_media_type", + "download", + "models", + "ImageModel", + "MediaModel", + "MediaResult", + "VideoModel", +] diff --git a/src/vercel_ai_sdk/core/media/data.py b/src/vercel_ai_sdk/core/media/data.py new file mode 100644 index 00000000..e92fb5e2 --- /dev/null +++ b/src/vercel_ai_sdk/core/media/data.py @@ -0,0 +1,100 @@ +"""Data-format helpers for multimodal content. + +URL detection, ``data:`` URL parsing, base-64 encoding/decoding, and +media-type inference utilities used by :class:`~vercel_ai_sdk.core.messages.FilePart` +and the provider converters. +""" + +from __future__ import annotations + +import base64 +import mimetypes + +# -- URL helpers ----------------------------------------------------------- + + +def is_url(data: str) -> bool: + """Return True if *data* looks like a URL rather than raw base-64.""" + return data.startswith(("http://", "https://", "data:")) + + +def is_downloadable_url(data: str) -> bool: + """Return True if *data* is an ``http(s)://`` URL that can be fetched.""" + return data.startswith(("http://", "https://")) + + +def split_data_url(url: str) -> tuple[str | None, str | None]: + """Parse a ``data:`` URL into ``(media_type, base64_content)``. + + Returns ``(None, None)`` if the input is not a valid ``data:`` URL. + + Example:: + + >>> split_data_url("data:image/png;base64,iVBOR...") + ("image/png", "iVBOR...") + """ + if not url.startswith("data:"): + return None, None + try: + header, b64_content = url.split(",", 1) + # header = "data:image/png;base64" + mt = header.split(";")[0].split(":", 1)[1] + return (mt or None), (b64_content or None) + except (ValueError, IndexError): + return None, None + + +# -- encoding helpers ------------------------------------------------------ + + +def data_to_base64(data: str | bytes) -> str: + """Ensure *data* is a base-64 encoded string. + + * ``bytes`` -> base-64 encoded. + * ``str`` that is a ``data:`` URL -> base-64 content extracted. + * ``str`` that is an ``http(s)://`` URL -> returned as-is (caller + must handle). + * ``str`` that is not a URL -> assumed to already be base-64. + """ + if isinstance(data, bytes): + return base64.b64encode(data).decode("ascii") + if data.startswith("data:"): + _, b64 = split_data_url(data) + if b64 is not None: + return b64 + return data + + +def data_to_data_url(data: str | bytes, media_type: str) -> str: + """Convert *data* to a ``data:`` URL. Passes through existing URLs.""" + if isinstance(data, str) and is_url(data): + return data + b64 = data_to_base64(data) + return f"data:{media_type};base64,{b64}" + + +# -- media-type inference -------------------------------------------------- + + +def infer_media_type(url: str) -> str: + """Infer IANA media type from a URL. + + * ``data:image/png;base64,...`` -> ``"image/png"`` + * ``https://example.com/cat.jpg`` -> ``"image/jpeg"`` (via :mod:`mimetypes`) + * Unknown -> raises :class:`ValueError` + """ + if url.startswith("data:"): + # data:[][;base64], + rest = url[5:] # strip "data:" + sep = rest.find(",") + meta = rest[:sep] if sep != -1 else rest + mt = meta.split(";")[0] + if mt: + return mt + else: + guessed, _ = mimetypes.guess_type(url) + if guessed: + return guessed + raise ValueError( + f"Cannot infer media_type from URL: {url!r}. Provide media_type explicitly." + ) diff --git a/src/vercel_ai_sdk/core/media/detect_media_type.py b/src/vercel_ai_sdk/core/media/detect_media_type.py new file mode 100644 index 00000000..a9bf770a --- /dev/null +++ b/src/vercel_ai_sdk/core/media/detect_media_type.py @@ -0,0 +1,188 @@ +"""Magic-byte media type detection. + +Port of ``@ai-sdk/ai/src/util/detect-media-type.ts``. Detects image, +audio, and video formats by inspecting the first bytes of binary data +(or the first characters of a base-64 string). +""" + +from __future__ import annotations + +import base64 as _b64 + +# --------------------------------------------------------------------------- +# Signature definitions +# --------------------------------------------------------------------------- + +# Each signature is a tuple of (media_type, byte_prefix) where byte_prefix +# is a tuple of ``int | None`` values. ``None`` is a wildcard that matches +# any byte (mirrors the TS SDK's ``null`` sentinel). + +_Signature = tuple[str, tuple[int | None, ...]] + +IMAGE_SIGNATURES: list[_Signature] = [ + ("image/gif", (0x47, 0x49, 0x46)), + ("image/png", (0x89, 0x50, 0x4E, 0x47)), + ("image/jpeg", (0xFF, 0xD8)), + ( + "image/webp", + (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x45, 0x42, 0x50), + ), + ("image/bmp", (0x42, 0x4D)), + ("image/tiff", (0x49, 0x49, 0x2A, 0x00)), # little-endian + ("image/tiff", (0x4D, 0x4D, 0x00, 0x2A)), # big-endian + ( + "image/avif", + (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x61, 0x76, 0x69, 0x66), + ), + ( + "image/heic", + (0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x68, 0x65, 0x69, 0x63), + ), +] + +AUDIO_SIGNATURES: list[_Signature] = [ + ("audio/mpeg", (0xFF, 0xFB)), + ("audio/mpeg", (0xFF, 0xFA)), + ("audio/mpeg", (0xFF, 0xF3)), + ("audio/mpeg", (0xFF, 0xF2)), + ("audio/mpeg", (0xFF, 0xE3)), + ("audio/mpeg", (0xFF, 0xE2)), + ( + "audio/wav", + (0x52, 0x49, 0x46, 0x46, None, None, None, None, 0x57, 0x41, 0x56, 0x45), + ), + ("audio/ogg", (0x4F, 0x67, 0x67, 0x53)), + ("audio/flac", (0x66, 0x4C, 0x61, 0x43)), + ("audio/aac", (0x40, 0x15, 0x00, 0x00)), + ("audio/mp4", (0x66, 0x74, 0x79, 0x70)), + ("audio/webm", (0x1A, 0x45, 0xDF, 0xA3)), +] + +VIDEO_SIGNATURES: list[_Signature] = [ + ("video/mp4", (0x00, 0x00, 0x00, None, 0x66, 0x74, 0x79, 0x70)), + ("video/webm", (0x1A, 0x45, 0xDF, 0xA3)), + ( + "video/quicktime", + (0x00, 0x00, 0x00, 0x14, 0x66, 0x74, 0x79, 0x70, 0x71, 0x74), + ), + ("video/x-msvideo", (0x52, 0x49, 0x46, 0x46)), +] + + +# --------------------------------------------------------------------------- +# ID3 tag stripping (for MP3 files that start with ID3v2 metadata) +# --------------------------------------------------------------------------- + +_ID3_HEADER = bytes([0x49, 0x44, 0x33]) # "ID3" +_ID3_BASE64 = "SUQz" # base64("ID3") + + +def _strip_id3_tags(data: bytes) -> bytes: + """Strip an ID3v2 tag header if present, returning the audio data.""" + if len(data) < 10 or data[:3] != _ID3_HEADER: + return data + # Syncsafe integer: 4 bytes, 7 bits each + size = ( + (data[6] & 0x7F) << 21 + | (data[7] & 0x7F) << 14 + | (data[8] & 0x7F) << 7 + | (data[9] & 0x7F) + ) + offset = size + 10 + return data[offset:] if offset < len(data) else data + + +def _strip_id3_tags_base64(data: str) -> str: + """Strip an ID3v2 tag from base64-encoded data if present.""" + if not data.startswith(_ID3_BASE64): + return data + # Decode enough to read the ID3 header (10 bytes = ~16 base64 chars) + try: + header = _b64.b64decode(data[:16]) + except Exception: + return data + if len(header) < 10 or header[:3] != _ID3_HEADER: + return data + size = ( + (header[6] & 0x7F) << 21 + | (header[7] & 0x7F) << 14 + | (header[8] & 0x7F) << 7 + | (header[9] & 0x7F) + ) + offset = size + 10 + # Re-encode: decode full data, strip, re-encode + try: + full = _b64.b64decode(data) + stripped = full[offset:] if offset < len(full) else full + return _b64.b64encode(stripped).decode("ascii") + except Exception: + return data + + +# --------------------------------------------------------------------------- +# Core detection +# --------------------------------------------------------------------------- + + +def _to_bytes(data: bytes | str, *, max_bytes: int = 24) -> bytes: + """Convert *data* to bytes for signature comparison. + + For ``str`` input (base-64), decodes only the first *max_bytes* + characters worth of data to avoid decoding large payloads. + """ + if isinstance(data, bytes): + return data[:max_bytes] + # base-64: 4 chars → 3 bytes. Decode ~32 chars to get enough bytes. + chunk = data[: max_bytes * 2] + # Pad to multiple of 4 for valid base64 + padded = chunk + "=" * (-len(chunk) % 4) + try: + return _b64.b64decode(padded)[:max_bytes] + except Exception: + return b"" + + +def detect_media_type( + data: bytes | str, + signatures: list[_Signature], +) -> str | None: + """Detect media type from magic bytes. + + Args: + data: Raw bytes or a base-64 encoded string. + signatures: List of ``(media_type, byte_prefix)`` tuples to + match against (e.g. :data:`IMAGE_SIGNATURES`). + + Returns: + The matched IANA media type, or ``None`` if no signature matches. + """ + # Strip ID3 tags for audio detection + if signatures is AUDIO_SIGNATURES: + if isinstance(data, bytes): + data = _strip_id3_tags(data) + else: + data = _strip_id3_tags_base64(data) + + raw = _to_bytes(data) + if not raw: + return None + + for media_type, prefix in signatures: + if len(raw) < len(prefix): + continue + if all( + expected is None or raw[i] == expected for i, expected in enumerate(prefix) + ): + return media_type + + return None + + +def detect_image_media_type(data: bytes | str) -> str | None: + """Detect image format from magic bytes.""" + return detect_media_type(data, IMAGE_SIGNATURES) + + +def detect_audio_media_type(data: bytes | str) -> str | None: + """Detect audio format from magic bytes.""" + return detect_media_type(data, AUDIO_SIGNATURES) diff --git a/src/vercel_ai_sdk/core/media/download.py b/src/vercel_ai_sdk/core/media/download.py new file mode 100644 index 00000000..ef3757af --- /dev/null +++ b/src/vercel_ai_sdk/core/media/download.py @@ -0,0 +1,104 @@ +"""Async download utility for URL-based file parts. + +Port of ``@ai-sdk/ai/src/util/download/download.ts``. Used by +provider adapters that need to fetch a URL the provider API cannot +accept natively (e.g. OpenAI does not support audio/PDF URLs). +""" + +from __future__ import annotations + +import httpx + +DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) +_ALLOWED_SCHEMES = frozenset({"http", "https"}) + + +class DownloadError(Exception): + """Raised when a URL download fails.""" + + def __init__( + self, + url: str, + *, + status_code: int | None = None, + status_text: str | None = None, + cause: BaseException | None = None, + ) -> None: + parts = [f"Failed to download {url!r}"] + if status_code is not None: + parts.append(f"status={status_code}") + if status_text: + parts.append(status_text) + super().__init__(": ".join(parts)) + self.url = url + self.status_code = status_code + if cause is not None: + self.__cause__ = cause + + +def _validate_url(url: str) -> None: + """Reject non-HTTP(S) URLs (SSRF prevention).""" + from urllib.parse import urlparse + + parsed = urlparse(url) + if parsed.scheme not in _ALLOWED_SCHEMES: + raise DownloadError( + url, status_text=f"Unsupported URL scheme: {parsed.scheme!r}" + ) + + +async def download( + url: str, + *, + max_bytes: int = DEFAULT_MAX_BYTES, +) -> tuple[bytes, str | None]: + """Download *url* and return ``(data, content_type)``. + + Args: + url: The URL to fetch (must be ``http`` or ``https``). + max_bytes: Maximum response size. Defaults to 100 MiB. + + Returns: + A tuple of ``(raw_bytes, content_type_or_None)``. + + Raises: + DownloadError: On any failure (network, HTTP status, size, etc.). + """ + _validate_url(url) + + try: + async with httpx.AsyncClient(follow_redirects=True) as client: + resp = await client.get(url) + + # Validate redirect target + if resp.url is not None and str(resp.url) != url: + _validate_url(str(resp.url)) + + if resp.status_code >= 400: + raise DownloadError( + url, + status_code=resp.status_code, + status_text=resp.reason_phrase or "", + ) + + data = resp.content + if len(data) > max_bytes: + raise DownloadError( + url, + status_text=( + f"Response exceeds maximum size " + f"({len(data)} > {max_bytes} bytes)" + ), + ) + + content_type = resp.headers.get("content-type") + # Strip charset/parameters: "image/png; charset=..." → "image/png" + if content_type: + content_type = content_type.split(";")[0].strip() + + return data, content_type or None + + except DownloadError: + raise + except Exception as exc: + raise DownloadError(url, cause=exc) from exc diff --git a/src/vercel_ai_sdk/core/media/models.py b/src/vercel_ai_sdk/core/media/models.py new file mode 100644 index 00000000..8a7ff2ca --- /dev/null +++ b/src/vercel_ai_sdk/core/media/models.py @@ -0,0 +1,314 @@ +"""Media generation model abstractions (image, video). + +Provides :class:`MediaModel` (shared base), :class:`ImageModel`, and +:class:`VideoModel` along with the :class:`MediaResult` return type. + +The base class owns the common pipeline steps that every adapter would +otherwise duplicate: + +* **Input** -- extract a text prompt and input files from messages. +* **Output** -- wrap the adapter's :class:`MediaResult` into a + :class:`Message` with ``role="assistant"``. + +Subclasses define the public ``generate()`` signature with +media-type-specific parameters and delegate to the adapter's +``make_request()`` method. + +Usage:: + + # Image model + model = ai.ai_gateway.GatewayImageModel( + model="google/imagen-4.0-generate-001", + ) + msg = await model.generate( + ai.make_messages(user="A sunset over Tokyo"), + n=2, aspect_ratio="16:9", + ) + for img in msg.images: + print(img.media_type, len(img.data)) + + # Video model + model = ai.ai_gateway.GatewayVideoModel( + model="google/veo-3.0-generate-001", + ) + msg = await model.generate( + ai.make_messages(user="A cat on a beach"), + aspect_ratio="16:9", duration=5, + ) + for vid in msg.videos: + print(vid.media_type, len(vid.data)) +""" + +from __future__ import annotations + +import abc +import dataclasses +from typing import Any, override + +from .. import messages as messages_ + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class MediaResult: + """Raw result returned by an adapter's ``make_request()`` method. + + The framework wraps this into a :class:`Message` automatically. + """ + + files: list[messages_.FilePart] + usage: messages_.Usage | None = None + + +# --------------------------------------------------------------------------- +# MediaModel -- shared base +# --------------------------------------------------------------------------- + + +class MediaModel(abc.ABC): + """Abstract base for media generation models. + + Owns the shared pipeline steps that every adapter would otherwise + duplicate: + + * **Input** -- extract a text prompt and input files from + :class:`Message` objects. + * **Output** -- wrap the adapter's :class:`MediaResult` into a + :class:`Message` with ``role="assistant"``. + + Subclasses (:class:`ImageModel`, :class:`VideoModel`) define the + public ``generate()`` signature with media-type-specific parameters + and delegate to the adapter's ``make_request()`` method. + """ + + @staticmethod + def _extract_prompt(messages: list[messages_.Message]) -> str: + """Concatenate all :class:`TextPart` texts from user/system messages.""" + parts: list[str] = [] + for msg in messages: + if msg.role in ("user", "system"): + for p in msg.parts: + if isinstance(p, messages_.TextPart): + parts.append(p.text) + return " ".join(parts) + + @staticmethod + def _extract_input_files( + messages: list[messages_.Message], + ) -> list[messages_.FilePart]: + """Collect all :class:`FilePart` objects from user messages.""" + files: list[messages_.FilePart] = [] + for msg in messages: + if msg.role == "user": + for p in msg.parts: + if isinstance(p, messages_.FilePart): + files.append(p) + return files + + @staticmethod + def _build_message(result: MediaResult) -> messages_.Message: + """Wrap adapter output into a :class:`Message`.""" + return messages_.Message( + role="assistant", + parts=list(result.files), + usage=result.usage, + ) + + @abc.abstractmethod + async def make_request( + self, + prompt: str, + input_files: list[messages_.FilePart], + *, + n: int = 1, + provider_options: dict[str, Any] | None = None, + ) -> MediaResult: + """Adapter-specific generation logic. + + Receives already-parsed inputs and returns a :class:`MediaResult`. + The framework calls this from ``generate()`` and wraps the result + into a :class:`Message`. + """ + ... + + +# --------------------------------------------------------------------------- +# ImageModel +# --------------------------------------------------------------------------- + + +class ImageModel(MediaModel): + """Abstract image generation model. + + Accepts :class:`Message`\\s as input and returns a :class:`Message` + containing generated images as :class:`FilePart`\\s. + + Adapter authors implement :meth:`make_request`; the framework handles + parsing messages and assembling the response :class:`Message`. + """ + + async def generate( + self, + messages: list[messages_.Message], + *, + n: int = 1, + size: str | None = None, + aspect_ratio: str | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> messages_.Message: + """Generate images from the given messages. + + Args: + messages: Input messages containing the text prompt and + optional input images (as :class:`FilePart`\\s) for + editing. + n: Number of images to generate. + size: Image dimensions (e.g. ``"1024x1024"``). + aspect_ratio: Aspect ratio (e.g. ``"16:9"``). + seed: Random seed for reproducible generation. + provider_options: Provider-specific options. + + Returns: + A :class:`Message` with ``role="assistant"`` containing + :class:`FilePart`\\s for each generated image. + """ + prompt = self._extract_prompt(messages) + input_files = self._extract_input_files(messages) + result = await self.make_request( + prompt, + input_files, + n=n, + size=size, + aspect_ratio=aspect_ratio, + seed=seed, + provider_options=provider_options, + ) + return self._build_message(result) + + @override + @abc.abstractmethod + async def make_request( + self, + prompt: str, + input_files: list[messages_.FilePart], + *, + n: int = 1, + size: str | None = None, + aspect_ratio: str | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> MediaResult: + """Adapter-specific image generation. + + Args: + prompt: Text prompt extracted from messages. + input_files: File parts from user messages (for editing). + n: Number of images to generate. + size: Image dimensions (e.g. ``"1024x1024"``). + aspect_ratio: Aspect ratio (e.g. ``"16:9"``). + seed: Random seed for reproducible generation. + provider_options: Provider-specific options. + + Returns: + A :class:`MediaResult` with generated image files. + """ + ... + + +# --------------------------------------------------------------------------- +# VideoModel +# --------------------------------------------------------------------------- + + +class VideoModel(MediaModel): + """Abstract video generation model. + + Accepts :class:`Message`\\s as input and returns a :class:`Message` + containing generated videos as :class:`FilePart`\\s. + + Adapter authors implement :meth:`make_request`; the framework handles + parsing messages and assembling the response :class:`Message`. + """ + + async def generate( + self, + messages: list[messages_.Message], + *, + n: int = 1, + aspect_ratio: str | None = None, + resolution: str | None = None, + duration: float | None = None, + fps: int | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> messages_.Message: + """Generate videos from the given messages. + + Args: + messages: Input messages containing the text prompt and + optional input image (as a :class:`FilePart`) for + image-to-video. + n: Number of videos to generate. + aspect_ratio: Aspect ratio (e.g. ``"16:9"``). + resolution: Video resolution (e.g. ``"1920x1080"``). + duration: Duration in seconds. + fps: Frames per second. + seed: Random seed for reproducible generation. + provider_options: Provider-specific options. + + Returns: + A :class:`Message` with ``role="assistant"`` containing + :class:`FilePart`\\s for each generated video. + """ + prompt = self._extract_prompt(messages) + input_files = self._extract_input_files(messages) + result = await self.make_request( + prompt, + input_files, + n=n, + aspect_ratio=aspect_ratio, + resolution=resolution, + duration=duration, + fps=fps, + seed=seed, + provider_options=provider_options, + ) + return self._build_message(result) + + @override + @abc.abstractmethod + async def make_request( + self, + prompt: str, + input_files: list[messages_.FilePart], + *, + n: int = 1, + aspect_ratio: str | None = None, + resolution: str | None = None, + duration: float | None = None, + fps: int | None = None, + seed: int | None = None, + provider_options: dict[str, Any] | None = None, + ) -> MediaResult: + """Adapter-specific video generation. + + Args: + prompt: Text prompt extracted from messages. + input_files: File parts from user messages (e.g. input + image for image-to-video). + n: Number of videos to generate. + aspect_ratio: Aspect ratio (e.g. ``"16:9"``). + resolution: Video resolution (e.g. ``"1920x1080"``). + duration: Duration in seconds. + fps: Frames per second. + seed: Random seed for reproducible generation. + provider_options: Provider-specific options. + + Returns: + A :class:`MediaResult` with generated video files. + """ + ... diff --git a/src/vercel_ai_sdk/core/messages.py b/src/vercel_ai_sdk/core/messages.py index 88a5cb85..c06fc022 100644 --- a/src/vercel_ai_sdk/core/messages.py +++ b/src/vercel_ai_sdk/core/messages.py @@ -6,6 +6,8 @@ import pydantic +from . import media + # Streaming state for parts PartState = Literal["streaming", "done"] @@ -109,8 +111,65 @@ def value(self) -> Any: return self._hydrated +class FilePart(pydantic.BaseModel): + """File, image, or audio content part. + + Covers images (``image/*``), documents (``application/pdf``, ``text/*``), + and audio (``audio/*``). The ``media_type`` field tells provider + converters how to format this part for each API. + + ``data`` accepts: + + * **str** -- a URL (``http(s)://...`` or ``data:...``) *or* raw base-64 text. + * **bytes** -- raw binary data (will be base-64 encoded when serialized + to JSON for providers that need it). + """ + + data: str | bytes + media_type: str # IANA media type, e.g. "image/png", "audio/wav" + filename: str | None = None + type: Literal["file"] = "file" + + @classmethod + def from_url(cls, url: str, *, media_type: str | None = None) -> FilePart: + """Create from a URL, inferring ``media_type`` from the URL if omitted. + + Inference handles ``data:`` URLs (the media type is embedded in the + prefix) and ``http(s)://`` URLs (via :func:`mimetypes.guess_type`). + Raises :class:`ValueError` if inference fails and no explicit + ``media_type`` is provided. + """ + if media_type is None: + media_type = media.data.infer_media_type(url) + return cls(data=url, media_type=media_type) + + @classmethod + def from_bytes( + cls, + data: bytes, + *, + media_type: str | None = None, + filename: str | None = None, + ) -> FilePart: + """Create from raw bytes, detecting ``media_type`` via magic bytes. + + Attempts image detection first, then audio. Raises + :class:`ValueError` if no ``media_type`` is provided and + detection fails. + """ + if media_type is None: + media_type = media.detect_media_type.detect_image_media_type( + data + ) or media.detect_media_type.detect_audio_media_type(data) + if media_type is None: + raise ValueError( + "Cannot detect media_type from bytes. Provide media_type explicitly." + ) + return cls(data=data, media_type=media_type, filename=filename) + + Part = Annotated[ - TextPart | ToolPart | ReasoningPart | HookPart | StructuredOutputPart, + TextPart | ToolPart | ReasoningPart | HookPart | StructuredOutputPart | FilePart, pydantic.Field(discriminator="type"), ] @@ -232,6 +291,29 @@ def tool_deltas(self) -> list[ToolDelta]: ) return deltas + @property + def files(self) -> list[FilePart]: + """All file parts in the message.""" + return [p for p in self.parts if isinstance(p, FilePart)] + + @property + def images(self) -> list[FilePart]: + """File parts with ``image/*`` media types.""" + return [ + p + for p in self.parts + if isinstance(p, FilePart) and p.media_type.startswith("image/") + ] + + @property + def videos(self) -> list[FilePart]: + """File parts with ``video/*`` media types.""" + return [ + p + for p in self.parts + if isinstance(p, FilePart) and p.media_type.startswith("video/") + ] + @property def text(self) -> str: for part in self.parts: diff --git a/src/vercel_ai_sdk/openai/__init__.py b/src/vercel_ai_sdk/openai/__init__.py index 545a76e0..730455f6 100644 --- a/src/vercel_ai_sdk/openai/__init__.py +++ b/src/vercel_ai_sdk/openai/__init__.py @@ -26,7 +26,66 @@ def _tools_to_openai(tools: Sequence[core.tools.ToolLike]) -> list[dict[str, Any ] -def _messages_to_openai(messages: list[core.messages.Message]) -> list[dict[str, Any]]: +async def _file_part_to_openai(part: core.messages.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to an OpenAI content-array element. + + Follows the OpenAI chat-completions content part formats: + + * ``image/*`` → ``image_url`` (URL or ``data:`` URL) + * ``audio/*`` → ``input_audio`` (base-64 only; URLs auto-downloaded) + * ``application/pdf`` → ``file`` (base-64 only; URLs auto-downloaded) + * ``text/*`` → ``text`` (decoded to string) + * anything else → ``ValueError`` + + OpenAI does not accept URLs for audio ``input_audio`` or PDF ``file`` + parts. When URL data is provided for these types, it is downloaded + automatically (matching the TS SDK's ``downloadAssets`` behaviour). + """ + mt = part.media_type + data = part.data + + if mt.startswith("image/"): + media_type = "image/jpeg" if mt == "image/*" else mt + url = core.media.data.data_to_data_url(data, media_type) + return {"type": "image_url", "image_url": {"url": url}} + + if mt.startswith("audio/"): + # OpenAI input_audio requires raw base-64 — download http(s) URLs. + if isinstance(data, str) and core.media.data.is_downloadable_url(data): + downloaded, _ = await core.media.download.download(data) + data = downloaded + fmt = mt.split("/", 1)[1] if "/" in mt else mt + b64 = core.media.data.data_to_base64(data) + return {"type": "input_audio", "input_audio": {"data": b64, "format": fmt}} + + if mt == "application/pdf": + # OpenAI file parts require base-64 — download http(s) URLs. + if isinstance(data, str) and core.media.data.is_downloadable_url(data): + downloaded, _ = await core.media.download.download(data) + data = downloaded + data_url = core.media.data.data_to_data_url(data, mt) + filename = part.filename or "document.pdf" + return {"type": "file", "file": {"filename": filename, "file_data": data_url}} + + if mt.startswith("text/"): + # Decode text content — URLs are passed through as text, + # bytes/base-64 are decoded to UTF-8 string. + if isinstance(data, bytes): + text_content = data.decode("utf-8") + elif core.media.data.is_url(data): + text_content = data + else: + import base64 as _b64 + + text_content = _b64.b64decode(data).decode("utf-8") + return {"type": "text", "text": text_content} + + raise ValueError(f"Unsupported media type for OpenAI: {mt}") + + +async def _messages_to_openai( + messages: list[core.messages.Message], +) -> list[dict[str, Any]]: """Convert internal messages to OpenAI API format. Converts to the OpenAI wire format: @@ -85,12 +144,28 @@ def _messages_to_openai(messages: list[core.messages.Message]) -> list[dict[str, # Emit tool results as separate messages (OpenAI API format) result.extend(tool_results) - else: - # User/system messages + elif msg.role == "system": content = "".join( p.text for p in msg.parts if isinstance(p, core.messages.TextPart) ) - result.append({"role": msg.role, "content": content}) + result.append({"role": "system", "content": content}) + else: + # User messages — may contain multimodal FileParts + has_files = any(isinstance(p, core.messages.FilePart) for p in msg.parts) + if not has_files: + # Text-only: keep simple string format (cheaper, no content array) + text = "".join( + p.text for p in msg.parts if isinstance(p, core.messages.TextPart) + ) + result.append({"role": "user", "content": text}) + else: + parts: list[dict[str, Any]] = [] + for p in msg.parts: + if isinstance(p, core.messages.TextPart): + parts.append({"type": "text", "text": p.text}) + elif isinstance(p, core.messages.FilePart): + parts.append(await _file_part_to_openai(p)) + result.append({"role": "user", "content": parts}) return result @@ -141,7 +216,7 @@ async def stream_events( output_type: type[pydantic.BaseModel] | None = None, ) -> AsyncGenerator[core.llm.StreamEvent]: """Yield raw stream events from OpenAI API.""" - openai_messages = _messages_to_openai(messages) + openai_messages = await _messages_to_openai(messages) openai_tools = _tools_to_openai(tools) if tools else None kwargs: dict[str, Any] = { diff --git a/tests/ai_gateway/__init__.py b/tests/ai_gateway/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/ai_gateway/test_errors.py b/tests/ai_gateway/test_errors.py new file mode 100644 index 00000000..6be48ae7 --- /dev/null +++ b/tests/ai_gateway/test_errors.py @@ -0,0 +1,139 @@ +"""Tests for the gateway error factory. + +The factory ``create_gateway_error`` is the real point of contact: +it parses the JSON error response from the gateway server and +dispatches to the correct error class. These tests use payloads +matching the actual gateway wire format. +""" + +from __future__ import annotations + +import json + +from vercel_ai_sdk.ai_gateway import errors + + +class TestGatewayErrorBase: + """Base class behaviour that all concrete errors inherit.""" + + def test_isinstance_hierarchy(self) -> None: + err = errors.GatewayRateLimitError("nope") + assert isinstance(err, errors.GatewayError) + assert isinstance(err, Exception) + + def test_generation_id_in_message(self) -> None: + err = errors.GatewayInternalServerError("boom", generation_id="gen-123") + assert "[gen-123]" in str(err) + assert err.generation_id == "gen-123" + + def test_cause_chained(self) -> None: + original = ValueError("original") + err = errors.GatewayInternalServerError("boom", cause=original) + assert err.__cause__ is original + + +class TestCreateGatewayError: + """The factory must dispatch on ``error.type`` from the response.""" + + def test_authentication_error_from_json_string(self) -> None: + body = json.dumps( + { + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + } + ) + err = errors.create_gateway_error( + response_body=body, + status_code=401, + api_key_provided=True, + ) + assert isinstance(err, errors.GatewayAuthenticationError) + assert err.status_code == 401 + # contextual message includes the key URL + assert "vercel.com/d?to=" in str(err) + + def test_invalid_request_error(self) -> None: + body = { + "error": { + "message": "Bad format", + "type": "invalid_request_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=400) + assert isinstance(err, errors.GatewayInvalidRequestError) + assert err.status_code == 400 + + def test_rate_limit_error(self) -> None: + body = { + "error": { + "message": "Rate limit exceeded", + "type": "rate_limit_exceeded", + } + } + err = errors.create_gateway_error(response_body=body, status_code=429) + assert isinstance(err, errors.GatewayRateLimitError) + + def test_model_not_found_extracts_model_id(self) -> None: + body = { + "error": { + "message": "Model xyz not found", + "type": "model_not_found", + "param": {"modelId": "xyz"}, + } + } + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayModelNotFoundError) + assert err.model_id == "xyz" + + def test_model_not_found_without_param(self) -> None: + body = { + "error": { + "message": "Not found", + "type": "model_not_found", + } + } + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayModelNotFoundError) + assert err.model_id is None + + def test_internal_server_error(self) -> None: + body = { + "error": { + "message": "Database down", + "type": "internal_server_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=500) + assert isinstance(err, errors.GatewayInternalServerError) + + def test_unknown_type_falls_back_to_internal(self) -> None: + body = { + "error": { + "message": "Something weird", + "type": "alien_error", + } + } + err = errors.create_gateway_error(response_body=body, status_code=500) + assert isinstance(err, errors.GatewayInternalServerError) + + def test_malformed_json_string(self) -> None: + err = errors.create_gateway_error(response_body="Not JSON", status_code=500) + assert isinstance(err, errors.GatewayResponseError) + + def test_missing_error_field(self) -> None: + body = {"ferror": {"message": "oops"}} + err = errors.create_gateway_error(response_body=body, status_code=404) + assert isinstance(err, errors.GatewayResponseError) + + def test_generation_id_extracted(self) -> None: + body = { + "error": { + "message": "Rate limit", + "type": "rate_limit_exceeded", + }, + "generationId": "gen-abc", + } + err = errors.create_gateway_error(response_body=body, status_code=429) + assert err.generation_id == "gen-abc" diff --git a/tests/ai_gateway/test_gateway.py b/tests/ai_gateway/test_gateway.py new file mode 100644 index 00000000..2a49e7ea --- /dev/null +++ b/tests/ai_gateway/test_gateway.py @@ -0,0 +1,421 @@ +"""Integration tests for ``GatewayModel``. + +Every test exercises the real ``model.stream()`` method with an injected +``httpx.MockTransport``, so the full production code path is covered: + + model.stream() + → build_request_body() + → httpx POST (mock) + → SSE line parsing + → parse_stream_part() + → StreamHandler + → yield Message +""" + +from __future__ import annotations + +import json +from typing import Any + +import httpx +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.ai_gateway import GatewayModel, errors +from vercel_ai_sdk.core import messages + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sse(*events: dict[str, Any]) -> str: + """Build SSE response text from event dicts.""" + return "".join(f"data: {json.dumps(e)}\n\n" for e in events) + + +def _gateway( + handler: httpx.MockTransport, + *, + model: str = "test-provider/test-model", + api_key: str = "test-key", + provider_options: dict[str, Any] | None = None, +) -> GatewayModel: + """Create a ``GatewayModel`` wired to a mock transport.""" + return GatewayModel( + model=model, + api_key=api_key, + base_url="https://gw.test/v3/ai", + provider_options=provider_options, + _transport=handler, + ) + + +async def _collect( + model: GatewayModel, + msgs: list[messages.Message], + **kwargs: Any, +) -> list[messages.Message]: + """Drain ``model.stream()`` and return all yielded messages.""" + result: list[messages.Message] = [] + async for msg in model.stream(msgs, **kwargs): + result.append(msg) + return result + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Streaming: text, reasoning, tool calls +# --------------------------------------------------------------------------- + + +class TestStreaming: + @pytest.mark.asyncio + async def test_text_stream(self) -> None: + body = _sse( + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "Hello"}, + {"type": "text-delta", "id": "t1", "textDelta": " World"}, + {"type": "text-end", "id": "t1"}, + { + "type": "finish", + "finishReason": "stop", + "usage": { + "prompt_tokens": 5, + "completion_tokens": 2, + }, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + model = _gateway(httpx.MockTransport(handler)) + msgs = await _collect(model, [_user("Hi")]) + + final = msgs[-1] + assert final.text == "Hello World" + assert final.is_done + assert final.usage is not None + assert final.usage.input_tokens == 5 + assert final.usage.output_tokens == 2 + + @pytest.mark.asyncio + async def test_reasoning_then_text(self) -> None: + body = _sse( + {"type": "reasoning-start", "id": "r1"}, + {"type": "reasoning-delta", "id": "r1", "delta": "think"}, + {"type": "reasoning-end", "id": "r1"}, + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "42"}, + {"type": "text-end", "id": "t1"}, + {"type": "finish", "finishReason": "stop", "usage": {}}, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = (await _collect(_gateway(httpx.MockTransport(handler)), [_user("?")]))[ + -1 + ] + assert final.reasoning == "think" + assert final.text == "42" + + @pytest.mark.asyncio + async def test_streaming_tool_call(self) -> None: + body = _sse( + { + "type": "tool-input-start", + "id": "tc-1", + "toolName": "search", + }, + {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q":'}, + {"type": "tool-input-delta", "id": "tc-1", "delta": '"hi"}'}, + {"type": "tool-input-end", "id": "tc-1"}, + { + "type": "finish", + "finishReason": "tool-calls", + "usage": {}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_gateway(httpx.MockTransport(handler)), [_user("search")]) + )[-1] + tc = final.tool_calls + assert len(tc) == 1 + assert tc[0].tool_name == "search" + assert tc[0].tool_args == '{"q":"hi"}' + + @pytest.mark.asyncio + async def test_inline_file_stream(self) -> None: + """Models like Gemini-3-pro-image return inline file parts + alongside text in the language model stream.""" + body = _sse( + {"type": "text-start", "id": "t1"}, + {"type": "text-delta", "id": "t1", "textDelta": "Here is an image:"}, + {"type": "text-end", "id": "t1"}, + { + "type": "file", + "id": "f1", + "mediaType": "image/png", + "data": "iVBORw0KGgo=", + }, + { + "type": "finish", + "finishReason": "stop", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_gateway(httpx.MockTransport(handler)), [_user("draw me")]) + )[-1] + assert final.text == "Here is an image:" + assert len(final.images) == 1 + assert final.images[0].media_type == "image/png" + assert final.images[0].data == "iVBORw0KGgo=" + assert final.is_done + + @pytest.mark.asyncio + async def test_complete_tool_call_part(self) -> None: + """Non-streaming ``tool-call`` part (one shot) must also work.""" + body = _sse( + { + "type": "tool-call", + "toolCallId": "tc-1", + "toolName": "get_weather", + "input": {"city": "SF"}, + }, + { + "type": "finish", + "finishReason": "tool-calls", + "usage": {}, + }, + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + final = ( + await _collect(_gateway(httpx.MockTransport(handler)), [_user("weather")]) + )[-1] + assert len(final.tool_calls) == 1 + assert json.loads(final.tool_calls[0].tool_args) == {"city": "SF"} + + +# --------------------------------------------------------------------------- +# Request: headers, body, tools +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + model = _gateway( + httpx.MockTransport(handler), + model="anthropic/claude-sonnet-4", + api_key="sk-test", + ) + await _collect(model, [_user("Hi")]) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-gateway-protocol-version"] == "0.0.1" + assert captured["ai-language-model-specification-version"] == "3" + assert captured["ai-language-model-id"] == "anthropic/claude-sonnet-4" + assert captured["ai-language-model-streaming"] == "true" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_body_prompt_format(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hello")]) + + assert captured_body["prompt"] == [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + } + ] + + @pytest.mark.asyncio + async def test_provider_options_in_body(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + opts = {"gateway": {"order": ["bedrock", "openai"]}} + await _collect( + _gateway(httpx.MockTransport(handler), provider_options=opts), + [_user("Hi")], + ) + + assert captured_body["providerOptions"] == opts + + @pytest.mark.asyncio + async def test_real_tool_in_request_body(self) -> None: + """A real ``@tool``-decorated function must appear correctly + in the request body sent to the gateway.""" + + @ai.tool + async def lookup(query: str) -> str: + """Search the database.""" + return "result" + + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + await _collect( + _gateway(httpx.MockTransport(handler)), + [_user("find something")], + tools=[lookup], + ) + + assert "tools" in captured_body + td = captured_body["tools"][0] + assert td["name"] == "lookup" + assert td["type"] == "function" + assert "query" in td["inputSchema"]["properties"] + + @pytest.mark.asyncio + async def test_multi_turn_request_body(self) -> None: + """A multi-turn conversation including a tool result must + serialize correctly into the v3 prompt format.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse({"type": "finish", "finishReason": "stop", "usage": {}}), + ) + + tool_part = messages.ToolPart( + tool_call_id="tc-1", + tool_name="search", + tool_args='{"q": "weather"}', + status="result", + result={"temp": 72}, + ) + conversation = [ + _user("What's the weather?"), + messages.Message(role="assistant", parts=[tool_part]), + _user("Thanks, and tomorrow?"), + ] + + await _collect(_gateway(httpx.MockTransport(handler)), conversation) + + prompt = captured_body["prompt"] + # user → assistant (tool-call) → tool (tool-result) → user + assert len(prompt) == 4 + assert prompt[0]["role"] == "user" + assert prompt[1]["role"] == "assistant" + assert prompt[1]["content"][0]["type"] == "tool-call" + assert prompt[2]["role"] == "tool" + assert prompt[2]["content"][0]["type"] == "tool-result" + assert prompt[3]["role"] == "user" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + + @pytest.mark.asyncio + async def test_429_rate_limit_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 429, + json={ + "error": { + "message": "Rate limit exceeded", + "type": "rate_limit_exceeded", + } + }, + ) + + with pytest.raises(errors.GatewayRateLimitError): + await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + + @pytest.mark.asyncio + async def test_404_model_not_found(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 404, + json={ + "error": { + "message": "Model xyz not found", + "type": "model_not_found", + "param": {"modelId": "xyz"}, + } + }, + ) + + with pytest.raises(errors.GatewayModelNotFoundError) as exc_info: + await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) + assert exc_info.value.model_id == "xyz" + + @pytest.mark.asyncio + async def test_500_malformed_response(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(500, text="Not JSON") + + with pytest.raises(errors.GatewayResponseError): + await _collect(_gateway(httpx.MockTransport(handler)), [_user("Hi")]) diff --git a/tests/ai_gateway/test_gateway_image.py b/tests/ai_gateway/test_gateway_image.py new file mode 100644 index 00000000..a93cd337 --- /dev/null +++ b/tests/ai_gateway/test_gateway_image.py @@ -0,0 +1,262 @@ +"""Integration tests for ``GatewayImageModel``. + +Every test exercises the real ``model.generate()`` method with an injected +``httpx.MockTransport``, so the full production code path is covered: + + model.generate() + → extract prompt/images from messages + → httpx POST (mock) to /image-model + → JSON response parsing + → media type detection + → return Message with FileParts +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any + +import httpx +import pytest + +from vercel_ai_sdk.ai_gateway import GatewayImageModel, errors +from vercel_ai_sdk.core import messages + +# 1x1 transparent PNG (minimal valid PNG for magic-byte detection) +_PNG_HEADER = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]) +_PNG_B64 = base64.b64encode(_PNG_HEADER).decode() + +# 1x1 JPEG header +_JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) +_JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _image_model( + handler: httpx.MockTransport, + *, + model: str = "google/imagen-4.0-generate-001", + api_key: str = "test-key", +) -> GatewayImageModel: + return GatewayImageModel( + model=model, + api_key=api_key, + base_url="https://gw.test/v3/ai", + _transport=handler, + ) + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Basic generation +# --------------------------------------------------------------------------- + + +class TestGenerate: + @pytest.mark.asyncio + async def test_basic_image_generation(self) -> None: + """Simple prompt → one PNG image back.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={"images": [_PNG_B64]}, + ) + + model = _image_model(httpx.MockTransport(handler)) + msg = await model.generate([_user("A sunset over Tokyo")]) + + assert msg.role == "assistant" + assert len(msg.images) == 1 + assert msg.images[0].data == _PNG_B64 + assert msg.images[0].media_type == "image/png" + + @pytest.mark.asyncio + async def test_multiple_images(self) -> None: + """Request n=3 images.""" + + def handler(req: httpx.Request) -> httpx.Response: + body = json.loads(req.content) + assert body["n"] == 3 + return httpx.Response( + 200, + json={"images": [_PNG_B64, _JPEG_B64, _PNG_B64]}, + ) + + model = _image_model(httpx.MockTransport(handler)) + msg = await model.generate([_user("Three cats")], n=3) + + assert len(msg.images) == 3 + assert msg.images[0].media_type == "image/png" + assert msg.images[1].media_type == "image/jpeg" + assert msg.images[2].media_type == "image/png" + + @pytest.mark.asyncio + async def test_usage_parsing(self) -> None: + """Usage data from response surfaces on the Message.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 200, + json={ + "images": [_PNG_B64], + "usage": {"inputTokens": 50, "outputTokens": 100}, + }, + ) + + model = _image_model(httpx.MockTransport(handler)) + msg = await model.generate([_user("a dog")]) + + assert msg.usage is not None + assert msg.usage.input_tokens == 50 + assert msg.usage.output_tokens == 100 + + +# --------------------------------------------------------------------------- +# Request format +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + model = _image_model( + httpx.MockTransport(handler), + model="openai/gpt-image-1", + api_key="sk-test", + ) + await model.generate([_user("Hi")]) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-image-model-specification-version"] == "3" + assert captured["ai-model-id"] == "openai/gpt-image-1" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_parameters_forwarded(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + model = _image_model(httpx.MockTransport(handler)) + await model.generate( + [_user("landscape")], + n=2, + size="1024x1024", + aspect_ratio="16:9", + seed=42, + provider_options={"google": {"style": "vivid"}}, + ) + + assert captured_body["prompt"] == "landscape" + assert captured_body["n"] == 2 + assert captured_body["size"] == "1024x1024" + assert captured_body["aspectRatio"] == "16:9" + assert captured_body["seed"] == 42 + assert captured_body["providerOptions"] == {"google": {"style": "vivid"}} + + @pytest.mark.asyncio + async def test_input_images_forwarded(self) -> None: + """Input images from user messages → files in request body.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + user_msg = messages.Message( + role="user", + parts=[ + messages.TextPart(text="Edit this"), + messages.FilePart(data=_PNG_B64, media_type="image/png"), + ], + ) + model = _image_model(httpx.MockTransport(handler)) + await model.generate([user_msg]) + + assert captured_body["prompt"] == "Edit this" + assert "files" in captured_body + assert len(captured_body["files"]) == 1 + assert captured_body["files"][0]["type"] == "file" + assert captured_body["files"][0]["mediaType"] == "image/png" + + @pytest.mark.asyncio + async def test_url_posts_to_image_model_endpoint(self) -> None: + captured_url: list[str] = [] + + def handler(req: httpx.Request) -> httpx.Response: + captured_url.append(str(req.url)) + return httpx.Response(200, json={"images": [_PNG_B64]}) + + model = _image_model(httpx.MockTransport(handler)) + await model.generate([_user("test")]) + + assert captured_url[0] == "https://gw.test/v3/ai/image-model" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Invalid API key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + + @pytest.mark.asyncio + async def test_429_rate_limit_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 429, + json={ + "error": { + "message": "Rate limited", + "type": "rate_limit_exceeded", + } + }, + ) + + with pytest.raises(errors.GatewayRateLimitError): + await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + + @pytest.mark.asyncio + async def test_empty_images_returns_empty_message(self) -> None: + """Gateway returns empty images array → message with no parts.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, json={"images": []}) + + msg = await _image_model(httpx.MockTransport(handler)).generate([_user("test")]) + assert len(msg.images) == 0 diff --git a/tests/ai_gateway/test_gateway_video.py b/tests/ai_gateway/test_gateway_video.py new file mode 100644 index 00000000..e4ce682e --- /dev/null +++ b/tests/ai_gateway/test_gateway_video.py @@ -0,0 +1,354 @@ +"""Integration tests for ``GatewayVideoModel``. + +Every test exercises the real ``model.generate()`` method with an injected +``httpx.MockTransport``, so the full production code path is covered: + + model.generate() + → extract prompt/image from messages + → httpx POST (mock) to /video-model with SSE accept + → SSE event parsing + → video data handling (base64 or URL download) + → return Message with FileParts +""" + +from __future__ import annotations + +import base64 +import json +from typing import Any +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from vercel_ai_sdk.ai_gateway import GatewayVideoModel, errors +from vercel_ai_sdk.core import messages + +# MP4 magic bytes (ftyp box) +_MP4_HEADER = bytes( + [0x00, 0x00, 0x00, 0x20, 0x66, 0x74, 0x79, 0x70, 0x69, 0x73, 0x6F, 0x6D] +) +_MP4_B64 = base64.b64encode(_MP4_HEADER).decode() + +# WebM magic bytes +_WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) +_WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sse(*events: dict[str, Any]) -> str: + """Build SSE response text from event dicts.""" + return "".join(f"data: {json.dumps(e)}\n\n" for e in events) + + +def _video_model( + handler: httpx.MockTransport, + *, + model: str = "google/veo-3.0-generate-001", + api_key: str = "test-key", +) -> GatewayVideoModel: + return GatewayVideoModel( + model=model, + api_key=api_key, + base_url="https://gw.test/v3/ai", + _transport=handler, + ) + + +def _user(text: str) -> messages.Message: + return messages.Message( + role="user", + parts=[messages.TextPart(text=text)], + ) + + +# --------------------------------------------------------------------------- +# Basic generation +# --------------------------------------------------------------------------- + + +class TestGenerate: + @pytest.mark.asyncio + async def test_basic_video_generation_base64(self) -> None: + """Simple prompt → one MP4 video back via base64.""" + body = _sse( + { + "type": "result", + "videos": [ + {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"} + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + model = _video_model(httpx.MockTransport(handler)) + msg = await model.generate([_user("A cat walking on a beach")]) + + assert msg.role == "assistant" + assert len(msg.videos) == 1 + assert msg.videos[0].data == _MP4_B64 + assert msg.videos[0].media_type == "video/mp4" + + @pytest.mark.asyncio + async def test_video_generation_url(self) -> None: + """Video returned as URL → downloaded automatically.""" + body = _sse( + { + "type": "result", + "videos": [ + { + "type": "url", + "url": "https://storage.example.com/video.mp4", + "mediaType": "video/mp4", + } + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + model = _video_model(httpx.MockTransport(handler)) + + with patch( + "vercel_ai_sdk.core.media.download.download", + new_callable=AsyncMock, + return_value=(_MP4_HEADER, "video/mp4"), + ) as mock_dl: + msg = await model.generate([_user("A sunset timelapse")]) + + mock_dl.assert_called_once_with("https://storage.example.com/video.mp4") + assert len(msg.videos) == 1 + assert msg.videos[0].data == _MP4_HEADER + assert msg.videos[0].media_type == "video/mp4" + + @pytest.mark.asyncio + async def test_multiple_videos(self) -> None: + body = _sse( + { + "type": "result", + "videos": [ + {"type": "base64", "data": _MP4_B64, "mediaType": "video/mp4"}, + {"type": "base64", "data": _WEBM_B64, "mediaType": "video/webm"}, + ], + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + msg = await _video_model(httpx.MockTransport(handler)).generate( + [_user("Two versions")], n=2 + ) + assert len(msg.videos) == 2 + assert msg.videos[0].media_type == "video/mp4" + assert msg.videos[1].media_type == "video/webm" + + +# --------------------------------------------------------------------------- +# Request format +# --------------------------------------------------------------------------- + + +class TestRequest: + @pytest.mark.asyncio + async def test_protocol_headers(self) -> None: + captured: dict[str, str] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured.update(dict(req.headers)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + model = _video_model( + httpx.MockTransport(handler), + model="google/veo-3.0-generate-001", + api_key="sk-test", + ) + await model.generate([_user("test")]) + + assert captured["authorization"] == "Bearer sk-test" + assert captured["ai-video-model-specification-version"] == "3" + assert captured["ai-model-id"] == "google/veo-3.0-generate-001" + assert captured["accept"] == "text/event-stream" + assert captured["ai-gateway-auth-method"] == "api-key" + + @pytest.mark.asyncio + async def test_parameters_forwarded(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + model = _video_model(httpx.MockTransport(handler)) + await model.generate( + [_user("sunset")], + n=2, + aspect_ratio="16:9", + resolution="1920x1080", + duration=5.0, + fps=30, + seed=42, + provider_options={"google": {"enhancePrompt": True}}, + ) + + assert captured_body["prompt"] == "sunset" + assert captured_body["n"] == 2 + assert captured_body["aspectRatio"] == "16:9" + assert captured_body["resolution"] == "1920x1080" + assert captured_body["duration"] == 5.0 + assert captured_body["fps"] == 30 + assert captured_body["seed"] == 42 + assert captured_body["providerOptions"] == {"google": {"enhancePrompt": True}} + + @pytest.mark.asyncio + async def test_url_posts_to_video_model_endpoint(self) -> None: + captured_url: list[str] = [] + + def handler(req: httpx.Request) -> httpx.Response: + captured_url.append(str(req.url)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + model = _video_model(httpx.MockTransport(handler)) + await model.generate([_user("test")]) + + assert captured_url[0] == "https://gw.test/v3/ai/video-model" + + @pytest.mark.asyncio + async def test_image_to_video_input(self) -> None: + """Image in user message → image field in request body.""" + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=_sse( + { + "type": "result", + "videos": [ + { + "type": "base64", + "data": _MP4_B64, + "mediaType": "video/mp4", + } + ], + } + ), + ) + + png_b64 = base64.b64encode(b"\x89PNG").decode() + user_msg = messages.Message( + role="user", + parts=[ + messages.TextPart(text="Animate this"), + messages.FilePart(data=png_b64, media_type="image/png"), + ], + ) + model = _video_model(httpx.MockTransport(handler)) + await model.generate([user_msg]) + + assert captured_body["prompt"] == "Animate this" + assert "image" in captured_body + assert captured_body["image"]["type"] == "file" + assert captured_body["image"]["mediaType"] == "image/png" + + +# --------------------------------------------------------------------------- +# Error handling +# --------------------------------------------------------------------------- + + +class TestErrors: + @pytest.mark.asyncio + async def test_sse_error_event(self) -> None: + """Gateway returns an SSE error event → raises.""" + body = _sse( + { + "type": "error", + "message": "Content policy violation", + "errorType": "content_filter", + "statusCode": 400, + "param": None, + } + ) + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text=body) + + with pytest.raises(errors.GatewayInvalidRequestError, match="Content policy"): + await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) + + @pytest.mark.asyncio + async def test_401_authentication_error(self) -> None: + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response( + 401, + json={ + "error": { + "message": "Bad key", + "type": "authentication_error", + } + }, + ) + + with pytest.raises(errors.GatewayAuthenticationError): + await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) + + @pytest.mark.asyncio + async def test_empty_sse_stream(self) -> None: + """SSE stream with no data events → raises.""" + + def handler(req: httpx.Request) -> httpx.Response: + return httpx.Response(200, text="") + + with pytest.raises(errors.GatewayResponseError, match="SSE stream ended"): + await _video_model(httpx.MockTransport(handler)).generate([_user("test")]) diff --git a/tests/ai_gateway/test_protocol.py b/tests/ai_gateway/test_protocol.py new file mode 100644 index 00000000..7c244c2f --- /dev/null +++ b/tests/ai_gateway/test_protocol.py @@ -0,0 +1,517 @@ +"""Tests for the v3 protocol serialization and deserialization. + +Focus areas: +- ``messages_to_v3_prompt``: the critical outgoing translation layer +- ``tools_to_v3`` / ``build_request_body``: using real ``@tool`` +- ``parse_stream_part``: the critical incoming translation layer +- ``parse_generate_result``: non-streaming response handling +- ``_parse_usage``: the two distinct wire formats +""" + +from __future__ import annotations + +import json +from unittest.mock import AsyncMock, patch + +import pydantic +import pytest + +import vercel_ai_sdk as ai +from vercel_ai_sdk.ai_gateway import protocol +from vercel_ai_sdk.core import llm, messages + +# --------------------------------------------------------------------------- +# messages_to_v3_prompt +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestMessagesToV3Prompt: + async def test_system_message(self) -> None: + msgs = [ + messages.Message( + role="system", + parts=[messages.TextPart(text="You are helpful.")], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + assert result == [{"role": "system", "content": "You are helpful."}] + + async def test_user_message(self) -> None: + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hello")], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + assert result == [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}], + } + ] + + async def test_assistant_with_reasoning_and_text(self) -> None: + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ReasoningPart(text="Let me think..."), + messages.TextPart(text="42"), + ], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + content = result[0]["content"] + assert content[0] == {"type": "reasoning", "text": "Let me think..."} + assert content[1] == {"type": "text", "text": "42"} + + async def test_tool_call_with_result_produces_two_messages(self) -> None: + """A completed tool call must produce an assistant message + (with the tool-call) AND a tool message (with the result).""" + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="get_weather", + tool_args='{"city": "SF"}', + status="result", + result={"temp": 72}, + ) + ], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + assert len(result) == 2 + + # Assistant message has the tool-call + tc = result[0]["content"][0] + assert tc["type"] == "tool-call" + assert tc["toolCallId"] == "tc-1" + assert tc["input"] == {"city": "SF"} + + # Tool message has the result + tr = result[1]["content"][0] + assert tr["type"] == "tool-result" + assert tr["output"] == {"type": "json", "value": {"temp": 72}} + + async def test_tool_error_result(self) -> None: + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="get_weather", + tool_args="{}", + status="error", + result="Connection timeout", + ) + ], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + tr = result[1]["content"][0] + assert tr["output"]["type"] == "error-text" + assert tr["output"]["value"] == "Connection timeout" + + async def test_user_message_with_image_url(self) -> None: + """FilePart with image URL → downloaded and converted to data: URL.""" + fake_jpeg = b"\xff\xd8\xff\xe0" + msgs = [ + messages.Message( + role="user", + parts=[ + messages.TextPart(text="Look at this"), + messages.FilePart( + data="https://example.com/cat.jpg", media_type="image/jpeg" + ), + ], + ) + ] + with patch( + "vercel_ai_sdk.core.media.download.download", + new_callable=AsyncMock, + return_value=(fake_jpeg, "image/jpeg"), + ): + result = await protocol.messages_to_v3_prompt(msgs) + content = result[0]["content"] + assert content[0] == {"type": "text", "text": "Look at this"} + assert content[1]["type"] == "file" + assert content[1]["mediaType"] == "image/jpeg" + assert content[1]["data"].startswith("data:image/jpeg;base64,") + + async def test_user_message_with_file_bytes(self) -> None: + """FilePart with bytes → v3 file content part with data URL.""" + msgs = [ + messages.Message( + role="user", + parts=[ + messages.FilePart( + data=b"\x89PNG", media_type="image/png", filename="pic.png" + ), + ], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + part = result[0]["content"][0] + assert part["type"] == "file" + assert part["mediaType"] == "image/png" + assert part["data"].startswith("data:image/png;base64,") + assert part["filename"] == "pic.png" + + async def test_user_message_text_only_unchanged(self) -> None: + """Regression: text-only user messages still work.""" + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hello")], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + assert result == [ + {"role": "user", "content": [{"type": "text", "text": "Hello"}]} + ] + + async def test_pending_tool_call_no_tool_message(self) -> None: + """A pending tool call should NOT produce a tool-result message.""" + msgs = [ + messages.Message( + role="assistant", + parts=[ + messages.ToolPart( + tool_call_id="tc-1", + tool_name="search", + tool_args="{}", + status="pending", + ) + ], + ) + ] + result = await protocol.messages_to_v3_prompt(msgs) + assert len(result) == 1 + assert result[0]["role"] == "assistant" + + +# --------------------------------------------------------------------------- +# tools_to_v3 / build_request_body — using real @tool +# --------------------------------------------------------------------------- + + +@ai.tool +async def get_weather(city: str, units: str = "celsius") -> str: + """Get the current weather for a city.""" + return f"Sunny in {city}" + + +@pytest.mark.asyncio +class TestBuildRequestBody: + async def test_with_real_tool(self) -> None: + """Verify @tool-produced schema round-trips through + build_request_body → JSON → gateway wire format.""" + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="What's the weather?")], + ) + ] + body = await protocol.build_request_body(msgs, tools=[get_weather]) + + assert "tools" in body + tool_def = body["tools"][0] + assert tool_def["type"] == "function" + assert tool_def["name"] == "get_weather" + assert tool_def["description"] == ("Get the current weather for a city.") + # The schema comes from pydantic — verify structure, not exact dict + schema = tool_def["inputSchema"] + assert "properties" in schema + assert "city" in schema["properties"] + assert "units" in schema["properties"] + # 'city' is required (no default), 'units' is not (has default) + assert "city" in schema.get("required", []) + + async def test_with_output_type(self) -> None: + class WeatherResult(pydantic.BaseModel): + temp: float + condition: str + + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Weather?")], + ) + ] + body = await protocol.build_request_body(msgs, output_type=WeatherResult) + + assert "responseFormat" in body + rf = body["responseFormat"] + assert rf["type"] == "json" + assert rf["name"] == "WeatherResult" + assert "properties" in rf["schema"] + assert "temp" in rf["schema"]["properties"] + + async def test_provider_options_passthrough(self) -> None: + msgs = [ + messages.Message( + role="user", + parts=[messages.TextPart(text="Hi")], + ) + ] + opts = {"gateway": {"order": ["bedrock", "openai"]}} + body = await protocol.build_request_body(msgs, provider_options=opts) + assert body["providerOptions"] == opts + + +# --------------------------------------------------------------------------- +# parse_stream_part — parametrized simple 1:1 mappings +# --------------------------------------------------------------------------- + +_SIMPLE_STREAM_PARTS = [ + ( + {"type": "text-start", "id": "t1"}, + llm.TextStart(block_id="t1"), + ), + ( + {"type": "text-end", "id": "t1"}, + llm.TextEnd(block_id="t1"), + ), + ( + {"type": "reasoning-start", "id": "r1"}, + llm.ReasoningStart(block_id="r1"), + ), + ( + {"type": "reasoning-delta", "id": "r1", "delta": "hmm"}, + llm.ReasoningDelta(block_id="r1", delta="hmm"), + ), + ( + {"type": "reasoning-end", "id": "r1"}, + llm.ReasoningEnd(block_id="r1"), + ), + ( + {"type": "tool-input-start", "id": "tc-1", "toolName": "search"}, + llm.ToolStart(tool_call_id="tc-1", tool_name="search"), + ), + ( + {"type": "tool-input-delta", "id": "tc-1", "delta": '{"q"'}, + llm.ToolArgsDelta(tool_call_id="tc-1", delta='{"q"'), + ), + ( + {"type": "tool-input-end", "id": "tc-1"}, + llm.ToolEnd(tool_call_id="tc-1"), + ), +] + + +@pytest.mark.parametrize( + ("wire", "expected"), + _SIMPLE_STREAM_PARTS, + ids=[w["type"] for w, _ in _SIMPLE_STREAM_PARTS], +) +def test_parse_stream_part_simple( + wire: dict[str, object], expected: llm.StreamEvent +) -> None: + events = protocol.parse_stream_part(wire) + assert len(events) == 1 + assert events[0] == expected + + +@pytest.mark.asyncio +class TestParseStreamPartComplex: + async def test_text_delta_uses_textDelta_key(self) -> None: + """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" + events = protocol.parse_stream_part( + {"type": "text-delta", "id": "t1", "textDelta": "Hello"} + ) + assert isinstance(events[0], llm.TextDelta) + assert events[0].delta == "Hello" + + async def test_tool_call_expands_to_three_events(self) -> None: + """A complete ``tool-call`` part must expand into + ToolStart → ToolArgsDelta → ToolEnd.""" + events = protocol.parse_stream_part( + { + "type": "tool-call", + "toolCallId": "tc-1", + "toolName": "get_weather", + "input": {"city": "SF"}, + } + ) + assert len(events) == 3 + assert isinstance(events[0], llm.ToolStart) + assert events[0].tool_name == "get_weather" + assert isinstance(events[1], llm.ToolArgsDelta) + assert json.loads(events[1].delta) == {"city": "SF"} + assert isinstance(events[2], llm.ToolEnd) + + async def test_finish_flat_usage(self) -> None: + events = protocol.parse_stream_part( + { + "type": "finish", + "finishReason": "stop", + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + }, + } + ) + done = events[0] + assert isinstance(done, llm.MessageDone) + assert done.finish_reason == "stop" + assert done.usage is not None + assert done.usage.input_tokens == 10 + assert done.usage.output_tokens == 20 + + async def test_finish_v3_nested_usage(self) -> None: + events = protocol.parse_stream_part( + { + "type": "finish", + "finishReason": { + "unified": "tool-calls", + "raw": "tool_calls", + }, + "usage": { + "inputTokens": { + "total": 100, + "cacheRead": 50, + }, + "outputTokens": { + "total": 200, + "reasoning": 30, + }, + }, + } + ) + done = events[0] + assert isinstance(done, llm.MessageDone) + assert done.finish_reason == "tool-calls" + assert done.usage is not None + assert done.usage.input_tokens == 100 + assert done.usage.cache_read_tokens == 50 + assert done.usage.reasoning_tokens == 30 + + async def test_file_part(self) -> None: + """A ``file`` stream part (inline image from Gemini/GPT-5) + must produce a FileEvent.""" + events = protocol.parse_stream_part( + { + "type": "file", + "id": "f1", + "mediaType": "image/png", + "data": "iVBORw0KGgo=", + } + ) + assert len(events) == 1 + assert isinstance(events[0], llm.FileEvent) + assert events[0].block_id == "f1" + assert events[0].media_type == "image/png" + assert events[0].data == "iVBORw0KGgo=" + + async def test_file_part_defaults(self) -> None: + """A minimal ``file`` part uses sensible defaults.""" + events = protocol.parse_stream_part({"type": "file", "data": "somedata"}) + assert len(events) == 1 + assert isinstance(events[0], llm.FileEvent) + assert events[0].media_type == "application/octet-stream" + + async def test_unknown_types_produce_no_events(self) -> None: + for t in ("stream-start", "raw", "response-metadata", "banana"): + assert protocol.parse_stream_part({"type": t}) == [] + + +# --------------------------------------------------------------------------- +# parse_generate_result +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestParseGenerateResult: + async def test_text_content(self) -> None: + events = protocol.parse_generate_result( + { + "content": [{"type": "text", "text": "Hello!"}], + "finishReason": "stop", + "usage": {"prompt_tokens": 4, "completion_tokens": 10}, + } + ) + # TextStart + TextDelta + TextEnd + MessageDone + assert len(events) == 4 + assert isinstance(events[1], llm.TextDelta) + assert events[1].delta == "Hello!" + assert isinstance(events[3], llm.MessageDone) + + async def test_tool_call_content(self) -> None: + events = protocol.parse_generate_result( + { + "content": [ + { + "type": "tool-call", + "toolCallId": "tc-1", + "toolName": "search", + "input": {"query": "weather"}, + } + ], + "finishReason": "tool-calls", + } + ) + assert isinstance(events[0], llm.ToolStart) + assert isinstance(events[3], llm.MessageDone) + assert events[3].finish_reason == "tool-calls" + + async def test_file_content(self) -> None: + """A ``file`` part in non-streaming result produces a FileEvent.""" + events = protocol.parse_generate_result( + { + "content": [ + { + "type": "file", + "id": "f1", + "mediaType": "image/png", + "data": "iVBORw0KGgo=", + } + ], + "finishReason": "stop", + "usage": {"prompt_tokens": 10, "completion_tokens": 5}, + } + ) + file_events = [e for e in events if isinstance(e, llm.FileEvent)] + assert len(file_events) == 1 + assert file_events[0].media_type == "image/png" + assert isinstance(events[-1], llm.MessageDone) + + +# --------------------------------------------------------------------------- +# _parse_usage +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestParseUsage: + async def test_flat_format(self) -> None: + usage = protocol._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) + assert usage.input_tokens == 10 + assert usage.output_tokens == 20 + + async def test_v3_nested_format(self) -> None: + usage = protocol._parse_usage( + { + "inputTokens": { + "total": 100, + "cacheRead": 30, + "cacheWrite": 5, + }, + "outputTokens": {"total": 50, "reasoning": 10}, + } + ) + assert usage.input_tokens == 100 + assert usage.output_tokens == 50 + assert usage.cache_read_tokens == 30 + assert usage.cache_write_tokens == 5 + assert usage.reasoning_tokens == 10 + + async def test_non_dict_returns_empty(self) -> None: + usage = protocol._parse_usage("not a dict") + assert usage.input_tokens == 0 + assert usage.output_tokens == 0 diff --git a/tests/ai_sdk_ui/test_adapter.py b/tests/ai_sdk_ui/test_adapter.py index d5197534..8a767800 100644 --- a/tests/ai_sdk_ui/test_adapter.py +++ b/tests/ai_sdk_ui/test_adapter.py @@ -471,6 +471,36 @@ def test_ui_tool_part_with_dict_input() -> None: assert tool_part.status == "pending" # input-available maps to pending +def test_ui_file_part_converted_to_core_file_part() -> None: + """UIFilePart from the frontend is converted to a core FilePart.""" + raw_message = { + "id": "msg-1", + "role": "user", + "parts": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "file", + "mediaType": "image/png", + "url": "https://example.com/photo.png", + "filename": "photo.png", + }, + ], + } + ui_msg = ui_message.UIMessage.model_validate(raw_message) + internal = adapter.to_messages([ui_msg]) + + assert len(internal) == 1 + msg = internal[0] + assert msg.role == "user" + assert len(msg.parts) == 2 + assert isinstance(msg.parts[0], messages.TextPart) + assert isinstance(msg.parts[1], messages.FilePart) + fp = msg.parts[1] + assert fp.data == "https://example.com/photo.png" + assert fp.media_type == "image/png" + assert fp.filename == "photo.png" + + def test_ui_skips_unsupported_parts() -> None: """Test that unsupported part types are skipped gracefully.""" raw_message = { diff --git a/tests/anthropic/test_anthropic.py b/tests/anthropic/test_anthropic.py index 53d5d69b..ddae8f92 100644 --- a/tests/anthropic/test_anthropic.py +++ b/tests/anthropic/test_anthropic.py @@ -1,10 +1,16 @@ """Anthropic provider: _messages_to_anthropic conversion tests.""" +import base64 + +import pytest + from vercel_ai_sdk.anthropic import _messages_to_anthropic -from vercel_ai_sdk.core.messages import Message, TextPart, ToolPart +from vercel_ai_sdk.core.messages import FilePart, Message, TextPart, ToolPart + +pytestmark = pytest.mark.asyncio -def test_tool_result_none_still_emits_tool_result() -> None: +async def test_tool_result_none_still_emits_tool_result() -> None: """A tool that returns None must still produce a tool_result block. Regression: when part.result is None the converter skipped the tool_result, @@ -22,7 +28,7 @@ def test_tool_result_none_still_emits_tool_result() -> None: Message(role="assistant", parts=[tool_part]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) # Should have: assistant message with tool_use, then user message with tool_result assert len(anthropic_msgs) == 2, ( @@ -41,7 +47,7 @@ def test_tool_result_none_still_emits_tool_result() -> None: assert tool_results[0]["tool_use_id"] == "toolu_01abc" -def test_tool_with_normal_result() -> None: +async def test_tool_with_normal_result() -> None: """Baseline: a tool with a normal result produces the correct pair.""" tool_part = ToolPart( tool_call_id="toolu_02xyz", @@ -54,13 +60,13 @@ def test_tool_with_normal_result() -> None: Message(role="assistant", parts=[tool_part]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) assert len(anthropic_msgs) == 2 assert anthropic_msgs[1]["content"][0]["content"] == "{'temp': 62}" -def test_tool_error_produces_tool_result() -> None: +async def test_tool_error_produces_tool_result() -> None: """Tool errors must also produce a tool_result block (with is_error=True).""" tool_part = ToolPart( tool_call_id="toolu_03err", @@ -73,7 +79,7 @@ def test_tool_error_produces_tool_result() -> None: Message(role="assistant", parts=[tool_part]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) assert len(anthropic_msgs) == 2 tool_result = anthropic_msgs[1]["content"][0] @@ -82,7 +88,7 @@ def test_tool_error_produces_tool_result() -> None: assert tool_result["content"] == "Connection timeout" -def test_multiple_tools_one_returns_none() -> None: +async def test_multiple_tools_one_returns_none() -> None: """When one of several tools returns None, all must have tool_results.""" tool_a = ToolPart( tool_call_id="toolu_a", @@ -102,7 +108,7 @@ def test_multiple_tools_one_returns_none() -> None: Message(role="assistant", parts=[tool_a, tool_b]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) assert len(anthropic_msgs) == 2 @@ -123,7 +129,7 @@ def test_multiple_tools_one_returns_none() -> None: # -- Multi-turn: consecutive user messages (tool_result + next user) ------- -def test_multi_turn_no_consecutive_same_role_messages() -> None: +async def test_multi_turn_no_consecutive_same_role_messages() -> None: """Multi-turn with tools must not produce consecutive same-role messages. Regression: when a previous assistant turn includes a tool call (with @@ -159,7 +165,7 @@ def test_multi_turn_no_consecutive_same_role_messages() -> None: ), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) # Verify no consecutive same-role messages for i in range(1, len(anthropic_msgs)): @@ -170,7 +176,7 @@ def test_multi_turn_no_consecutive_same_role_messages() -> None: ) -def test_multi_turn_tool_result_before_user_merged() -> None: +async def test_multi_turn_tool_result_before_user_merged() -> None: """When tool_result (user) is followed by a user message, they merge. The merged user message should contain both the tool_result blocks @@ -189,7 +195,7 @@ def test_multi_turn_tool_result_before_user_merged() -> None: Message(role="user", parts=[TextPart(text="thanks, what about tomorrow?")]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) # Should be: user, assistant, user (tool_result + text) assert len(anthropic_msgs) == 3 @@ -205,7 +211,7 @@ def test_multi_turn_tool_result_before_user_merged() -> None: assert tool_results[0]["tool_use_id"] == "toolu_01abc" -def test_stream_loop_second_iteration_messages() -> None: +async def test_stream_loop_second_iteration_messages() -> None: """Simulates what stream_loop sends on the 2nd LLM call in a multi-turn. After the first stream_step returns a tool call, stream_loop appends @@ -228,7 +234,7 @@ def test_stream_loop_second_iteration_messages() -> None: # No user message follows — this is the loop, not a new user turn ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) # Should be: user, assistant(tool_use), user(tool_result) assert len(anthropic_msgs) == 3 @@ -243,7 +249,7 @@ def test_stream_loop_second_iteration_messages() -> None: assert len(tool_results) == 1 -def test_pending_tool_does_not_emit_tool_result() -> None: +async def test_pending_tool_does_not_emit_tool_result() -> None: """A tool with status='pending' must not produce a tool_result block. When stream_step returns a message mid-stream (before tool execution), @@ -262,7 +268,7 @@ def test_pending_tool_does_not_emit_tool_result() -> None: Message(role="assistant", parts=[tool]), ] - _system, anthropic_msgs = _messages_to_anthropic(messages) + _system, anthropic_msgs = await _messages_to_anthropic(messages) # assistant message with tool_use, but NO user message with tool_result assert len(anthropic_msgs) == 2 @@ -274,3 +280,111 @@ def test_pending_tool_does_not_emit_tool_result() -> None: for msg in anthropic_msgs: if isinstance(msg["content"], list): assert not any(b.get("type") == "tool_result" for b in msg["content"]) + + +# -- Multimodal user messages ------------------------------------------------ + + +async def test_user_text_only_is_plain_string() -> None: + """Text-only user messages should produce a plain content string.""" + msgs = [Message(role="user", parts=[TextPart(text="Hello")])] + _sys, result = await _messages_to_anthropic(msgs) + assert result[0]["content"] == "Hello" + + +async def test_user_image_url() -> None: + """Image URL → Anthropic image block with url source.""" + msgs = [ + Message( + role="user", + parts=[ + TextPart(text="Describe this"), + FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), + ], + ) + ] + _sys, result = await _messages_to_anthropic(msgs) + content = result[0]["content"] + assert content[0] == {"type": "text", "text": "Describe this"} + assert content[1] == { + "type": "image", + "source": {"type": "url", "url": "https://example.com/cat.jpg"}, + } + + +async def test_user_image_base64() -> None: + """Base64 image → Anthropic image block with base64 source.""" + b64 = base64.b64encode(b"\x89PNG").decode() + msgs = [ + Message( + role="user", + parts=[FilePart(data=b64, media_type="image/png")], + ) + ] + _sys, result = await _messages_to_anthropic(msgs) + img = result[0]["content"][0] + assert img["type"] == "image" + assert img["source"]["type"] == "base64" + assert img["source"]["media_type"] == "image/png" + assert img["source"]["data"] == b64 + + +async def test_user_pdf_url() -> None: + """PDF URL → Anthropic document block with url source.""" + msgs = [ + Message( + role="user", + parts=[ + FilePart( + data="https://example.com/doc.pdf", media_type="application/pdf" + ) + ], + ) + ] + _sys, result = await _messages_to_anthropic(msgs) + doc = result[0]["content"][0] + assert doc["type"] == "document" + assert doc["source"] == {"type": "url", "url": "https://example.com/doc.pdf"} + + +async def test_user_pdf_base64() -> None: + """PDF base64 → Anthropic document block with base64 source.""" + b64 = base64.b64encode(b"%PDF-1.4").decode() + msgs = [ + Message( + role="user", + parts=[FilePart(data=b64, media_type="application/pdf")], + ) + ] + _sys, result = await _messages_to_anthropic(msgs) + doc = result[0]["content"][0] + assert doc["type"] == "document" + assert doc["source"]["type"] == "base64" + assert doc["source"]["media_type"] == "application/pdf" + + +async def test_user_text_plain_bytes() -> None: + """text/plain with bytes → Anthropic document with text source.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], + ) + ] + _sys, result = await _messages_to_anthropic(msgs) + doc = result[0]["content"][0] + assert doc["type"] == "document" + assert doc["source"]["type"] == "text" + assert doc["source"]["data"] == "Hello, world!" + + +async def test_unsupported_media_type_raises() -> None: + """Unsupported media type → ValueError.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data=b"\x00", media_type="video/mp4")], + ) + ] + with pytest.raises(ValueError, match="Unsupported media type"): + await _messages_to_anthropic(msgs) diff --git a/tests/core/media/__init__.py b/tests/core/media/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/media/test_data.py b/tests/core/media/test_data.py new file mode 100644 index 00000000..e2809f3a --- /dev/null +++ b/tests/core/media/test_data.py @@ -0,0 +1,80 @@ +"""Tests for media data-format helpers (URL detection, base-64, data URLs).""" + +from vercel_ai_sdk.core.media.data import ( + data_to_base64, + data_to_data_url, + is_url, + split_data_url, +) + +# -- is_url ---------------------------------------------------------------- + + +def test_is_url_http() -> None: + assert is_url("https://example.com/img.png") is True + assert is_url("http://example.com/img.png") is True + + +def test_is_url_data() -> None: + assert is_url("data:image/png;base64,abc") is True + + +def test_is_url_base64() -> None: + assert is_url("iVBORw0KGgo=") is False + + +# -- data_to_base64 ------------------------------------------------------- + + +def test_data_to_base64_bytes() -> None: + assert data_to_base64(b"\x01\x02\x03") == "AQID" + + +def test_data_to_base64_passthrough() -> None: + assert data_to_base64("AQID") == "AQID" + + +def test_data_to_base64_extracts_from_data_url() -> None: + """data: URLs must have the prefix stripped -- providers need raw base64.""" + result = data_to_base64("data:image/png;base64,AQID") + assert result == "AQID" + + +def test_data_to_base64_passthrough_http_url() -> None: + """HTTP URLs are passed through -- caller must handle.""" + url = "https://example.com/img.png" + assert data_to_base64(url) == url + + +# -- data_to_data_url ------------------------------------------------------ + + +def test_data_to_data_url_from_bytes() -> None: + result = data_to_data_url(b"\x01\x02\x03", "image/png") + assert result == "data:image/png;base64,AQID" + + +def test_data_to_data_url_passthrough_url() -> None: + url = "https://example.com/img.png" + assert data_to_data_url(url, "image/png") == url + + +# -- split_data_url -------------------------------------------------------- + + +def test_split_data_url_valid() -> None: + mt, b64 = split_data_url("data:image/png;base64,iVBOR") + assert mt == "image/png" + assert b64 == "iVBOR" + + +def test_split_data_url_non_data_url() -> None: + mt, b64 = split_data_url("https://example.com/img.png") + assert mt is None + assert b64 is None + + +def test_split_data_url_malformed() -> None: + mt, b64 = split_data_url("data:") + assert mt is None + assert b64 is None diff --git a/tests/core/media/test_detect_media_type.py b/tests/core/media/test_detect_media_type.py new file mode 100644 index 00000000..11f1a5e2 --- /dev/null +++ b/tests/core/media/test_detect_media_type.py @@ -0,0 +1,460 @@ +"""Tests for magic-byte media type detection. + +Ported from: .reference/ai/packages/ai/src/util/detect-media-type.test.ts +""" + +from __future__ import annotations + +import base64 + +from vercel_ai_sdk.core.media.detect_media_type import ( + AUDIO_SIGNATURES, + IMAGE_SIGNATURES, + detect_media_type, +) + +# --------------------------------------------------------------------------- +# Image detection +# --------------------------------------------------------------------------- + + +class TestGif: + def test_detect_gif_from_bytes(self) -> None: + data = bytes([0x47, 0x49, 0x46, 0xFF, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/gif" + + def test_detect_gif_from_base64(self) -> None: + assert detect_media_type("R0lGabc123", IMAGE_SIGNATURES) == "image/gif" + + +class TestPng: + def test_detect_png_from_bytes(self) -> None: + data = bytes([0x89, 0x50, 0x4E, 0x47, 0xFF, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/png" + + def test_detect_png_from_base64(self) -> None: + assert detect_media_type("iVBORwabc123", IMAGE_SIGNATURES) == "image/png" + + +class TestJpeg: + def test_detect_jpeg_from_bytes(self) -> None: + data = bytes([0xFF, 0xD8, 0xFF, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/jpeg" + + def test_detect_jpeg_from_base64(self) -> None: + assert detect_media_type("/9j/abc123", IMAGE_SIGNATURES) == "image/jpeg" + + +class TestWebp: + def test_detect_webp_from_bytes(self) -> None: + # RIFF + 4 bytes (file size) + WEBP + VP8 data + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, # "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, # file size (wildcard in sig) + 0x57, + 0x45, + 0x42, + 0x50, # "WEBP" + 0x56, + 0x50, + 0x38, + 0x20, # "VP8 " (trailing data) + ] + ) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/webp" + + def test_detect_webp_from_base64(self) -> None: + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x45, + 0x42, + 0x50, + 0x56, + 0x50, + 0x38, + 0x20, + ] + ) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/webp" + + def test_riff_audio_not_detected_as_webp_bytes(self) -> None: + """RIFF + WAVE should NOT match WebP.""" + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x41, + 0x56, + 0x45, # "WAVE", not "WEBP" + ] + ) + assert detect_media_type(data, IMAGE_SIGNATURES) is None + + def test_riff_audio_not_detected_as_webp_base64(self) -> None: + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x41, + 0x56, + 0x45, + ] + ) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, IMAGE_SIGNATURES) is None + + +class TestBmp: + def test_detect_bmp_from_bytes(self) -> None: + data = bytes([0x42, 0x4D, 0xFF, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/bmp" + + def test_detect_bmp_from_base64(self) -> None: + data = bytes([0x42, 0x4D, 0xFF, 0xFF]) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, IMAGE_SIGNATURES) == "image/bmp" + + +class TestTiff: + def test_detect_tiff_le_from_bytes(self) -> None: + data = bytes([0x49, 0x49, 0x2A, 0x00, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" + + def test_detect_tiff_le_from_base64(self) -> None: + assert detect_media_type("SUkqAAabc123", IMAGE_SIGNATURES) == "image/tiff" + + def test_detect_tiff_be_from_bytes(self) -> None: + data = bytes([0x4D, 0x4D, 0x00, 0x2A, 0xFF]) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/tiff" + + def test_detect_tiff_be_from_base64(self) -> None: + assert detect_media_type("TU0AKgabc123", IMAGE_SIGNATURES) == "image/tiff" + + +class TestAvif: + def test_detect_avif_from_bytes(self) -> None: + data = bytes( + [ + 0x00, + 0x00, + 0x00, + 0x20, + 0x66, + 0x74, + 0x79, + 0x70, + 0x61, + 0x76, + 0x69, + 0x66, + 0xFF, + ] + ) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/avif" + + def test_detect_avif_from_base64(self) -> None: + assert ( + detect_media_type("AAAAIGZ0eXBhdmlmabc123", IMAGE_SIGNATURES) + == "image/avif" + ) + + +class TestHeic: + def test_detect_heic_from_bytes(self) -> None: + data = bytes( + [ + 0x00, + 0x00, + 0x00, + 0x20, + 0x66, + 0x74, + 0x79, + 0x70, + 0x68, + 0x65, + 0x69, + 0x63, + 0xFF, + ] + ) + assert detect_media_type(data, IMAGE_SIGNATURES) == "image/heic" + + def test_detect_heic_from_base64(self) -> None: + assert ( + detect_media_type("AAAAIGZ0eXBoZWljabc123", IMAGE_SIGNATURES) + == "image/heic" + ) + + +# --------------------------------------------------------------------------- +# Audio detection +# --------------------------------------------------------------------------- + + +class TestMp3: + def test_detect_mp3_from_bytes(self) -> None: + data = bytes([0xFF, 0xFB]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" + + def test_detect_mp3_from_base64(self) -> None: + assert detect_media_type("//s=", AUDIO_SIGNATURES) == "audio/mpeg" + + def test_detect_mp3_with_id3v2_tags_from_bytes(self) -> None: + """ID3v2 header (10 bytes tag, size=4) followed by MP3 frame.""" + data = bytes( + [ + 0x49, + 0x44, + 0x33, # "ID3" + 0x04, + 0x00, # version + 0x00, # flags + 0x00, + 0x00, + 0x00, + 0x04, # size = 4 (syncsafe) + 0x00, + 0x00, + 0x00, + 0x00, # 4 bytes of tag data + 0xFF, + 0xFB, # MP3 frame sync + 0x90, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + ] + ) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mpeg" + + def test_detect_mp3_with_id3v2_tags_from_base64(self) -> None: + data = bytes( + [ + 0x49, + 0x44, + 0x33, + 0x04, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x04, + 0x00, + 0x00, + 0x00, + 0x00, + 0xFF, + 0xFB, + 0x90, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + 0x00, + ] + ) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/mpeg" + + +class TestWav: + def test_detect_wav_from_bytes(self) -> None: + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, # "RIFF" + 0x24, + 0x00, + 0x00, + 0x00, # file size + 0x57, + 0x41, + 0x56, + 0x45, # "WAVE" + ] + ) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/wav" + + def test_detect_wav_from_base64(self) -> None: + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x41, + 0x56, + 0x45, + ] + ) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/wav" + + def test_webp_not_detected_as_wav_bytes(self) -> None: + """RIFF + WEBP should NOT match WAV.""" + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x45, + 0x42, + 0x50, # "WEBP", not "WAVE" + ] + ) + assert detect_media_type(data, AUDIO_SIGNATURES) is None + + def test_webp_not_detected_as_wav_base64(self) -> None: + data = bytes( + [ + 0x52, + 0x49, + 0x46, + 0x46, + 0x24, + 0x00, + 0x00, + 0x00, + 0x57, + 0x45, + 0x42, + 0x50, + ] + ) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, AUDIO_SIGNATURES) is None + + +class TestOgg: + def test_detect_ogg_from_bytes(self) -> None: + data = bytes([0x4F, 0x67, 0x67, 0x53]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/ogg" + + def test_detect_ogg_from_base64(self) -> None: + assert detect_media_type("T2dnUw", AUDIO_SIGNATURES) == "audio/ogg" + + +class TestFlac: + def test_detect_flac_from_bytes(self) -> None: + data = bytes([0x66, 0x4C, 0x61, 0x43]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/flac" + + def test_detect_flac_from_base64(self) -> None: + assert detect_media_type("ZkxhQw", AUDIO_SIGNATURES) == "audio/flac" + + +class TestAac: + def test_detect_aac_from_bytes(self) -> None: + data = bytes([0x40, 0x15, 0x00, 0x00]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/aac" + + def test_detect_aac_from_base64(self) -> None: + data = bytes([0x40, 0x15, 0x00, 0x00]) + b64 = base64.b64encode(data).decode() + assert detect_media_type(b64, AUDIO_SIGNATURES) == "audio/aac" + + +class TestMp4Audio: + def test_detect_mp4_from_bytes(self) -> None: + data = bytes([0x66, 0x74, 0x79, 0x70]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/mp4" + + def test_detect_mp4_from_base64(self) -> None: + assert detect_media_type("ZnR5cA", AUDIO_SIGNATURES) == "audio/mp4" + + +class TestWebmAudio: + def test_detect_webm_from_bytes(self) -> None: + data = bytes([0x1A, 0x45, 0xDF, 0xA3]) + assert detect_media_type(data, AUDIO_SIGNATURES) == "audio/webm" + + def test_detect_webm_from_base64(self) -> None: + assert detect_media_type("GkXfow==", AUDIO_SIGNATURES) == "audio/webm" + + +# --------------------------------------------------------------------------- +# Error / edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_unknown_image_format(self) -> None: + data = bytes([0x00, 0x01, 0x02, 0x03]) + assert detect_media_type(data, IMAGE_SIGNATURES) is None + + def test_unknown_audio_format(self) -> None: + data = bytes([0x00, 0x01, 0x02, 0x03]) + assert detect_media_type(data, AUDIO_SIGNATURES) is None + + def test_empty_bytes_image(self) -> None: + assert detect_media_type(b"", IMAGE_SIGNATURES) is None + + def test_empty_bytes_audio(self) -> None: + assert detect_media_type(b"", AUDIO_SIGNATURES) is None + + def test_short_bytes_image(self) -> None: + """Bytes shorter than longest signature should not crash.""" + data = bytes([0x89, 0x50]) # incomplete PNG + assert detect_media_type(data, IMAGE_SIGNATURES) is None + + def test_short_bytes_audio(self) -> None: + data = bytes([0x4F, 0x67]) # incomplete OGG + assert detect_media_type(data, AUDIO_SIGNATURES) is None + + def test_invalid_base64_image(self) -> None: + assert detect_media_type("invalid123", IMAGE_SIGNATURES) is None + + def test_invalid_base64_audio(self) -> None: + assert detect_media_type("invalid123", AUDIO_SIGNATURES) is None diff --git a/tests/core/media/test_models.py b/tests/core/media/test_models.py new file mode 100644 index 00000000..4fce1ea3 --- /dev/null +++ b/tests/core/media/test_models.py @@ -0,0 +1,198 @@ +"""Tests for MediaModel: extraction and message assembly.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from vercel_ai_sdk.core.media import MediaModel, MediaResult +from vercel_ai_sdk.core.messages import FilePart, Message, TextPart, Usage + +# --------------------------------------------------------------------------- +# Concrete stub for testing the base class +# --------------------------------------------------------------------------- + + +class _StubMediaModel(MediaModel): + """Minimal concrete implementation that just returns what we tell it to.""" + + def __init__(self, result: MediaResult) -> None: + self._result = result + + async def make_request( + self, + prompt: str, + input_files: list[FilePart], + *, + n: int = 1, + provider_options: dict[str, Any] | None = None, + ) -> MediaResult: + return self._result + + +# --------------------------------------------------------------------------- +# _extract_prompt +# --------------------------------------------------------------------------- + + +class TestExtractPrompt: + def test_user_text(self) -> None: + msgs = [Message(role="user", parts=[TextPart(text="hello world")])] + assert MediaModel._extract_prompt(msgs) == "hello world" + + def test_system_and_user(self) -> None: + msgs = [ + Message(role="system", parts=[TextPart(text="be helpful")]), + Message(role="user", parts=[TextPart(text="draw a cat")]), + ] + assert MediaModel._extract_prompt(msgs) == "be helpful draw a cat" + + def test_ignores_assistant(self) -> None: + msgs = [ + Message(role="user", parts=[TextPart(text="hello")]), + Message(role="assistant", parts=[TextPart(text="ignored")]), + ] + assert MediaModel._extract_prompt(msgs) == "hello" + + def test_multiple_text_parts(self) -> None: + msgs = [ + Message( + role="user", + parts=[TextPart(text="first"), TextPart(text="second")], + ) + ] + assert MediaModel._extract_prompt(msgs) == "first second" + + def test_skips_non_text_parts(self) -> None: + msgs = [ + Message( + role="user", + parts=[ + TextPart(text="prompt"), + FilePart(data=b"\x89PNG", media_type="image/png"), + ], + ) + ] + assert MediaModel._extract_prompt(msgs) == "prompt" + + def test_empty_messages(self) -> None: + assert MediaModel._extract_prompt([]) == "" + + +# --------------------------------------------------------------------------- +# _extract_input_files +# --------------------------------------------------------------------------- + + +class TestExtractInputFiles: + def test_user_file_parts(self) -> None: + img = FilePart(data=b"\x89PNG", media_type="image/png") + pdf = FilePart(data=b"%PDF", media_type="application/pdf") + msgs = [Message(role="user", parts=[TextPart(text="hi"), img, pdf])] + result = MediaModel._extract_input_files(msgs) + assert result == [img, pdf] + + def test_ignores_assistant_files(self) -> None: + img = FilePart(data=b"\x89PNG", media_type="image/png") + msgs = [Message(role="assistant", parts=[img])] + assert MediaModel._extract_input_files(msgs) == [] + + def test_ignores_system_files(self) -> None: + img = FilePart(data=b"\x89PNG", media_type="image/png") + msgs = [Message(role="system", parts=[img])] + assert MediaModel._extract_input_files(msgs) == [] + + def test_returns_all_media_types(self) -> None: + """Unlike the old extract_input_images, this returns ALL file parts.""" + img = FilePart(data=b"\x89PNG", media_type="image/png") + audio = FilePart(data=b"\xff\xfb", media_type="audio/mpeg") + video = FilePart(data=b"\x00\x00", media_type="video/mp4") + msgs = [Message(role="user", parts=[img, audio, video])] + result = MediaModel._extract_input_files(msgs) + assert len(result) == 3 + + def test_empty_messages(self) -> None: + assert MediaModel._extract_input_files([]) == [] + + def test_multiple_user_messages(self) -> None: + img1 = FilePart(data=b"\x89PNG", media_type="image/png") + img2 = FilePart(data=b"\xff\xd8", media_type="image/jpeg") + msgs = [ + Message(role="user", parts=[img1]), + Message(role="user", parts=[img2]), + ] + result = MediaModel._extract_input_files(msgs) + assert result == [img1, img2] + + +# --------------------------------------------------------------------------- +# _build_message +# --------------------------------------------------------------------------- + + +class TestBuildMessage: + def test_wraps_files_in_message(self) -> None: + fp = FilePart(data=b"\x89PNG", media_type="image/png") + result = MediaResult(files=[fp]) + msg = MediaModel._build_message(result) + assert msg.role == "assistant" + assert len(msg.parts) == 1 + assert msg.images[0] is fp + + def test_includes_usage(self) -> None: + fp = FilePart(data=b"\x89PNG", media_type="image/png") + usage = Usage(input_tokens=10, output_tokens=20) + result = MediaResult(files=[fp], usage=usage) + msg = MediaModel._build_message(result) + assert msg.usage is not None + assert msg.usage.input_tokens == 10 + assert msg.usage.output_tokens == 20 + + def test_no_usage(self) -> None: + result = MediaResult(files=[]) + msg = MediaModel._build_message(result) + assert msg.usage is None + + def test_empty_files(self) -> None: + result = MediaResult(files=[]) + msg = MediaModel._build_message(result) + assert msg.parts == [] + + +# --------------------------------------------------------------------------- +# Integration: generate() calls make_request() and wraps result +# --------------------------------------------------------------------------- + + +class TestGenerateIntegration: + @pytest.mark.asyncio + async def test_generate_round_trip(self) -> None: + """The base class extracts prompt/files and wraps the result.""" + fp_out = FilePart(data="b64data", media_type="image/png") + usage = Usage(input_tokens=5, output_tokens=15) + stub = _StubMediaModel(MediaResult(files=[fp_out], usage=usage)) + + # We can't call generate() directly on MediaModel since it doesn't + # define one — subclasses do. But we can verify the pipeline by + # calling the helpers manually. + prompt = stub._extract_prompt( + [Message(role="user", parts=[TextPart(text="a sunset")])] + ) + assert prompt == "a sunset" + + input_files = stub._extract_input_files( + [ + Message( + role="user", + parts=[FilePart(data=b"\x89PNG", media_type="image/png")], + ) + ] + ) + assert len(input_files) == 1 + + result = await stub.make_request(prompt, input_files) + msg = stub._build_message(result) + assert msg.role == "assistant" + assert msg.images == [fp_out] + assert msg.usage == usage diff --git a/tests/core/test_llm.py b/tests/core/test_llm.py index f15a97e4..f219f959 100644 --- a/tests/core/test_llm.py +++ b/tests/core/test_llm.py @@ -8,6 +8,7 @@ import vercel_ai_sdk as ai from vercel_ai_sdk.core.llm import ( + FileEvent, MessageDone, ReasoningDelta, ReasoningEnd, @@ -20,7 +21,13 @@ ToolEnd, ToolStart, ) -from vercel_ai_sdk.core.messages import ReasoningPart, TextPart, ToolPart, Usage +from vercel_ai_sdk.core.messages import ( + FilePart, + ReasoningPart, + TextPart, + ToolPart, + Usage, +) from ..conftest import MockLLM, text_msg @@ -239,3 +246,50 @@ async def test_buffer_structured_output_invalid_json_raises() -> None: with pytest.raises((json.JSONDecodeError, pydantic.ValidationError)): await llm.buffer(ai.make_messages(user="weather?"), output_type=_Weather) + + +# -- File event (inline images from LLMs like Gemini/GPT-5) --------------- + + +def test_file_event_accumulates() -> None: + """FileEvent should produce a FilePart in the message.""" + h = StreamHandler(message_id="m1") + m = h.handle_event( + FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") + ) + file_parts = [p for p in m.parts if isinstance(p, FilePart)] + assert len(file_parts) == 1 + assert file_parts[0].media_type == "image/png" + assert file_parts[0].data == "iVBORw0KGgo=" + + +def test_file_event_with_text() -> None: + """A message can have both text and file parts (e.g. Gemini image gen).""" + h = StreamHandler(message_id="m1") + h.handle_event(TextStart(block_id="t1")) + h.handle_event(TextDelta(block_id="t1", delta="Here is your image:")) + h.handle_event(TextEnd(block_id="t1")) + h.handle_event( + FileEvent(block_id="f1", media_type="image/png", data="iVBORw0KGgo=") + ) + m = h.handle_event(MessageDone(finish_reason="stop")) + + assert len(m.parts) == 2 + assert isinstance(m.parts[0], TextPart) + assert m.parts[0].text == "Here is your image:" + assert isinstance(m.parts[1], FilePart) + assert m.parts[1].media_type == "image/png" + assert m.is_done + + +def test_multiple_file_events() -> None: + """Multiple FileEvents produce multiple FileParts.""" + h = StreamHandler(message_id="m1") + h.handle_event(FileEvent(block_id="f1", media_type="image/png", data="png_data")) + m = h.handle_event( + FileEvent(block_id="f2", media_type="image/jpeg", data="jpeg_data") + ) + file_parts = [p for p in m.parts if isinstance(p, FilePart)] + assert len(file_parts) == 2 + assert file_parts[0].media_type == "image/png" + assert file_parts[1].media_type == "image/jpeg" diff --git a/tests/core/test_messages.py b/tests/core/test_messages.py index 96d443bf..c11924c5 100644 --- a/tests/core/test_messages.py +++ b/tests/core/test_messages.py @@ -1,10 +1,11 @@ """Message model: properties, ToolPart.set_result/set_error, make_messages, -StructuredOutputPart.""" +StructuredOutputPart, FilePart.""" import pydantic import pytest from vercel_ai_sdk.core.messages import ( + FilePart, HookPart, Message, ReasoningPart, @@ -317,3 +318,107 @@ def test_usage_add_merges_optional_fields() -> None: # raw is intentionally not merged assert total.raw is None + + +# -- FilePart -------------------------------------------------------------- + + +def test_file_part_creation() -> None: + """FilePart stores data, media_type, and optional filename.""" + fp = FilePart(data=b"\x89PNG", media_type="image/png", filename="pic.png") + assert fp.type == "file" + assert fp.data == b"\x89PNG" + assert fp.media_type == "image/png" + assert fp.filename == "pic.png" + + +def test_file_part_in_part_union() -> None: + """FilePart round-trips through the Part discriminated union.""" + msg = Message( + id="m1", + role="user", + parts=[ + TextPart(text="look at this"), + FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), + ], + ) + dumped = msg.model_dump() + restored = Message.model_validate(dumped) + assert len(restored.parts) == 2 + assert isinstance(restored.parts[1], FilePart) + assert restored.parts[1].media_type == "image/jpeg" + + +# -- FilePart.from_url ----------------------------------------------------- + + +def test_from_url_infers_jpeg() -> None: + fp = FilePart.from_url("https://example.com/cat.jpg") + assert fp.media_type == "image/jpeg" + assert fp.data == "https://example.com/cat.jpg" + + +def test_from_url_infers_png() -> None: + fp = FilePart.from_url("https://example.com/photo.png") + assert fp.media_type == "image/png" + + +def test_from_url_infers_pdf() -> None: + fp = FilePart.from_url("https://example.com/doc.pdf") + assert fp.media_type == "application/pdf" + + +def test_from_url_infers_from_data_url() -> None: + fp = FilePart.from_url("data:audio/wav;base64,AAAA") + assert fp.media_type == "audio/wav" + + +def test_from_url_explicit_media_type_overrides() -> None: + fp = FilePart.from_url("https://example.com/img", media_type="image/webp") + assert fp.media_type == "image/webp" + + +def test_from_url_unknown_extension_raises() -> None: + with pytest.raises(ValueError, match="Cannot infer media_type"): + FilePart.from_url("https://example.com/blob") + + +# -- FilePart.from_bytes -------------------------------------------------- + + +def test_from_bytes_detects_png() -> None: + data = bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A]) + fp = FilePart.from_bytes(data) + assert fp.media_type == "image/png" + assert fp.data == data + + +def test_from_bytes_detects_jpeg() -> None: + data = bytes([0xFF, 0xD8, 0xFF, 0xE0]) + fp = FilePart.from_bytes(data) + assert fp.media_type == "image/jpeg" + + +def test_from_bytes_detects_wav() -> None: + data = bytes( + [0x52, 0x49, 0x46, 0x46, 0x00, 0x00, 0x00, 0x00, 0x57, 0x41, 0x56, 0x45] + ) + fp = FilePart.from_bytes(data) + assert fp.media_type == "audio/wav" + + +def test_from_bytes_explicit_overrides() -> None: + """Explicit media_type should bypass detection.""" + fp = FilePart.from_bytes(b"\x00\x00", media_type="video/mp4") + assert fp.media_type == "video/mp4" + + +def test_from_bytes_preserves_filename() -> None: + data = bytes([0x89, 0x50, 0x4E, 0x47]) + fp = FilePart.from_bytes(data, filename="photo.png") + assert fp.filename == "photo.png" + + +def test_from_bytes_unknown_raises() -> None: + with pytest.raises(ValueError, match="Cannot detect media_type"): + FilePart.from_bytes(b"\x00\x01\x02\x03") diff --git a/tests/openai/__init__.py b/tests/openai/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/openai/test_openai.py b/tests/openai/test_openai.py new file mode 100644 index 00000000..22fd86d0 --- /dev/null +++ b/tests/openai/test_openai.py @@ -0,0 +1,245 @@ +"""OpenAI provider: _messages_to_openai multimodal conversion tests.""" + +import base64 +from unittest.mock import AsyncMock, patch + +import pytest + +from vercel_ai_sdk.core.messages import FilePart, Message, TextPart +from vercel_ai_sdk.openai import _messages_to_openai + +# -- text-only (regression) ------------------------------------------------ + + +@pytest.mark.asyncio +async def test_user_text_only_is_plain_string() -> None: + """Text-only user messages should produce a plain content string, not array.""" + msgs = [Message(role="user", parts=[TextPart(text="Hello")])] + result = await _messages_to_openai(msgs) + assert result == [{"role": "user", "content": "Hello"}] + + +# -- images ---------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_user_image_url() -> None: + """Image URL → OpenAI image_url content part.""" + msgs = [ + Message( + role="user", + parts=[ + TextPart(text="What's this?"), + FilePart(data="https://example.com/cat.jpg", media_type="image/jpeg"), + ], + ) + ] + result = await _messages_to_openai(msgs) + content = result[0]["content"] + assert content[0] == {"type": "text", "text": "What's this?"} + assert content[1] == { + "type": "image_url", + "image_url": {"url": "https://example.com/cat.jpg"}, + } + + +@pytest.mark.asyncio +async def test_user_image_base64() -> None: + """Base64 image data → OpenAI image_url with data URL.""" + b64 = base64.b64encode(b"\x89PNG").decode() + msgs = [ + Message( + role="user", + parts=[FilePart(data=b64, media_type="image/png")], + ) + ] + result = await _messages_to_openai(msgs) + content = result[0]["content"] + assert content[0]["type"] == "image_url" + assert content[0]["image_url"]["url"] == f"data:image/png;base64,{b64}" + + +@pytest.mark.asyncio +async def test_user_image_bytes() -> None: + """Raw bytes image → OpenAI image_url with data URL.""" + raw = b"\x89PNG" + msgs = [ + Message( + role="user", + parts=[FilePart(data=raw, media_type="image/png")], + ) + ] + result = await _messages_to_openai(msgs) + url = result[0]["content"][0]["image_url"]["url"] + assert url.startswith("data:image/png;base64,") + + +@pytest.mark.asyncio +async def test_user_image_wildcard_becomes_jpeg() -> None: + """image/* media type is normalized to image/jpeg for the data URL.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data="https://example.com/img", media_type="image/*")], + ) + ] + result = await _messages_to_openai(msgs) + # URL passthrough: no data URL conversion needed + assert result[0]["content"][0]["image_url"]["url"] == "https://example.com/img" + + +@pytest.mark.asyncio +async def test_user_image_data_url() -> None: + """data: URL image → base64 extracted correctly for image_url.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data="data:image/png;base64,AQID", media_type="image/png")], + ) + ] + result = await _messages_to_openai(msgs) + # data: URLs pass through directly for images + assert result[0]["content"][0]["image_url"]["url"] == "data:image/png;base64,AQID" + + +# -- audio ----------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_user_audio_base64() -> None: + """Audio base64 → OpenAI input_audio part.""" + b64 = base64.b64encode(b"\xff\xfb").decode() + msgs = [ + Message( + role="user", + parts=[FilePart(data=b64, media_type="audio/wav")], + ) + ] + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part["type"] == "input_audio" + assert part["input_audio"]["data"] == b64 + assert part["input_audio"]["format"] == "wav" + + +@pytest.mark.asyncio +async def test_user_audio_data_url_extracts_base64() -> None: + """Audio data: URL → base64 prefix stripped for input_audio.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data="data:audio/wav;base64,AAAA", media_type="audio/wav")], + ) + ] + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part["type"] == "input_audio" + assert part["input_audio"]["data"] == "AAAA" + + +@pytest.mark.asyncio +async def test_user_audio_url_downloads() -> None: + """Audio URLs are auto-downloaded since OpenAI requires base64.""" + fake_audio = b"\xff\xfb\x90\x00" + msgs = [ + Message( + role="user", + parts=[ + FilePart(data="https://example.com/clip.wav", media_type="audio/wav") + ], + ) + ] + with patch( + "vercel_ai_sdk.core.media.download.download", + new_callable=AsyncMock, + return_value=(fake_audio, "audio/wav"), + ): + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part["type"] == "input_audio" + assert part["input_audio"]["format"] == "wav" + # Should be base64 of the downloaded bytes + assert part["input_audio"]["data"] == base64.b64encode(fake_audio).decode() + + +# -- PDF ------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_user_pdf_base64() -> None: + """PDF base64 → OpenAI file part.""" + b64 = base64.b64encode(b"%PDF-1.4").decode() + msgs = [ + Message( + role="user", + parts=[ + FilePart(data=b64, media_type="application/pdf", filename="report.pdf") + ], + ) + ] + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part["type"] == "file" + assert part["file"]["filename"] == "report.pdf" + assert part["file"]["file_data"].startswith("data:application/pdf;base64,") + + +@pytest.mark.asyncio +async def test_user_pdf_url_downloads() -> None: + """PDF URLs are auto-downloaded since OpenAI requires base64.""" + fake_pdf = b"%PDF-1.4 fake content" + msgs = [ + Message( + role="user", + parts=[ + FilePart( + data="https://example.com/doc.pdf", + media_type="application/pdf", + filename="doc.pdf", + ) + ], + ) + ] + with patch( + "vercel_ai_sdk.core.media.download.download", + new_callable=AsyncMock, + return_value=(fake_pdf, "application/pdf"), + ): + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part["type"] == "file" + assert part["file"]["filename"] == "doc.pdf" + assert part["file"]["file_data"].startswith("data:application/pdf;base64,") + + +# -- text/* ---------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_user_text_file_bytes() -> None: + """text/* file with bytes data → decoded to text content part.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data=b"Hello, world!", media_type="text/plain")], + ) + ] + result = await _messages_to_openai(msgs) + part = result[0]["content"][0] + assert part == {"type": "text", "text": "Hello, world!"} + + +# -- unsupported ----------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_unsupported_media_type_raises() -> None: + """Unknown media type → ValueError.""" + msgs = [ + Message( + role="user", + parts=[FilePart(data=b"\x00", media_type="application/octet-stream")], + ) + ] + with pytest.raises(ValueError, match="Unsupported media type"): + await _messages_to_openai(msgs) diff --git a/uv.lock b/uv.lock index 160c9eac..4e171024 100644 --- a/uv.lock +++ b/uv.lock @@ -1049,7 +1049,7 @@ wheels = [ [[package]] name = "vercel-ai-sdk" -version = "0.0.1.dev5" +version = "0.0.1.dev6" source = { editable = "." } dependencies = [ { name = "anthropic" },