diff --git a/src/ai/models/ai_gateway/__init__.py b/src/ai/models/ai_gateway/__init__.py index 96d0bdf5..a4b22420 100644 --- a/src/ai/models/ai_gateway/__init__.py +++ b/src/ai/models/ai_gateway/__init__.py @@ -7,99 +7,13 @@ model = ai_gateway("anthropic/claude-sonnet-4") ids = await ai_gateway.list() -Heavy adapter modules (``.generate``, ``.stream``) are loaded lazily so that -``import ai`` does not pull in ``httpx`` and other I/O libraries at import -time. This matters for sandboxed runtimes (e.g. Temporal workflow workers). +The heavy ``.adapter`` module is loaded lazily so that ``import ai`` does +not pull in ``httpx`` and other I/O libraries at import time. This matters +for sandboxed runtimes (e.g. Temporal workflow workers). """ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING, Any - -from ..core import client as client_ -from ..core.model import Model from . import errors - -if TYPE_CHECKING: - pass - -_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" -_API_KEY_ENV = "AI_GATEWAY_API_KEY" -_PROTOCOL_VERSION = "0.0.1" - - -class _AIGateway: - """Callable provider factory for the Vercel AI Gateway. - - Satisfies the :class:`~ai.models.core.proto.Provider` protocol. - """ - - @property - def api_key_env(self) -> str: - return _API_KEY_ENV - - @property - def base_url(self) -> str: - return _BASE_URL - - @property - def adapter(self) -> str: - return "ai-gateway-v3" - - @property - def name(self) -> str: - return "ai-gateway" - - def client(self) -> client_.Client: - """Create a :class:`Client` from env-var credentials.""" - return client_.Client( - base_url=_BASE_URL, - api_key=os.environ.get(_API_KEY_ENV), - ) - - async def check(self, client: client_.Client, model: Model) -> bool: - """Delegate to :func:`ai_gateway.check.check`.""" - from . import check as check_ - - return await check_.check(client, model) - - def __call__( - self, - model_id: str, - *, - base_url: str | None = None, - client: client_.Client | None = None, - ) -> Model: - return Model( - id=model_id, - adapter=self.adapter, - provider=self, - client=client, - ) - - async def list(self, *, client: client_.Client | None = None) -> list[str]: - """List available model IDs from the AI Gateway.""" - c = client or self.client() - base_url = c.base_url.rstrip("/") - headers: dict[str, str] = { - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - } - if c.api_key: - headers["Authorization"] = f"Bearer {c.api_key}" - headers["ai-gateway-auth-method"] = "api-key" - - config_url = f"{base_url}/config" - response = await c.http.get(config_url, headers=headers) - response.raise_for_status() - data: dict[str, Any] = response.json() - return sorted(str(m["id"]) for m in data.get("models", [])) - - def __repr__(self) -> str: - return "ai_gateway" - - -ai_gateway = _AIGateway() +from .provider import ai_gateway __all__ = [ "ai_gateway", @@ -109,11 +23,11 @@ def __repr__(self) -> str: def __getattr__(name: str) -> object: if name == "generate": - from .generate import generate + from .adapter import generate return generate if name == "stream": - from .stream import stream + from .adapter import stream return stream raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/src/ai/models/ai_gateway/_common.py b/src/ai/models/ai_gateway/_common.py deleted file mode 100644 index 02333799..00000000 --- a/src/ai/models/ai_gateway/_common.py +++ /dev/null @@ -1,145 +0,0 @@ -"""Shared helpers for the AI Gateway v3 adapter. - -Contains utilities used by both the streaming (language-model) and generation -(image-model, video-model) endpoints. - -.. note:: - - Several helpers here are candidates for lifting to framework-level: - - - ``extract_prompt`` / ``extract_input_files`` → ``Message`` methods - - ``parse_sse_lines`` → ``core/helpers/sse.py`` -""" - -from __future__ import annotations - -import base64 -import json -from collections.abc import AsyncGenerator -from typing import Any - -import httpx - -from ...types import media -from ...types import messages as messages_ -from ..core import client as client_ -from ..core import model as model_ - -_PROTOCOL_VERSION = "0.0.1" - - -# --------------------------------------------------------------------------- -# Message extraction helpers -# --------------------------------------------------------------------------- -# TODO: lift to Message methods — these are universally useful. - - -def extract_prompt(messages: list[messages_.Message]) -> str: - """Concatenate all text from user/system messages into a single prompt string.""" - parts: list[str] = [] - for msg in messages: - if msg.role in ("user", "system"): - for p in msg.parts: - if isinstance(p, messages_.TextPart): - parts.append(p.text) - return " ".join(parts) - - -def extract_input_files(messages: list[messages_.Message]) -> list[messages_.FilePart]: - """Collect all file parts from user messages.""" - files: list[messages_.FilePart] = [] - for msg in messages: - if msg.role == "user": - for p in msg.parts: - if isinstance(p, messages_.FilePart): - files.append(p) - return files - - -# --------------------------------------------------------------------------- -# Wire format helpers -# --------------------------------------------------------------------------- - - -def file_part_to_wire(part: messages_.FilePart) -> dict[str, Any]: - """Convert a :class:`FilePart` to the gateway wire format for input files.""" - data = part.data - if isinstance(data, str) and media.is_url(data): - return {"type": "url", "url": data} - if isinstance(data, bytes): - b64 = base64.b64encode(data).decode("ascii") - elif isinstance(data, str): - b64 = data - else: - b64 = str(data) - return {"type": "file", "data": b64, "mediaType": part.media_type} - - -# --------------------------------------------------------------------------- -# Request headers -# --------------------------------------------------------------------------- - - -def request_headers( - client: client_.Client, - model: model_.Model, - *, - model_type: str = "language", - streaming: bool = False, -) -> dict[str, str]: - """Build gateway-specific request headers. - - Args: - client: The HTTP client (provides api_key). - model: The model (provides id). - model_type: One of ``"language"``, ``"image"``, ``"video"``. - streaming: Whether this is a streaming request (language-model only). - """ - h: dict[str, str] = { - "Content-Type": "application/json", - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - } - - if model_type == "language": - h["ai-language-model-specification-version"] = "3" - h["ai-language-model-id"] = model.id - h["ai-language-model-streaming"] = str(streaming).lower() - elif model_type == "image": - h["ai-image-model-specification-version"] = "3" - h["ai-model-id"] = model.id - elif model_type == "video": - h["ai-video-model-specification-version"] = "3" - h["ai-model-id"] = model.id - - if client.api_key: - h["Authorization"] = f"Bearer {client.api_key}" - h["ai-gateway-auth-method"] = "api-key" - - return h - - -# --------------------------------------------------------------------------- -# SSE parsing -# --------------------------------------------------------------------------- -# TODO: lift to core/helpers/sse.py — any SSE-based adapter will need this. - - -async def parse_sse_lines( - response: httpx.Response, -) -> AsyncGenerator[dict[str, Any]]: - """Yield parsed JSON dicts from an SSE response stream. - - Handles the ``data: `` / ``data: [DONE]`` protocol used by the - AI Gateway's streaming endpoints. - """ - async for line in response.aiter_lines(): - line = line.strip() - if not line.startswith("data: "): - continue - payload = line[len("data: ") :] - if payload == "[DONE]": - break - try: - yield json.loads(payload) - except json.JSONDecodeError: - continue diff --git a/src/ai/models/ai_gateway/stream.py b/src/ai/models/ai_gateway/adapter.py similarity index 52% rename from src/ai/models/ai_gateway/stream.py rename to src/ai/models/ai_gateway/adapter.py index aecb3c53..8ebbba25 100644 --- a/src/ai/models/ai_gateway/stream.py +++ b/src/ai/models/ai_gateway/adapter.py @@ -1,10 +1,9 @@ -"""AI Gateway v3 streaming adapter — language-model endpoint. +"""AI Gateway v3 adapter. -Handles text, tool-call, reasoning, and inline file streaming via SSE. -""" - -from __future__ import annotations +Converts internal messages to AI Gateway wire payloads and maps gateway +responses back to public event/message types.""" +import base64 import json from collections.abc import AsyncGenerator, Sequence from typing import Any @@ -12,32 +11,69 @@ import httpx import pydantic -from ...types import events as events_ -from ...types import media -from ...types import messages as messages_ -from ...types import proto as proto_ -from ...types import usage as usage_ -from ..core import client as client_ -from ..core import model as model_ -from ..core.helpers import files -from . import _common, errors +from ... import types +from .. import core +from . import errors, sdk + +# --------------------------------------------------------------------------- +# Shared request helpers +# --------------------------------------------------------------------------- + + +def _extract_prompt(messages: list[types.Message]) -> str: + """Concatenate all text from user/system messages into one prompt.""" + parts: list[str] = [] + for msg in messages: + if msg.role in ("user", "system"): + for p in msg.parts: + if isinstance(p, types.TextPart): + parts.append(p.text) + return " ".join(parts) + + +def _extract_input_files( + messages: list[types.Message], +) -> list[types.FilePart]: + """Collect all file parts from user messages.""" + files_: list[types.FilePart] = [] + for msg in messages: + if msg.role == "user": + for p in msg.parts: + if isinstance(p, types.FilePart): + files_.append(p) + return files_ + + +def _file_part_to_wire(part: types.FilePart) -> dict[str, Any]: + """Convert a :class:`FilePart` to the gateway wire format for input files.""" + data = part.data + if isinstance(data, str) and types.media.is_url(data): + return {"type": "url", "url": data} + if isinstance(data, bytes): + b64 = base64.b64encode(data).decode("ascii") + elif isinstance(data, str): + b64 = data + else: + b64 = str(data) + return {"type": "file", "data": b64, "mediaType": part.media_type} + # --------------------------------------------------------------------------- -# Request building — Message list → v3 prompt +# Streaming request building — Message list → v3 prompt # --------------------------------------------------------------------------- -async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: +async def _file_part_to_v3(part: types.FilePart) -> dict[str, Any]: """Convert a :class:`FilePart` to a v3 ``file`` content part.""" data = part.data - if isinstance(data, str) and media.is_downloadable_url(data): - downloaded, _ = await files.download(data) + if isinstance(data, str) and types.media.is_downloadable_url(data): + downloaded, _ = await core.helpers.files.download(data) data = downloaded entry: dict[str, Any] = { "type": "file", "mediaType": part.media_type, - "data": media.data_to_data_url(data, part.media_type), + "data": types.media.data_to_data_url(data, part.media_type), } if part.filename is not None: entry["filename"] = part.filename @@ -45,7 +81,7 @@ async def _file_part_to_v3(part: messages_.FilePart) -> dict[str, Any]: async def _messages_to_prompt( - messages: list[messages_.Message], + messages: list[types.Message], ) -> list[dict[str, Any]]: """Convert ``Message`` list to the v3 prompt wire format.""" result: list[dict[str, Any]] = [] @@ -54,16 +90,16 @@ async def _messages_to_prompt( match msg.role: case "system": text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) + p.text for p in msg.parts if isinstance(p, types.TextPart) ) result.append({"role": "system", "content": text}) case "user": content: list[dict[str, Any]] = [] for p in msg.parts: - if isinstance(p, messages_.TextPart): + if isinstance(p, types.TextPart): content.append({"type": "text", "text": p.text}) - elif isinstance(p, messages_.FilePart): + elif isinstance(p, types.FilePart): content.append(await _file_part_to_v3(p)) result.append({"role": "user", "content": content}) @@ -71,13 +107,13 @@ async def _messages_to_prompt( assistant_content: list[dict[str, Any]] = [] for part in msg.parts: match part: - case messages_.ReasoningPart(text=text): + case types.ReasoningPart(text=text): assistant_content.append( {"type": "reasoning", "text": text} ) - case messages_.TextPart(text=text): + case types.TextPart(text=text): assistant_content.append({"type": "text", "text": text}) - case messages_.ToolCallPart() as tp: + case types.ToolCallPart() as tp: tool_input: Any = ( json.loads(tp.tool_args) if tp.tool_args else {} ) @@ -94,7 +130,7 @@ async def _messages_to_prompt( case "tool": tool_results: list[dict[str, Any]] = [] for part in msg.parts: - if isinstance(part, messages_.ToolResultPart): + if isinstance(part, types.ToolResultPart): output = ( { "type": "error-text", @@ -123,8 +159,8 @@ async def _messages_to_prompt( async def _build_request_body( - messages: list[messages_.Message], - tools: Sequence[proto_.ToolLike] | None = None, + messages: list[types.Message], + tools: Sequence[types.proto.ToolLike] | None = None, output_type: type[Any] | None = None, **kwargs: Any, ) -> dict[str, Any]: @@ -154,13 +190,13 @@ async def _build_request_body( # --------------------------------------------------------------------------- -# SSE response parsing — v3 stream parts → public Event +# Streaming response parsing — v3 stream parts → public Event # --------------------------------------------------------------------------- def _expand_tool_call( data: dict[str, Any], streamed_tool_ids: set[str] -) -> list[events_.Event]: +) -> list[types.events.Event]: """Expand a complete ``tool-call`` part into Start + Delta + End. Returns empty when the tool was already streamed via ``tool-input-*``. @@ -172,16 +208,16 @@ def _expand_tool_call( tool_input = data.get("input", "") args_str = tool_input if isinstance(tool_input, str) else json.dumps(tool_input) return [ - events_.ToolStart(tool_call_id=tc_id, tool_name=tool_name), - events_.ToolDelta(tool_call_id=tc_id, chunk=args_str), - events_.ToolEnd(tool_call_id=tc_id), + types.events.ToolStart(tool_call_id=tc_id, tool_name=tool_name), + types.events.ToolDelta(tool_call_id=tc_id, chunk=args_str), + types.events.ToolEnd(tool_call_id=tc_id), ] -def _parse_usage(data: Any) -> usage_.Usage: +def _parse_usage(data: Any) -> types.Usage: """Parse v3 usage data into an internal ``Usage``.""" if not isinstance(data, dict): - return usage_.Usage() + return types.Usage() input_tokens_obj = data.get("inputTokens") output_tokens_obj = data.get("outputTokens") @@ -189,7 +225,7 @@ def _parse_usage(data: Any) -> usage_.Usage: if isinstance(input_tokens_obj, dict) or isinstance(output_tokens_obj, dict): inp = input_tokens_obj if isinstance(input_tokens_obj, dict) else {} out = output_tokens_obj if isinstance(output_tokens_obj, dict) else {} - return usage_.Usage( + return types.Usage( input_tokens=inp.get("total") or 0, output_tokens=out.get("total") or 0, reasoning_tokens=out.get("reasoning"), @@ -198,7 +234,7 @@ def _parse_usage(data: Any) -> usage_.Usage: raw=data, ) - return usage_.Usage( + return types.Usage( input_tokens=data.get("prompt_tokens") or data.get("inputTokens") or 0, output_tokens=(data.get("completion_tokens") or data.get("outputTokens") or 0), raw=data, @@ -207,42 +243,42 @@ def _parse_usage(data: Any) -> usage_.Usage: def _parse_stream_part( data: dict[str, Any], streamed_tool_ids: set[str] -) -> list[events_.Event]: +) -> list[types.events.Event]: """Convert a ``LanguageModelV3StreamPart`` to public events.""" match data.get("type", ""): case "text-start": - return [events_.TextStart(block_id=data.get("id", "text"))] + return [types.events.TextStart(block_id=data.get("id", "text"))] case "text-delta": return [ - events_.TextDelta( + types.events.TextDelta( block_id=data.get("id", "text"), chunk=data.get("textDelta", data.get("delta", "")), ) ] case "text-end": - return [events_.TextEnd(block_id=data.get("id", "text"))] + return [types.events.TextEnd(block_id=data.get("id", "text"))] case "reasoning-start": - return [events_.ReasoningStart(block_id=data.get("id", "reasoning"))] + return [types.events.ReasoningStart(block_id=data.get("id", "reasoning"))] case "reasoning-delta": return [ - events_.ReasoningDelta( + types.events.ReasoningDelta( block_id=data.get("id", "reasoning"), chunk=data.get("delta", ""), ) ] case "reasoning-end": - return [events_.ReasoningEnd(block_id=data.get("id", "reasoning"))] + return [types.events.ReasoningEnd(block_id=data.get("id", "reasoning"))] case "tool-input-start": tcid = data.get("id", "") streamed_tool_ids.add(tcid) return [ - events_.ToolStart( + types.events.ToolStart( tool_call_id=tcid, tool_name=data.get("toolName", ""), ) @@ -250,21 +286,21 @@ def _parse_stream_part( case "tool-input-delta": return [ - events_.ToolDelta( + types.events.ToolDelta( tool_call_id=data.get("id", ""), chunk=data.get("delta", ""), ) ] case "tool-input-end": - return [events_.ToolEnd(tool_call_id=data.get("id", ""))] + return [types.events.ToolEnd(tool_call_id=data.get("id", ""))] case "tool-call": return _expand_tool_call(data, streamed_tool_ids) case "file": return [ - events_.FileEvent( + types.events.FileEvent( block_id=data.get("id", ""), media_type=data.get("mediaType", "application/octet-stream"), data=data.get("data", ""), @@ -274,58 +310,37 @@ def _parse_stream_part( case "finish": usage_data = data.get("usage") usage = _parse_usage(usage_data) if usage_data else None - return [events_.StreamEnd(usage=usage)] + return [types.events.StreamEnd(usage=usage)] case _: return [] -# --------------------------------------------------------------------------- -# Public adapter function -# --------------------------------------------------------------------------- - - async def stream( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], *, - tools: Sequence[proto_.ToolLike] | None = None, + tools: Sequence[types.proto.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, -) -> AsyncGenerator[events_.Event]: - """Stream an LLM response through the AI Gateway v3 protocol. - - Yields :class:`~ai.types.events.Event` objects as the response streams in. - Pure delta emitter — the :class:`~ai.models.Stream` wrapper aggregates - parts into the final :class:`~ai.types.Message`. - """ +) -> AsyncGenerator[types.events.Event]: + """Stream an LLM response through the AI Gateway v3 protocol.""" body = await _build_request_body( messages, tools=tools, output_type=output_type, **kwargs ) - headers = _common.request_headers( - client, model, model_type="language", streaming=True - ) - url = f"{client.base_url.rstrip('/')}/language-model" + gateway = sdk.GatewayClient(client, model) try: - async with client.http.stream( - "POST", - url, - json=body, - headers=headers, + async with gateway.stream( + "language-model", + body, + model_type="language", + streaming=True, ) as response: - if response.status_code >= 400: - await response.aread() - raise errors.create_gateway_error( - response_body=response.text, - status_code=response.status_code, - api_key_provided=bool(client.api_key), - ) - - yield events_.StreamStart() + yield types.events.StreamStart() streamed_tool_ids: set[str] = set() - async for data in _common.parse_sse_lines(response): + async for data in gateway.iter_sse(response): for event in _parse_stream_part(data, streamed_tool_ids): yield event except errors.GatewayError: @@ -337,3 +352,120 @@ async def stream( message=f"Unexpected error during streaming: {exc}", cause=exc, ) from exc + + +# --------------------------------------------------------------------------- +# Media generation +# --------------------------------------------------------------------------- + + +async def _generate_image( + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], + params: core.ImageParams, +) -> types.Message: + """Hit ``/image-model`` and return a Message with FileParts.""" + prompt = _extract_prompt(messages) + input_files = _extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["files"] = [_file_part_to_wire(f) for f in input_files] + + gateway = sdk.GatewayClient(client, model) + response = await gateway.post_json("image-model", body, model_type="image") + + data = response.json() + raw_images: list[str] = data.get("images", []) + usage_data = data.get("usage") + usage = None + if usage_data: + usage = types.Usage( + input_tokens=usage_data.get("inputTokens") or 0, + output_tokens=usage_data.get("outputTokens") or 0, + ) + + parts: list[types.Part] = [] + for img_b64 in raw_images: + media_type = types.media.detect_image_media_type(img_b64) or "image/png" + parts.append(types.FilePart(data=img_b64, media_type=media_type)) + + return types.Message(role="assistant", parts=parts, usage=usage) + + +async def _generate_video( + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], + params: core.VideoParams, +) -> types.Message: + """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" + prompt = _extract_prompt(messages) + input_files = _extract_input_files(messages) + + body: dict[str, Any] = { + "prompt": prompt, + **params.model_dump(by_alias=True, exclude_none=True), + } + if input_files: + body["image"] = _file_part_to_wire(input_files[0]) + + gateway = sdk.GatewayClient(client, model) + + async with gateway.stream( + "video-model", + body, + model_type="video", + accept="text/event-stream", + timeout=httpx.Timeout(timeout=600.0, connect=10.0), + ) as response: + event_data: dict[str, Any] = {} + async for parsed in gateway.iter_sse(response): + event_data = parsed + break + + if not event_data: + raise errors.GatewayResponseError( + "SSE stream ended without any data events", + ) + + if event_data.get("type") == "error": + raise errors.GatewayInvalidRequestError( + message=event_data.get("message", "unknown error"), + status_code=event_data.get("statusCode", 400), + ) + + raw_videos: list[dict[str, Any]] = event_data.get("videos", []) + parts: list[types.Part] = [] + for video_data in raw_videos: + vtype = video_data.get("type", "base64") + media_type = video_data.get("mediaType", "video/mp4") + + if vtype == "url": + downloaded_bytes, content_type = await core.helpers.files.download( + video_data["url"] + ) + if content_type: + media_type = content_type + parts.append(types.FilePart(data=downloaded_bytes, media_type=media_type)) + else: + raw_data = video_data.get("data", "") + parts.append(types.FilePart(data=raw_data, media_type=media_type)) + + return types.Message(role="assistant", parts=parts) + + +async def generate( + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], + params: core.GenerateParams, +) -> types.Message: + """Generate media through the AI Gateway.""" + if isinstance(params, core.VideoParams): + return await _generate_video(client, model, messages, params) + return await _generate_image(client, model, messages, params) diff --git a/src/ai/models/ai_gateway/check.py b/src/ai/models/ai_gateway/check.py index d3ccac74..1ff68c1b 100644 --- a/src/ai/models/ai_gateway/check.py +++ b/src/ai/models/ai_gateway/check.py @@ -1,7 +1,7 @@ """AI Gateway connection check. -Verifies **both** that the client's credentials are valid and that the -model exists in the gateway's catalogue. +Verifies that the client's credentials are valid and that the model +exists in the gateway's catalogue. * Auth is validated via ``GET {origin}/v1/credits`` which requires a valid API key (returns 401/403 otherwise). @@ -11,47 +11,28 @@ Both endpoints are free — no tokens or credits are consumed. """ -from __future__ import annotations - from typing import Any -from urllib.parse import urlparse - -from ..core import client as client_ -from ..core import model as model_ -_PROTOCOL_VERSION = "0.0.1" +from .. import core +from . import sdk # HTTP status codes that indicate bad auth. _FAIL_STATUSES = frozenset({401, 403}) -def _origin(base_url: str) -> str: - """Extract the origin (scheme + host + port) from *base_url*.""" - parsed = urlparse(base_url) - return f"{parsed.scheme}://{parsed.netloc}" - - -async def check(client: client_.Client, model: model_.Model) -> bool: +async def check(client: core.client.Client, model: core.model.Model) -> bool: """Return ``True`` if *client* can reach the gateway and *model* exists.""" - base_url = client.base_url.rstrip("/") - headers: dict[str, str] = { - "ai-gateway-protocol-version": _PROTOCOL_VERSION, - } - if client.api_key: - headers["Authorization"] = f"Bearer {client.api_key}" - headers["ai-gateway-auth-method"] = "api-key" + gateway = sdk.GatewayClient(client, model) # 1. Verify credentials via /v1/credits (requires valid auth). - credits_url = f"{_origin(base_url)}/v1/credits" - auth_resp = await client.http.get(credits_url, headers=headers) + auth_resp = await gateway.get("v1/credits", origin=True) if auth_resp.status_code in _FAIL_STATUSES: return False if auth_resp.status_code != 200: auth_resp.raise_for_status() # 2. Verify model existence via /config (public catalogue). - config_url = f"{base_url}/config" - config_resp = await client.http.get(config_url, headers=headers) + config_resp = await gateway.get("config") if config_resp.status_code != 200: config_resp.raise_for_status() return False # pragma: no cover diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py deleted file mode 100644 index fec98f57..00000000 --- a/src/ai/models/ai_gateway/generate.py +++ /dev/null @@ -1,169 +0,0 @@ -"""AI Gateway v3 generation adapter — image-model and video-model endpoints. - -Unified :func:`generate` entry point that dispatches based on param type. -""" - -from __future__ import annotations - -from typing import Any - -import httpx - -from ...types import media -from ...types import messages as messages_ -from ..core import client as client_ -from ..core import model as model_ -from ..core.helpers import files -from ..core.params import GenerateParams as GenerateParams -from ..core.params import ImageParams as ImageParams -from ..core.params import VideoParams as VideoParams -from . import _common, errors - -# --------------------------------------------------------------------------- -# Image generation — /image-model -# --------------------------------------------------------------------------- - - -async def _generate_image( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - params: ImageParams, -) -> messages_.Message: - """Hit ``/image-model`` and return a Message with FileParts.""" - prompt = _common.extract_prompt(messages) - input_files = _common.extract_input_files(messages) - - body: dict[str, Any] = { - "prompt": prompt, - **params.model_dump(by_alias=True, exclude_none=True), - } - if input_files: - body["files"] = [_common.file_part_to_wire(f) for f in input_files] - - url = f"{client.base_url.rstrip('/')}/image-model" - headers = _common.request_headers(client, model, model_type="image") - - response = await client.http.post(url, json=body, headers=headers) - if response.status_code >= 400: - raise errors.create_gateway_error( - response_body=response.text, - status_code=response.status_code, - api_key_provided=bool(client.api_key), - ) - - data = response.json() - raw_images: list[str] = data.get("images", []) - usage_data = data.get("usage") - usage = None - if usage_data: - usage = messages_.Usage( - input_tokens=usage_data.get("inputTokens") or 0, - output_tokens=usage_data.get("outputTokens") or 0, - ) - - parts: list[messages_.Part] = [] - for img_b64 in raw_images: - media_type = media.detect_image_media_type(img_b64) or "image/png" - parts.append(messages_.FilePart(data=img_b64, media_type=media_type)) - - return messages_.Message(role="assistant", parts=parts, usage=usage) - - -# --------------------------------------------------------------------------- -# Video generation — /video-model (SSE response) -# --------------------------------------------------------------------------- - - -async def _generate_video( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - params: VideoParams, -) -> messages_.Message: - """Hit ``/video-model`` (SSE) and return a Message with FileParts.""" - prompt = _common.extract_prompt(messages) - input_files = _common.extract_input_files(messages) - - body: dict[str, Any] = { - "prompt": prompt, - **params.model_dump(by_alias=True, exclude_none=True), - } - if input_files: - body["image"] = _common.file_part_to_wire(input_files[0]) - - url = f"{client.base_url.rstrip('/')}/video-model" - headers = _common.request_headers(client, model, model_type="video") - headers["accept"] = "text/event-stream" - - async with client.http.stream( - "POST", - url, - json=body, - headers=headers, - timeout=httpx.Timeout(timeout=600.0, connect=10.0), - ) as response: - if response.status_code >= 400: - await response.aread() - raise errors.create_gateway_error( - response_body=response.text, - status_code=response.status_code, - api_key_provided=bool(client.api_key), - ) - - # Read first SSE data event — the gateway sends a single result event. - event_data: dict[str, Any] = {} - async for parsed in _common.parse_sse_lines(response): - event_data = parsed - break - - if not event_data: - raise errors.GatewayResponseError( - "SSE stream ended without any data events", - ) - - if event_data.get("type") == "error": - raise errors.GatewayInvalidRequestError( - message=event_data.get("message", "unknown error"), - status_code=event_data.get("statusCode", 400), - ) - - raw_videos: list[dict[str, Any]] = event_data.get("videos", []) - parts: list[messages_.Part] = [] - for video_data in raw_videos: - vtype = video_data.get("type", "base64") - media_type = video_data.get("mediaType", "video/mp4") - - if vtype == "url": - downloaded_bytes, content_type = await files.download(video_data["url"]) - if content_type: - media_type = content_type - parts.append( - messages_.FilePart(data=downloaded_bytes, media_type=media_type) - ) - else: - raw_data = video_data.get("data", "") - parts.append(messages_.FilePart(data=raw_data, media_type=media_type)) - - return messages_.Message(role="assistant", parts=parts) - - -# --------------------------------------------------------------------------- -# Public adapter function -# --------------------------------------------------------------------------- - - -async def generate( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], - params: GenerateParams, -) -> messages_.Message: - """Generate media (images or video) through the AI Gateway. - - Dispatches to ``/image-model`` or ``/video-model`` based on ``params`` - type. - """ - if isinstance(params, VideoParams): - return await _generate_video(client, model, messages, params) - return await _generate_image(client, model, messages, params) diff --git a/src/ai/models/ai_gateway/provider.py b/src/ai/models/ai_gateway/provider.py new file mode 100644 index 00000000..ccab0874 --- /dev/null +++ b/src/ai/models/ai_gateway/provider.py @@ -0,0 +1,81 @@ +"""AI Gateway provider. + +Defines the callable :data:`ai_gateway` provider, which satisfies the +:class:`~ai.models.core.proto.Provider` protocol.""" + +import os +from typing import Any + +from .. import core + +_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" +_API_KEY_ENV = "AI_GATEWAY_API_KEY" + + +class _AIGateway: + """Callable provider factory for the Vercel AI Gateway. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "ai-gateway-v3" + + @property + def name(self) -> str: + return "ai-gateway" + + def client(self) -> core.client.Client: + """Create a :class:`Client` from env-var credentials.""" + return core.client.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: core.client.Client, model: core.model.Model) -> bool: + """Delegate to :func:`ai_gateway.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: core.client.Client | None = None, + ) -> core.model.Model: + return core.model.Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: core.client.Client | None = None) -> list[str]: + """List available model IDs from the AI Gateway.""" + from . import sdk + + c = client or self.client() + gateway = sdk.GatewayClient(c) + response = await gateway.get("config") + response.raise_for_status() + data: dict[str, Any] = response.json() + return sorted(str(m["id"]) for m in data.get("models", [])) + + def __repr__(self) -> str: + return "ai_gateway" + + +ai_gateway = _AIGateway() + +__all__ = ["ai_gateway"] diff --git a/src/ai/models/ai_gateway/sdk.py b/src/ai/models/ai_gateway/sdk.py new file mode 100644 index 00000000..3a81ced8 --- /dev/null +++ b/src/ai/models/ai_gateway/sdk.py @@ -0,0 +1,167 @@ +"""AI Gateway v3 HTTP API""" + +import json +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, Literal +from urllib.parse import urlparse + +import httpx + +from .. import core +from . import errors + +_PROTOCOL_VERSION = "0.0.1" + +ModelType = Literal["language", "image", "video"] + + +class GatewayClient: + def __init__( + self, + client: core.client.Client, + model: core.model.Model | None = None, + ) -> None: + self._client = client + self._model = model + + @property + def base_url(self) -> str: + return self._client.base_url.rstrip("/") + + def url(self, path: str) -> str: + return f"{self.base_url}/{path.lstrip('/')}" + + def origin_url(self, path: str) -> str: + parsed = urlparse(self.base_url) + return f"{parsed.scheme}://{parsed.netloc}/{path.lstrip('/')}" + + def protocol_headers(self) -> dict[str, str]: + headers: dict[str, str] = { + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + if self._client.api_key: + headers["Authorization"] = f"Bearer {self._client.api_key}" + headers["ai-gateway-auth-method"] = "api-key" + return headers + + def model_headers( + self, + model_type: ModelType, + *, + streaming: bool = False, + accept: str | None = None, + ) -> dict[str, str]: + if self._model is None: + raise ValueError("Gateway model headers require a model.") + + headers = { + "Content-Type": "application/json", + **self.protocol_headers(), + } + + if model_type == "language": + headers["ai-language-model-specification-version"] = "3" + headers["ai-language-model-id"] = self._model.id + headers["ai-language-model-streaming"] = str(streaming).lower() + elif model_type == "image": + headers["ai-image-model-specification-version"] = "3" + headers["ai-model-id"] = self._model.id + elif model_type == "video": + headers["ai-video-model-specification-version"] = "3" + headers["ai-model-id"] = self._model.id + + if accept is not None: + headers["accept"] = accept + + return headers + + async def get( + self, + path: str, + *, + origin: bool = False, + headers: dict[str, str] | None = None, + ) -> httpx.Response: + url = self.origin_url(path) if origin else self.url(path) + return await self._client.http.get( + url, + headers=headers or self.protocol_headers(), + ) + + async def post_json( + self, + path: str, + body: dict[str, Any], + *, + model_type: ModelType, + timeout: httpx.Timeout | float | None = None, + ) -> httpx.Response: + kwargs: dict[str, Any] = {} + if timeout is not None: + kwargs["timeout"] = timeout + + response = await self._client.http.post( + self.url(path), + json=body, + headers=self.model_headers(model_type), + **kwargs, + ) + await self.raise_for_error(response) + return response + + @asynccontextmanager + async def stream( + self, + path: str, + body: dict[str, Any], + *, + model_type: ModelType, + streaming: bool = False, + accept: str | None = None, + timeout: httpx.Timeout | float | None = None, + ) -> AsyncIterator[httpx.Response]: + kwargs: dict[str, Any] = {} + if timeout is not None: + kwargs["timeout"] = timeout + + async with self._client.http.stream( + "POST", + self.url(path), + json=body, + headers=self.model_headers( + model_type, + streaming=streaming, + accept=accept, + ), + **kwargs, + ) as response: + await self.raise_for_error(response) + yield response + + async def raise_for_error(self, response: httpx.Response) -> None: + if response.status_code < 400: + return + + await response.aread() + raise errors.create_gateway_error( + response_body=response.text, + status_code=response.status_code, + api_key_provided=bool(self._client.api_key), + ) + + async def iter_sse( + self, + response: httpx.Response, + ) -> AsyncGenerator[dict[str, Any]]: + async for line in response.aiter_lines(): + line = line.strip() + if not line.startswith("data: "): + continue + payload = line[len("data: ") :] + if payload == "[DONE]": + break + try: + yield json.loads(payload) + except json.JSONDecodeError: + continue diff --git a/src/ai/models/anthropic/__init__.py b/src/ai/models/anthropic/__init__.py index c99b6bfd..09f94f96 100644 --- a/src/ai/models/anthropic/__init__.py +++ b/src/ai/models/anthropic/__init__.py @@ -11,90 +11,7 @@ SDK at import time. """ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -from ..core import client as client_ -from ..core.model import Model - -if TYPE_CHECKING: - pass - -_BASE_URL = "https://api.anthropic.com" -_API_KEY_ENV = "ANTHROPIC_API_KEY" -_ANTHROPIC_VERSION = "2023-06-01" - - -class _Anthropic: - """Callable provider factory for Anthropic. - - Satisfies the :class:`~ai.models.core.proto.Provider` protocol. - """ - - @property - def api_key_env(self) -> str: - return _API_KEY_ENV - - @property - def base_url(self) -> str: - return _BASE_URL - - @property - def adapter(self) -> str: - return "anthropic" - - @property - def name(self) -> str: - return "anthropic" - - def client(self) -> client_.Client: - """Create a :class:`Client` from env-var credentials.""" - return client_.Client( - base_url=_BASE_URL, - api_key=os.environ.get(_API_KEY_ENV), - ) - - async def check(self, client: client_.Client, model: Model) -> bool: - """Delegate to :func:`anthropic.check.check`.""" - from . import check as check_ - - return await check_.check(client, model) - - def __call__( - self, - model_id: str, - *, - base_url: str | None = None, - client: client_.Client | None = None, - ) -> Model: - return Model( - id=model_id, - adapter=self.adapter, - provider=self, - client=client, - ) - - async def list(self, *, client: client_.Client | None = None) -> list[str]: - """List available model IDs from the Anthropic API.""" - c = client or self.client() - headers = { - "x-api-key": c.api_key or "", - "anthropic-version": _ANTHROPIC_VERSION, - } - response = await c.http.get( - f"{c.base_url.rstrip('/')}/v1/models", headers=headers - ) - response.raise_for_status() - data: list[dict[str, object]] = response.json().get("data", []) - return sorted(str(m["id"]) for m in data) - - def __repr__(self) -> str: - return "anthropic" - - -anthropic = _Anthropic() +from .provider import anthropic __all__ = ["anthropic"] diff --git a/src/ai/models/anthropic/check.py b/src/ai/models/anthropic/check.py index 2e7a4159..55393c96 100644 --- a/src/ai/models/anthropic/check.py +++ b/src/ai/models/anthropic/check.py @@ -9,10 +9,7 @@ This endpoint is free — no tokens or credits are consumed. """ -from __future__ import annotations - -from ..core import client as client_ -from ..core import model as model_ +from .. import core _ANTHROPIC_VERSION = "2023-06-01" @@ -20,7 +17,7 @@ _FAIL_STATUSES = frozenset({401, 403, 404}) -async def check(client: client_.Client, model: model_.Model) -> bool: +async def check(client: core.client.Client, model: core.model.Model) -> bool: """Return ``True`` if *client* can reach Anthropic and *model* exists.""" if not client.api_key: return False diff --git a/src/ai/models/anthropic/provider.py b/src/ai/models/anthropic/provider.py new file mode 100644 index 00000000..76e53f07 --- /dev/null +++ b/src/ai/models/anthropic/provider.py @@ -0,0 +1,84 @@ +"""Anthropic provider. + +Defines the callable :data:`anthropic` provider, which satisfies the +:class:`~ai.models.core.proto.Provider` protocol.""" + +import os + +from .. import core + +_BASE_URL = "https://api.anthropic.com" +_API_KEY_ENV = "ANTHROPIC_API_KEY" +_ANTHROPIC_VERSION = "2023-06-01" + + +class _Anthropic: + """Callable provider factory for Anthropic. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "anthropic" + + @property + def name(self) -> str: + return "anthropic" + + def client(self) -> core.client.Client: + """Create a :class:`Client` from env-var credentials.""" + return core.client.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: core.client.Client, model: core.model.Model) -> bool: + """Delegate to :func:`anthropic.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: core.client.Client | None = None, + ) -> core.model.Model: + return core.model.Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: core.client.Client | None = None) -> list[str]: + """List available model IDs from the Anthropic API.""" + c = client or self.client() + headers = { + "x-api-key": c.api_key or "", + "anthropic-version": _ANTHROPIC_VERSION, + } + response = await c.http.get( + f"{c.base_url.rstrip('/')}/v1/models", headers=headers + ) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "anthropic" + + +anthropic = _Anthropic() + +__all__ = ["anthropic"] diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index e7db1ace..732003df 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,5 +1,6 @@ """Core types for models.""" +from . import helpers from .adapters import register_generate, register_stream from .api import ( Executor, @@ -38,4 +39,5 @@ "register_generate", "register_stream", "stream", + "helpers", ] diff --git a/src/ai/models/core/adapters.py b/src/ai/models/core/adapters.py index 4496c3cd..35669c7f 100644 --- a/src/ai/models/core/adapters.py +++ b/src/ai/models/core/adapters.py @@ -11,8 +11,6 @@ to ``model.provider.check()``. """ -from __future__ import annotations - from . import proto # --------------------------------------------------------------------------- @@ -31,8 +29,8 @@ def _ensure_adapters() -> None: return _adapters_loaded = True - from ..ai_gateway.generate import generate as ai_gw_generate - from ..ai_gateway.stream import stream as ai_gw_stream + from ..ai_gateway.adapter import generate as ai_gw_generate + from ..ai_gateway.adapter import stream as ai_gw_stream from ..anthropic.adapter import stream as anthropic_stream from ..openai.adapter import stream as openai_stream diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 3d40de04..3f4e5c78 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -5,7 +5,7 @@ import pydantic from ... import types -from ...types import integrity as integrity_ +from ...types import integrity from . import adapters, params from . import client as client_ from . import model as model_ @@ -171,7 +171,7 @@ def stream( executor: StreamExecutor = _default_executor, ) -> Stream: """Stream an LLM response.""" - messages = integrity_.prepare_messages(messages) + messages = integrity.prepare_messages(messages) request = StreamRequest(model, messages, tools, output_type) return Stream(executor._do_stream(request)) @@ -184,7 +184,7 @@ async def generate( executor: GenerateExecutor = _default_executor, ) -> types.Message: """Generate a non-streaming response (images, video, etc.).""" - messages = integrity_.prepare_messages(messages) + messages = integrity.prepare_messages(messages) request = GenerateRequest(model, messages, params) return await executor._do_generate(request) diff --git a/src/ai/models/core/helpers/__init__.py b/src/ai/models/core/helpers/__init__.py new file mode 100644 index 00000000..bddc5b6a --- /dev/null +++ b/src/ai/models/core/helpers/__init__.py @@ -0,0 +1,3 @@ +from . import files + +__all__ = ["files"] diff --git a/src/ai/models/core/helpers/files.py b/src/ai/models/core/helpers/files.py index e3357f3f..147b448a 100644 --- a/src/ai/models/core/helpers/files.py +++ b/src/ai/models/core/helpers/files.py @@ -4,8 +4,6 @@ :mod:`ai.types.media`. """ -from __future__ import annotations - import httpx DEFAULT_MAX_BYTES = 100 * 1024 * 1024 # 100 MiB (matches TS SDK) diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 769330ed..dcbd2a6c 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,7 +1,5 @@ """Model metadata types.""" -from __future__ import annotations - import dataclasses from .client import Client diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 987d2f5e..7e7290d6 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -14,9 +14,7 @@ import pydantic -from ...types import events as events_ -from ...types import messages as messages_ -from ...types import proto as types_proto_ +from ... import types if TYPE_CHECKING: from .client import Client @@ -98,12 +96,12 @@ def __call__( self, client: Client, model: Model, - messages: list[messages_.Message], + messages: list[types.Message], *, - tools: Sequence[types_proto_.ToolLike] | None = None, + tools: Sequence[types.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, **kwargs: Any, - ) -> AsyncGenerator[events_.Event]: ... + ) -> AsyncGenerator[types.Event]: ... @runtime_checkable @@ -119,9 +117,9 @@ async def __call__( self, client: Client, model: Model, - messages: list[messages_.Message], + messages: list[types.Message], params: Any, - ) -> messages_.Message: ... + ) -> types.Message: ... @runtime_checkable diff --git a/src/ai/models/openai/__init__.py b/src/ai/models/openai/__init__.py index 9dfb113b..47380505 100644 --- a/src/ai/models/openai/__init__.py +++ b/src/ai/models/openai/__init__.py @@ -11,84 +11,7 @@ SDK at import time. """ -from __future__ import annotations - -import os -from typing import TYPE_CHECKING - -from ..core import client as client_ -from ..core.model import Model - -if TYPE_CHECKING: - pass - -_BASE_URL = "https://api.openai.com/v1" -_API_KEY_ENV = "OPENAI_API_KEY" - - -class _OpenAI: - """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`. - - Satisfies the :class:`~ai.models.core.proto.Provider` protocol. - """ - - @property - def api_key_env(self) -> str: - return _API_KEY_ENV - - @property - def base_url(self) -> str: - return _BASE_URL - - @property - def adapter(self) -> str: - return "openai" - - @property - def name(self) -> str: - return "openai" - - def client(self) -> client_.Client: - """Create a :class:`Client` from env-var credentials.""" - return client_.Client( - base_url=_BASE_URL, - api_key=os.environ.get(_API_KEY_ENV), - ) - - async def check(self, client: client_.Client, model: Model) -> bool: - """Delegate to :func:`openai.check.check`.""" - from . import check as check_ - - return await check_.check(client, model) - - def __call__( - self, - model_id: str, - *, - base_url: str | None = None, - client: client_.Client | None = None, - ) -> Model: - return Model( - id=model_id, - adapter=self.adapter, - provider=self, - client=client, - ) - - async def list(self, *, client: client_.Client | None = None) -> list[str]: - """List available model IDs from the OpenAI API.""" - c = client or self.client() - headers = {"Authorization": f"Bearer {c.api_key}"} - response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) - response.raise_for_status() - data: list[dict[str, object]] = response.json().get("data", []) - return sorted(str(m["id"]) for m in data) - - def __repr__(self) -> str: - return "openai" - - -openai = _OpenAI() +from .provider import openai __all__ = ["openai"] diff --git a/src/ai/models/openai/adapter.py b/src/ai/models/openai/adapter.py index d1611785..b7f81df8 100644 --- a/src/ai/models/openai/adapter.py +++ b/src/ai/models/openai/adapter.py @@ -4,22 +4,14 @@ The SDK client is constructed from :class:`Client` params on each call. """ -from __future__ import annotations - from collections.abc import AsyncGenerator, Sequence from typing import Any import openai import pydantic -from ...types import events as events_ -from ...types import media -from ...types import messages as messages_ -from ...types import proto as proto_ -from ...types import usage as usage_ -from ..core import client as client_ -from ..core import model as model_ -from ..core.helpers import files +from ... import types +from .. import core # --------------------------------------------------------------------------- # Message / tool conversion — internal types → OpenAI wire format @@ -27,7 +19,7 @@ def _tools_to_openai( - tools: Sequence[proto_.ToolLike], + tools: Sequence[types.ToolLike], ) -> list[dict[str, Any]]: """Convert internal Tool objects to OpenAI tool schema format.""" return [ @@ -44,7 +36,7 @@ def _tools_to_openai( async def _file_part_to_openai( - part: messages_.FilePart, + part: types.FilePart, ) -> dict[str, Any]: """Convert a :class:`FilePart` to an OpenAI content-array element. @@ -59,25 +51,25 @@ async def _file_part_to_openai( if mt.startswith("image/"): media_type = "image/jpeg" if mt == "image/*" else mt - url = media.data_to_data_url(data, media_type) + url = types.media.data_to_data_url(data, media_type) return {"type": "image_url", "image_url": {"url": url}} if mt.startswith("audio/"): - if isinstance(data, str) and media.is_downloadable_url(data): - downloaded, _ = await files.download(data) + if isinstance(data, str) and types.media.is_downloadable_url(data): + downloaded, _ = await core.helpers.files.download(data) data = downloaded fmt = mt.split("/", 1)[1] if "/" in mt else mt - b64 = media.data_to_base64(data) + b64 = types.media.data_to_base64(data) return { "type": "input_audio", "input_audio": {"data": b64, "format": fmt}, } if mt == "application/pdf": - if isinstance(data, str) and media.is_downloadable_url(data): - downloaded, _ = await files.download(data) + if isinstance(data, str) and types.media.is_downloadable_url(data): + downloaded, _ = await core.helpers.files.download(data) data = downloaded - data_url = media.data_to_data_url(data, mt) + data_url = types.media.data_to_data_url(data, mt) filename = part.filename or "document.pdf" return { "type": "file", @@ -87,7 +79,7 @@ async def _file_part_to_openai( if mt.startswith("text/"): if isinstance(data, bytes): text_content = data.decode("utf-8") - elif media.is_url(data): + elif types.media.is_url(data): text_content = data else: import base64 as _b64 @@ -99,7 +91,7 @@ async def _file_part_to_openai( async def _messages_to_openai( - messages: list[messages_.Message], + messages: list[types.Message], ) -> list[dict[str, Any]]: """Convert internal messages to OpenAI API format. @@ -116,11 +108,11 @@ async def _messages_to_openai( for part in msg.parts: match part: - case messages_.ReasoningPart(text=text): + case types.ReasoningPart(text=text): reasoning += text - case messages_.TextPart(text=text): + case types.TextPart(text=text): content += text - case messages_.ToolCallPart(): + case types.ToolCallPart(): tool_calls.append( { "id": part.tool_call_id, @@ -143,7 +135,7 @@ async def _messages_to_openai( case "tool": for part in msg.parts: - if isinstance(part, messages_.ToolResultPart): + if isinstance(part, types.ToolResultPart): result.append( { "role": "tool", @@ -156,24 +148,24 @@ async def _messages_to_openai( case "system": content_text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) + p.text for p in msg.parts if isinstance(p, types.TextPart) ) result.append({"role": "system", "content": content_text}) case "user": - has_files = any(isinstance(p, messages_.FilePart) for p in msg.parts) + has_files = any(isinstance(p, types.FilePart) for p in msg.parts) if not has_files: text = "".join( - p.text for p in msg.parts if isinstance(p, messages_.TextPart) + p.text for p in msg.parts if isinstance(p, types.TextPart) ) result.append({"role": "user", "content": text}) else: parts: list[dict[str, Any]] = [] for p in msg.parts: match p: - case messages_.TextPart(text=text): + case types.TextPart(text=text): parts.append({"type": "text", "text": text}) - case messages_.FilePart(): + case types.FilePart(): parts.append(await _file_part_to_openai(p)) result.append({"role": "user", "content": parts}) return result @@ -184,7 +176,7 @@ async def _messages_to_openai( # --------------------------------------------------------------------------- -def _make_client(client: client_.Client) -> openai.AsyncOpenAI: +def _make_client(client: core.client.Client) -> openai.AsyncOpenAI: """Construct an ``AsyncOpenAI`` from our generic ``Client``.""" return openai.AsyncOpenAI( base_url=client.base_url, @@ -198,17 +190,17 @@ def _make_client(client: client_.Client) -> openai.AsyncOpenAI: async def stream( - client: client_.Client, - model: model_.Model, - messages: list[messages_.Message], + client: core.client.Client, + model: core.model.Model, + messages: list[types.Message], *, - tools: Sequence[proto_.ToolLike] | None = None, + tools: Sequence[types.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, thinking: bool = False, budget_tokens: int | None = None, reasoning_effort: str | None = None, **kwargs: Any, -) -> AsyncGenerator[events_.Event]: +) -> AsyncGenerator[types.Event]: """Stream an LLM response via the OpenAI chat completions API. Yields :class:`~ai.types.events.Event` objects as the response streams in. @@ -264,9 +256,9 @@ async def stream( text_started = False reasoning_started = False tc_state: dict[int, dict[str, Any]] = {} - usage: usage_.Usage | None = None + usage: types.Usage | None = None - yield events_.StreamStart() + yield types.events.StreamStart() async for chunk in sdk_stream: if chunk.usage is not None: @@ -279,7 +271,7 @@ async def stream( pd = getattr(chunk.usage, "prompt_tokens_details", None) if pd: cache_read = getattr(pd, "cached_tokens", None) - usage = usage_.Usage( + usage = types.Usage( input_tokens=chunk.usage.prompt_tokens or 0, output_tokens=chunk.usage.completion_tokens or 0, reasoning_tokens=reasoning_tokens, @@ -303,20 +295,20 @@ async def stream( if reasoning_value: if not reasoning_started: reasoning_started = True - yield events_.ReasoningStart(block_id="reasoning") - yield events_.ReasoningDelta( + yield types.events.ReasoningStart(block_id="reasoning") + yield types.events.ReasoningDelta( chunk=reasoning_value, block_id="reasoning" ) if delta.content: if reasoning_started: - yield events_.ReasoningEnd(block_id="reasoning") + yield types.events.ReasoningEnd(block_id="reasoning") reasoning_started = False if not text_started: text_started = True - yield events_.TextStart(block_id="text") - yield events_.TextDelta(chunk=delta.content, block_id="text") + yield types.events.TextStart(block_id="text") + yield types.events.TextDelta(chunk=delta.content, block_id="text") if delta.tool_calls: for tc in delta.tool_calls: @@ -338,28 +330,28 @@ async def stream( if not tc_state[idx]["started"] and tid: tc_state[idx]["started"] = True - yield events_.ToolStart( + yield types.events.ToolStart( tool_call_id=tid, tool_name=tname ) if tid: - yield events_.ToolDelta( + yield types.events.ToolDelta( chunk=tc.function.arguments, tool_call_id=tid, ) if choice.finish_reason is not None: if reasoning_started: - yield events_.ReasoningEnd(block_id="reasoning") + yield types.events.ReasoningEnd(block_id="reasoning") reasoning_started = False if text_started: - yield events_.TextEnd(block_id="text") + yield types.events.TextEnd(block_id="text") text_started = False for tc in tc_state.values(): if tc["started"] and tc["id"]: - yield events_.ToolEnd(tool_call_id=tc["id"]) + yield types.events.ToolEnd(tool_call_id=tc["id"]) tc["started"] = False - yield events_.StreamEnd(usage=usage) + yield types.events.StreamEnd(usage=usage) finally: await sdk_client.close() diff --git a/src/ai/models/openai/check.py b/src/ai/models/openai/check.py index ac1e240b..40edb88c 100644 --- a/src/ai/models/openai/check.py +++ b/src/ai/models/openai/check.py @@ -6,16 +6,13 @@ This endpoint is free — no tokens or credits are consumed. """ -from __future__ import annotations - -from ..core import client as client_ -from ..core import model as model_ +from .. import core # HTTP status codes that indicate bad auth or a missing model. _FAIL_STATUSES = frozenset({401, 403, 404}) -async def check(client: client_.Client, model: model_.Model) -> bool: +async def check(client: core.client.Client, model: core.model.Model) -> bool: """Return ``True`` if *client* can reach OpenAI and *model* exists.""" if not client.api_key: return False diff --git a/src/ai/models/openai/provider.py b/src/ai/models/openai/provider.py new file mode 100644 index 00000000..02025caa --- /dev/null +++ b/src/ai/models/openai/provider.py @@ -0,0 +1,78 @@ +"""OpenAI provider. + +Defines the callable :data:`openai` provider, which satisfies the +:class:`~ai.models.core.proto.Provider` protocol.""" + +import os + +from .. import core + +_BASE_URL = "https://api.openai.com/v1" +_API_KEY_ENV = "OPENAI_API_KEY" + + +class _OpenAI: + """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "openai" + + @property + def name(self) -> str: + return "openai" + + def client(self) -> core.client.Client: + """Create a :class:`Client` from env-var credentials.""" + return core.client.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: core.client.Client, model: core.model.Model) -> bool: + """Delegate to :func:`openai.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: core.client.Client | None = None, + ) -> core.model.Model: + return core.model.Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: core.client.Client | None = None) -> list[str]: + """List available model IDs from the OpenAI API.""" + c = client or self.client() + headers = {"Authorization": f"Bearer {c.api_key}"} + response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "openai" + + +openai = _OpenAI() + +__all__ = ["openai"] diff --git a/tests/models/ai_gateway/test_generate_image.py b/tests/models/ai_gateway/test_generate_image.py index b3d7bfd9..aa5c030e 100644 --- a/tests/models/ai_gateway/test_generate_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -22,7 +22,8 @@ import pytest from ai.models.ai_gateway import ai_gateway, errors -from ai.models.ai_gateway.generate import ImageParams, generate +from ai.models.ai_gateway.adapter import generate +from ai.models.core.params import ImageParams from ai.types import messages from .conftest import mock_client, user_msg diff --git a/tests/models/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py index 8233daa6..1ebab74f 100644 --- a/tests/models/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -23,7 +23,8 @@ import pytest from ai.models.ai_gateway import ai_gateway, errors -from ai.models.ai_gateway.generate import VideoParams, generate +from ai.models.ai_gateway.adapter import generate +from ai.models.core.params import VideoParams from ai.types import messages from .conftest import mock_client, sse, user_msg diff --git a/tests/models/ai_gateway/test_protocol.py b/tests/models/ai_gateway/test_protocol.py index e16a05d8..6eb0fd86 100644 --- a/tests/models/ai_gateway/test_protocol.py +++ b/tests/models/ai_gateway/test_protocol.py @@ -12,19 +12,15 @@ from __future__ import annotations -import importlib import json from unittest.mock import AsyncMock, patch import pydantic +from ai.models.ai_gateway import adapter from ai.types import events as events_ from ai.types import messages -# The ai_gateway __init__.py re-exports `stream` as a function, which -# shadows the module. Use importlib to get the actual module. -stream_mod = importlib.import_module("ai.models.ai_gateway.stream") - # --------------------------------------------------------------------------- # _messages_to_prompt # --------------------------------------------------------------------------- @@ -38,7 +34,7 @@ async def test_system_message(self) -> None: parts=[messages.TextPart(text="You are helpful.")], ) ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) assert result == [{"role": "system", "content": "You are helpful."}] async def test_user_message(self) -> None: @@ -48,7 +44,7 @@ async def test_user_message(self) -> None: parts=[messages.TextPart(text="Hello")], ) ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) assert result == [ { "role": "user", @@ -66,7 +62,7 @@ async def test_assistant_with_reasoning_and_text(self) -> None: ], ) ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "reasoning", "text": "Let me think..."} assert content[1] == {"type": "text", "text": "42"} @@ -96,7 +92,7 @@ async def test_tool_call_with_result_produces_two_messages(self) -> None: ], ), ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) assert len(result) == 2 # Assistant message has the tool-call @@ -134,7 +130,7 @@ async def test_tool_error_result(self) -> None: ], ), ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) tr = result[1]["content"][0] assert tr["output"]["type"] == "error-text" assert tr["output"]["value"] == "Connection timeout" @@ -158,7 +154,7 @@ async def test_user_message_with_image_url(self) -> None: new_callable=AsyncMock, return_value=(fake_jpeg, "image/jpeg"), ): - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) content = result[0]["content"] assert content[0] == {"type": "text", "text": "Look at this"} assert content[1]["type"] == "file" @@ -177,7 +173,7 @@ async def test_user_message_with_file_bytes(self) -> None: ], ) ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) part = result[0]["content"][0] assert part["type"] == "file" assert part["mediaType"] == "image/png" @@ -199,7 +195,7 @@ async def test_pending_tool_call_no_tool_message(self) -> None: ], ) ] - result = await stream_mod._messages_to_prompt(msgs) + result = await adapter._messages_to_prompt(msgs) assert len(result) == 1 assert result[0]["role"] == "assistant" @@ -221,7 +217,7 @@ class WeatherResult(pydantic.BaseModel): parts=[messages.TextPart(text="Weather?")], ) ] - body = await stream_mod._build_request_body(msgs, output_type=WeatherResult) + body = await adapter._build_request_body(msgs, output_type=WeatherResult) assert "responseFormat" in body rf = body["responseFormat"] @@ -234,7 +230,7 @@ class WeatherResult(pydantic.BaseModel): class TestParseStreamPartComplex: def test_text_delta_uses_textDelta_key(self) -> None: """The gateway sends ``textDelta`` (camelCase), not ``delta``.""" - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( {"type": "text-delta", "id": "t1", "textDelta": "Hello"}, set() ) assert isinstance(events[0], events_.TextDelta) @@ -243,7 +239,7 @@ def test_text_delta_uses_textDelta_key(self) -> None: def test_tool_call_expands_to_three_events(self) -> None: """A complete ``tool-call`` part must expand into ToolStart -> ToolDelta -> ToolEnd.""" - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( { "type": "tool-call", "toolCallId": "tc-1", @@ -262,11 +258,11 @@ def test_tool_call_expands_to_three_events(self) -> None: def test_tool_call_skipped_when_already_streamed(self) -> None: """A ``tool-call`` that duplicates a streamed tool is dropped.""" seen: set[str] = set() - stream_mod._parse_stream_part( + adapter._parse_stream_part( {"type": "tool-input-start", "id": "tc-1", "toolName": "get_weather"}, seen, ) - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( { "type": "tool-call", "toolCallId": "tc-1", @@ -278,7 +274,7 @@ def test_tool_call_skipped_when_already_streamed(self) -> None: assert events == [] def test_finish_flat_usage(self) -> None: - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( { "type": "finish", "finishReason": "stop", @@ -296,7 +292,7 @@ def test_finish_flat_usage(self) -> None: assert done.usage.output_tokens == 20 def test_finish_v3_nested_usage(self) -> None: - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( { "type": "finish", "finishReason": { @@ -326,7 +322,7 @@ def test_finish_v3_nested_usage(self) -> None: def test_file_part(self) -> None: """A ``file`` stream part (inline image from Gemini/GPT-5) must produce a FileEvent.""" - events = stream_mod._parse_stream_part( + events = adapter._parse_stream_part( { "type": "file", "id": "f1", @@ -343,16 +339,14 @@ def test_file_part(self) -> None: def test_file_part_defaults(self) -> None: """A minimal ``file`` part uses sensible defaults.""" - events = stream_mod._parse_stream_part( - {"type": "file", "data": "somedata"}, set() - ) + events = adapter._parse_stream_part({"type": "file", "data": "somedata"}, set()) assert len(events) == 1 assert isinstance(events[0], events_.FileEvent) assert events[0].media_type == "application/octet-stream" def test_unknown_types_produce_no_events(self) -> None: for t in ("stream-start", "raw", "response-metadata", "banana"): - assert stream_mod._parse_stream_part({"type": t}, set()) == [] + assert adapter._parse_stream_part({"type": t}, set()) == [] # --------------------------------------------------------------------------- @@ -362,12 +356,12 @@ def test_unknown_types_produce_no_events(self) -> None: class TestParseUsage: def test_flat_format(self) -> None: - usage = stream_mod._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) + usage = adapter._parse_usage({"prompt_tokens": 10, "completion_tokens": 20}) assert usage.input_tokens == 10 assert usage.output_tokens == 20 def test_v3_nested_format(self) -> None: - usage = stream_mod._parse_usage( + usage = adapter._parse_usage( { "inputTokens": { "total": 100, @@ -384,6 +378,6 @@ def test_v3_nested_format(self) -> None: assert usage.reasoning_tokens == 10 def test_non_dict_returns_empty(self) -> None: - usage = stream_mod._parse_usage("not a dict") + usage = adapter._parse_usage("not a dict") assert usage.input_tokens == 0 assert usage.output_tokens == 0 diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index b3902d76..a2ad3637 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -15,7 +15,6 @@ from __future__ import annotations -import importlib import json from typing import Any @@ -24,16 +23,12 @@ import ai from ai import models -from ai.models.ai_gateway import ai_gateway, errors +from ai.models.ai_gateway import adapter, ai_gateway, errors from ai.models.core import model as model_ from ai.types import events, messages from .conftest import mock_client, sse, user_msg -# The ai_gateway __init__.py re-exports `stream` as a function, which -# shadows the module. Use importlib to get the actual module. -stream_mod = importlib.import_module("ai.models.ai_gateway.stream") - # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -49,7 +44,7 @@ async def _collect( ) -> list[events.Event]: """Drain ``stream()`` and return all yielded events.""" result: list[events.Event] = [] - async for event in stream_mod.stream(client, model, msgs, **kwargs): + async for event in adapter.stream(client, model, msgs, **kwargs): result.append(event) return result @@ -61,7 +56,7 @@ async def _final( **kwargs: Any, ) -> messages.Message: """Drain the adapter's event stream and return the aggregated message.""" - s = models.Stream(stream_mod.stream(client, model, msgs, **kwargs)) + s = models.Stream(adapter.stream(client, model, msgs, **kwargs)) async for _ in s: pass return s.message diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index 49586172..756a3acb 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -480,7 +480,7 @@ async def test_stream_calls_prepare_messages() -> None: msgs = [ai.user_message("hi")] with patch( - "ai.models.core.api.integrity_.prepare_messages", wraps=lambda m: m + "ai.models.core.api.integrity.prepare_messages", wraps=lambda m: m ) as spy: s = models.stream(MOCK_MODEL, msgs) async for _ in s: @@ -538,7 +538,7 @@ async def test_generate_calls_prepare_messages() -> None: msgs = [ai.user_message("A cat")] with patch( - "ai.models.core.api.integrity_.prepare_messages", wraps=lambda m: m + "ai.models.core.api.integrity.prepare_messages", wraps=lambda m: m ) as spy: await models.generate(MOCK_MODEL, msgs, models.ImageParams(n=1)) spy.assert_called_once_with(msgs)