From 580b6aca18bde2c2a6af9e0303032f4e3c76b01d Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 13 Apr 2026 10:46:08 -0700 Subject: [PATCH 1/4] Split models top-level file --- src/ai/models/__init__.py | 332 +--------------------------- src/ai/models/core/__init__.py | 13 +- src/ai/models/core/adapters.py | 124 +++++++++++ src/ai/models/core/api.py | 125 +++++++++++ src/ai/models/core/client.py | 26 +++ src/ai/models/core/stream_result.py | 65 ++++++ 6 files changed, 362 insertions(+), 323 deletions(-) create mode 100644 src/ai/models/core/adapters.py create mode 100644 src/ai/models/core/api.py create mode 100644 src/ai/models/core/stream_result.py diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 2dfe243f..d2634cc5 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -15,10 +15,6 @@ async for msg in s: print(msg.text_delta, end="") - # buffer the whole response - result = await ai.models.buffer(await ai.stream(opus, msgs)) - print(result.text) - # explicit client client = ai.Client( base_url="https://custom.example.com/v3/ai", api_key="sk-...", @@ -28,326 +24,16 @@ ... """ -from __future__ import annotations - -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any - -import pydantic - -from .. import middleware as middleware_ -from ..types import messages as messages_ -from ..types import tools as tools_ from ..types.stream import StreamResultLike from .ai_gateway.types import GenerateParams, ImageParams, VideoParams +from .core.adapters import register_check, register_generate, register_stream +from .core.api import check_connection, generate, stream from .core.catalog import get_models, get_providers, register_catalog from .core.catalog import model as model -from .core.client import Client +from .core.client import _PROVIDER_DEFAULTS, Client from .core.model import Model, ModelCost from .core.proto import CheckConnFn, GenerateFn, StreamFn - -# --------------------------------------------------------------------------- -# Adapter registry — maps adapter string → adapter function. -# Adapter modules are imported lazily on first use. -# --------------------------------------------------------------------------- - -_stream_adapters: dict[str, StreamFn] = {} -_generate_adapters: dict[str, GenerateFn] = {} -_adapters_loaded = False - - -def _ensure_adapters() -> None: - """Lazily register built-in adapter functions on first call.""" - global _adapters_loaded # noqa: PLW0603 - if _adapters_loaded: - return - _adapters_loaded = True - - from .ai_gateway.generate import generate as ai_gw_generate - from .ai_gateway.stream import stream as ai_gw_stream - from .anthropic.adapter import stream as anthropic_stream - from .openai.adapter import stream as openai_stream - - _stream_adapters["ai-gateway-v3"] = ai_gw_stream - _generate_adapters["ai-gateway-v3"] = ai_gw_generate - _stream_adapters["openai"] = openai_stream - _stream_adapters["anthropic"] = anthropic_stream - - -def register_stream(adapter: str, fn: StreamFn) -> None: - """Register a stream adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _stream_adapters[adapter] = fn - - -def register_generate(adapter: str, fn: GenerateFn) -> None: - """Register a generate adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _generate_adapters[adapter] = fn - - -# --------------------------------------------------------------------------- -# Connection-check registry — maps *provider* string → check function. -# Keyed by provider (not adapter) because the check verifies "can this -# client reach this provider and does this model exist there". -# --------------------------------------------------------------------------- - -_check_fns: dict[str, CheckConnFn] = {} -_check_fns_loaded = False - - -def _ensure_check_fns() -> None: - """Lazily register built-in check functions on first call.""" - global _check_fns_loaded # noqa: PLW0603 - if _check_fns_loaded: - return - _check_fns_loaded = True - - from .ai_gateway import check as ai_gw_check - from .anthropic import check as anthropic_check - from .openai import check as openai_check - - _check_fns["ai-gateway"] = ai_gw_check.check - _check_fns["anthropic"] = anthropic_check.check - _check_fns["openai"] = openai_check.check - - -def register_check(provider: str, fn: CheckConnFn) -> None: - """Register a connection-check function for a provider. - - Use this to add checks for custom providers. - """ - _check_fns[provider] = fn - - -# --------------------------------------------------------------------------- -# Provider defaults — base URLs and env var names for auto-client creation. -# --------------------------------------------------------------------------- - -_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { - "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), - "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), - "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), -} - - -def _auto_client(model: Model) -> Client: - """Create a :class:`Client` from env vars for the given model's provider.""" - defaults = _PROVIDER_DEFAULTS.get(model.provider) - if defaults is None: - raise ValueError( - f"No default client config for provider {model.provider!r}. " - f"Pass an explicit client= argument." - ) - base_url, env_var = defaults - return Client(base_url=base_url, api_key=os.environ.get(env_var)) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -class StreamResult: - """Wrapper around a message stream. Async-iterable; collects the final result. - - Properties like ``.text`` and ``.tool_calls`` delegate to the final - ``Message`` snapshot and are available after iteration completes. - - Satisfies :class:`~ai.types.StreamResultLike`. - """ - - def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: - self._gen = gen - self._final: messages_.Message | None = None - - @classmethod - def from_generator(cls, gen: AsyncGenerator[messages_.Message]) -> StreamResult: - """Create a :class:`StreamResult` from an async generator. - - This is the public API for middleware that needs to transform or - replace the stream returned by ``wrap_model``:: - - async def wrap_model(self, call, next): - original = await next(call) - - async def _transformed(): - async for msg in original: - yield modify(msg) - - return StreamResult.from_generator(_transformed()) - """ - return cls(gen) - - def __aiter__(self) -> AsyncGenerator[messages_.Message]: - return self._iterate() - - async def _iterate(self) -> AsyncGenerator[messages_.Message]: - async for msg in self._gen: - self._final = msg - yield msg - - @property - def text(self) -> str: - return self._final.text if self._final else "" - - @property - def tool_calls(self) -> list[messages_.ToolCallPart]: - return self._final.tool_calls if self._final else [] - - @property - def usage(self) -> messages_.Usage | None: - return self._final.usage if self._final else None - - @property - def output(self) -> Any: - """Parsed structured output from the final message, if available.""" - return self._final.output if self._final else None - - -async def stream( - model: Model, - messages: list[messages_.Message], - *, - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - client: Client | None = None, - **kwargs: Any, -) -> StreamResultLike: - """Stream an LLM response. - - Returns a :class:`StreamResultLike` that is async-iterable and - collects the final ``Message``. After iteration, access ``.text``, - ``.tool_calls``, ``.usage``, etc. - - Without middleware the concrete type is :class:`StreamResult`; with - middleware it may be any :class:`~ai.StreamResultLike`. - """ - call = middleware_.ModelContext( - model=model, - messages=messages, - tools=tools, - output_type=output_type, - client=client, - kwargs=kwargs, - ) - - async def _real(call: middleware_.ModelContext) -> StreamResultLike: - _ensure_adapters() - c = call.client or _auto_client(call.model) - adapter_fn = _stream_adapters.get(call.model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_stream_adapters)) or "(none)" - raise KeyError( - f"No stream adapter registered for adapter={call.model.adapter!r}. " - f"Registered: {registered}" - ) - return StreamResult( - adapter_fn( - c, - call.model, - call.messages, - tools=call.tools, - output_type=call.output_type, - **call.kwargs, - ) - ) - - chain = middleware_._build_model_chain(_real) - return await chain(call) - - -async def generate( - model: Model, - messages: list[messages_.Message], - params: GenerateParams | None = None, - *, - client: Client | None = None, -) -> messages_.Message: - """Generate a response (images, video, etc.). - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from env vars if none is provided. - - ``params`` controls the generation type: - - * :class:`ImageParams` — image generation (``/image-model``). - * :class:`VideoParams` — video generation (``/video-model``). - * ``None`` — auto-detect from ``model.capabilities``. - """ - call = middleware_.GenerateContext( - model=model, - messages=messages, - params=params, - client=client, - ) - - async def _real(call: middleware_.GenerateContext) -> messages_.Message: - _ensure_adapters() - c = call.client or _auto_client(call.model) - adapter_fn = _generate_adapters.get(call.model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_generate_adapters)) or "(none)" - raise KeyError( - f"No generate adapter registered for adapter={call.model.adapter!r}. " - f"Registered: {registered}" - ) - return await adapter_fn(c, call.model, call.messages, params=call.params) - - chain = middleware_._build_generate_chain(_real) - return await chain(call) - - -async def check_connection( - model: Model, - *, - client: Client | None = None, -) -> bool: - """Check whether *client* can reach *model*'s provider and the model exists. - - Returns ``True`` when the credentials are valid **and** the model is - available on the remote side — i.e. a subsequent :func:`stream` or - :func:`generate` call should succeed (network conditions permitting). - - This only hits free metadata endpoints; no tokens or credits are - consumed. - - If no *client* is given, one is auto-created from environment - variables (same logic as :func:`stream`). - - Non-auth transport errors (network failures, 5xx) are raised rather - than returning ``False`` so that callers can distinguish "bad - credentials / unknown model" from "provider unreachable". - """ - _ensure_check_fns() - c = client or _auto_client(model) - check_fn = _check_fns.get(model.provider) - if check_fn is None: - registered = ", ".join(sorted(_check_fns)) or "(none)" - raise KeyError( - f"No check function registered for provider={model.provider!r}. " - f"Registered: {registered}" - ) - return await check_fn(c, model) - - -async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: - """Drain a stream and return the final ``Message``. - - Raises :class:`ValueError` if the stream yields nothing. - """ - result: messages_.Message | None = None - async for msg in gen: - result = msg - if result is None: - raise ValueError("empty stream") - return result - +from .core.stream_result import StreamResult __all__ = [ # Core types @@ -367,12 +53,14 @@ async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: "get_providers", "model", "register_catalog", - # Public API - "buffer", - "check_connection", - "generate", + # Adapter / check registration "register_check", "register_generate", "register_stream", + # Public API + "check_connection", + "generate", "stream", + # Internal (used by tests) + "_PROVIDER_DEFAULTS", ] diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 83892315..40841556 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,19 +1,30 @@ """Core types for models.""" +from .adapters import register_check, register_generate, register_stream +from .api import check_connection, generate, stream from .catalog import get_models, get_providers, register_catalog from .catalog import model as model_factory from .client import Client from .model import Model, ModelCost -from .proto import GenerateFn, StreamFn +from .proto import CheckConnFn, GenerateFn, StreamFn +from .stream_result import StreamResult __all__ = [ + "CheckConnFn", "Client", "GenerateFn", "Model", "ModelCost", "StreamFn", + "StreamResult", + "check_connection", + "generate", "get_models", "get_providers", "model_factory", "register_catalog", + "register_check", + "register_generate", + "register_stream", + "stream", ] diff --git a/src/ai/models/core/adapters.py b/src/ai/models/core/adapters.py new file mode 100644 index 00000000..f7a5c114 --- /dev/null +++ b/src/ai/models/core/adapters.py @@ -0,0 +1,124 @@ +"""Adapter and check-function registries. + +Maps adapter/provider strings to their handler functions. Adapter modules +are imported lazily on first use to keep import-time lightweight. +""" + +from __future__ import annotations + +from . import proto + +# --------------------------------------------------------------------------- +# Stream / generate adapter registry +# --------------------------------------------------------------------------- + +_stream_adapters: dict[str, proto.StreamFn] = {} +_generate_adapters: dict[str, proto.GenerateFn] = {} +_adapters_loaded = False + + +def _ensure_adapters() -> None: + """Lazily register built-in adapter functions on first call.""" + global _adapters_loaded # noqa: PLW0603 + if _adapters_loaded: + return + _adapters_loaded = True + + from ..ai_gateway.generate import generate as ai_gw_generate + from ..ai_gateway.stream import stream as ai_gw_stream + from ..anthropic.adapter import stream as anthropic_stream + from ..openai.adapter import stream as openai_stream + + _stream_adapters["ai-gateway-v3"] = ai_gw_stream + _generate_adapters["ai-gateway-v3"] = ai_gw_generate + _stream_adapters["openai"] = openai_stream + _stream_adapters["anthropic"] = anthropic_stream + + +def register_stream(adapter: str, fn: proto.StreamFn) -> None: + """Register a stream adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _stream_adapters[adapter] = fn + + +def register_generate(adapter: str, fn: proto.GenerateFn) -> None: + """Register a generate adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _generate_adapters[adapter] = fn + + +def get_stream_adapter(adapter: str) -> proto.StreamFn: + """Return the stream adapter for *adapter*, raising on miss.""" + _ensure_adapters() + fn = _stream_adapters.get(adapter) + if fn is None: + registered = ", ".join(sorted(_stream_adapters)) or "(none)" + raise KeyError( + f"No stream adapter registered for adapter={adapter!r}. " + f"Registered: {registered}" + ) + return fn + + +def get_generate_adapter(adapter: str) -> proto.GenerateFn: + """Return the generate adapter for *adapter*, raising on miss.""" + _ensure_adapters() + fn = _generate_adapters.get(adapter) + if fn is None: + registered = ", ".join(sorted(_generate_adapters)) or "(none)" + raise KeyError( + f"No generate adapter registered for adapter={adapter!r}. " + f"Registered: {registered}" + ) + return fn + + +# --------------------------------------------------------------------------- +# Connection-check registry — keyed by *provider* (not adapter) because the +# check verifies "can this client reach this provider and does this model +# exist there". +# --------------------------------------------------------------------------- + +_check_fns: dict[str, proto.CheckConnFn] = {} +_check_fns_loaded = False + + +def _ensure_check_fns() -> None: + """Lazily register built-in check functions on first call.""" + global _check_fns_loaded # noqa: PLW0603 + if _check_fns_loaded: + return + _check_fns_loaded = True + + from ..ai_gateway import check as ai_gw_check + from ..anthropic import check as anthropic_check + from ..openai import check as openai_check + + _check_fns["ai-gateway"] = ai_gw_check.check + _check_fns["anthropic"] = anthropic_check.check + _check_fns["openai"] = openai_check.check + + +def register_check(provider: str, fn: proto.CheckConnFn) -> None: + """Register a connection-check function for a provider. + + Use this to add checks for custom providers. + """ + _check_fns[provider] = fn + + +def get_check_fn(provider: str) -> proto.CheckConnFn: + """Return the check function for *provider*, raising on miss.""" + _ensure_check_fns() + fn = _check_fns.get(provider) + if fn is None: + registered = ", ".join(sorted(_check_fns)) or "(none)" + raise KeyError( + f"No check function registered for provider={provider!r}. " + f"Registered: {registered}" + ) + return fn diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py new file mode 100644 index 00000000..fc7400a2 --- /dev/null +++ b/src/ai/models/core/api.py @@ -0,0 +1,125 @@ +"""Top-level orchestration — stream(), generate(), check_connection(). + +These wire together adapters, middleware chains, and auto-client creation. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import pydantic + +from ... import middleware as middleware_ +from ...types import messages as messages_ +from ...types import stream as stream_ +from ...types import tools as tools_ +from ..ai_gateway import types as ai_gw_types +from . import adapters, stream_result +from . import client as client_ +from . import model as model_ + + +async def stream( + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + client: client_.Client | None = None, + **kwargs: Any, +) -> stream_.StreamResultLike: + """Stream an LLM response. + + Returns a :class:`StreamResultLike` that is async-iterable and + collects the final ``Message``. After iteration, access ``.text``, + ``.tool_calls``, ``.usage``, etc. + + Without middleware the concrete type is :class:`StreamResult`; with + middleware it may be any :class:`~ai.StreamResultLike`. + """ + call = middleware_.ModelContext( + model=model, + messages=messages, + tools=tools, + output_type=output_type, + client=client, + kwargs=kwargs, + ) + + async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: + c = call.client or client_.auto_client(call.model) + adapter_fn = adapters.get_stream_adapter(call.model.adapter) + return stream_result.StreamResult( + adapter_fn( + c, + call.model, + call.messages, + tools=call.tools, + output_type=call.output_type, + **call.kwargs, + ) + ) + + chain = middleware_._build_model_chain(_real) + return await chain(call) + + +async def generate( + model: model_.Model, + messages: list[messages_.Message], + params: ai_gw_types.GenerateParams | None = None, + *, + client: client_.Client | None = None, +) -> messages_.Message: + """Generate a response (images, video, etc.). + + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from env vars if none is provided. + + ``params`` controls the generation type: + + * :class:`ImageParams` — image generation (``/image-model``). + * :class:`VideoParams` — video generation (``/video-model``). + * ``None`` — auto-detect from ``model.capabilities``. + """ + call = middleware_.GenerateContext( + model=model, + messages=messages, + params=params, + client=client, + ) + + async def _real(call: middleware_.GenerateContext) -> messages_.Message: + c = call.client or client_.auto_client(call.model) + adapter_fn = adapters.get_generate_adapter(call.model.adapter) + return await adapter_fn(c, call.model, call.messages, params=call.params) + + chain = middleware_._build_generate_chain(_real) + return await chain(call) + + +async def check_connection( + model: model_.Model, + *, + client: client_.Client | None = None, +) -> bool: + """Check whether *client* can reach *model*'s provider and the model exists. + + Returns ``True`` when the credentials are valid **and** the model is + available on the remote side — i.e. a subsequent :func:`stream` or + :func:`generate` call should succeed (network conditions permitting). + + This only hits free metadata endpoints; no tokens or credits are + consumed. + + If no *client* is given, one is auto-created from environment + variables (same logic as :func:`stream`). + + Non-auth transport errors (network failures, 5xx) are raised rather + than returning ``False`` so that callers can distinguish "bad + credentials / unknown model" from "provider unreachable". + """ + c = client or client_.auto_client(model) + check_fn = adapters.get_check_fn(model.provider) + return await check_fn(c, model) diff --git a/src/ai/models/core/client.py b/src/ai/models/core/client.py index fabed6cf..5db11d19 100644 --- a/src/ai/models/core/client.py +++ b/src/ai/models/core/client.py @@ -3,11 +3,14 @@ from __future__ import annotations import dataclasses +import os from typing import TYPE_CHECKING if TYPE_CHECKING: import httpx + from . import model as model_ + @dataclasses.dataclass class Client: @@ -47,3 +50,26 @@ async def aclose(self) -> None: if self._http is not None and not self._http.is_closed: await self._http.aclose() self._http = None + + +# --------------------------------------------------------------------------- +# Provider defaults — base URLs and env var names for auto-client creation. +# --------------------------------------------------------------------------- + +_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { + "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), + "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), + "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), +} + + +def auto_client(model: model_.Model) -> Client: + """Create a :class:`Client` from env vars for the given model's provider.""" + defaults = _PROVIDER_DEFAULTS.get(model.provider) + if defaults is None: + raise ValueError( + f"No default client config for provider {model.provider!r}. " + f"Pass an explicit client= argument." + ) + base_url, env_var = defaults + return Client(base_url=base_url, api_key=os.environ.get(env_var)) diff --git a/src/ai/models/core/stream_result.py b/src/ai/models/core/stream_result.py new file mode 100644 index 00000000..2eeb3577 --- /dev/null +++ b/src/ai/models/core/stream_result.py @@ -0,0 +1,65 @@ +"""StreamResult — concrete wrapper around a message stream.""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from typing import Any + +from ...types import messages as messages_ + + +class StreamResult: + """Wrapper around a message stream. Async-iterable; collects the final result. + + Properties like ``.text`` and ``.tool_calls`` delegate to the final + ``Message`` snapshot and are available after iteration completes. + + Satisfies :class:`~ai.types.StreamResultLike`. + """ + + def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: + self._gen = gen + self._final: messages_.Message | None = None + + @classmethod + def from_generator(cls, gen: AsyncGenerator[messages_.Message]) -> StreamResult: + """Create a :class:`StreamResult` from an async generator. + + This is the public API for middleware that needs to transform or + replace the stream returned by ``wrap_model``:: + + async def wrap_model(self, call, next): + original = await next(call) + + async def _transformed(): + async for msg in original: + yield modify(msg) + + return StreamResult.from_generator(_transformed()) + """ + return cls(gen) + + def __aiter__(self) -> AsyncGenerator[messages_.Message]: + return self._iterate() + + async def _iterate(self) -> AsyncGenerator[messages_.Message]: + async for msg in self._gen: + self._final = msg + yield msg + + @property + def text(self) -> str: + return self._final.text if self._final else "" + + @property + def tool_calls(self) -> list[messages_.ToolCallPart]: + return self._final.tool_calls if self._final else [] + + @property + def usage(self) -> messages_.Usage | None: + return self._final.usage if self._final else None + + @property + def output(self) -> Any: + """Parsed structured output from the final message, if available.""" + return self._final.output if self._final else None From 66326aa795d36ee6b21747de4e3236a5c352a575 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 13 Apr 2026 11:01:36 -0700 Subject: [PATCH 2/4] Move some types around --- src/ai/models/__init__.py | 3 +- src/ai/models/ai_gateway/__init__.py | 4 -- src/ai/models/ai_gateway/generate.py | 9 ++- src/ai/models/ai_gateway/types.py | 50 ---------------- src/ai/models/core/__init__.py | 5 +- src/ai/models/core/api.py | 8 +-- .../core/{stream_result.py => types.py} | 57 ++++++++++++++++++- 7 files changed, 69 insertions(+), 67 deletions(-) delete mode 100644 src/ai/models/ai_gateway/types.py rename src/ai/models/core/{stream_result.py => types.py} (53%) diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index d2634cc5..30a17a4e 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -25,7 +25,6 @@ """ from ..types.stream import StreamResultLike -from .ai_gateway.types import GenerateParams, ImageParams, VideoParams from .core.adapters import register_check, register_generate, register_stream from .core.api import check_connection, generate, stream from .core.catalog import get_models, get_providers, register_catalog @@ -33,7 +32,7 @@ from .core.client import _PROVIDER_DEFAULTS, Client from .core.model import Model, ModelCost from .core.proto import CheckConnFn, GenerateFn, StreamFn -from .core.stream_result import StreamResult +from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ # Core types diff --git a/src/ai/models/ai_gateway/__init__.py b/src/ai/models/ai_gateway/__init__.py index 4ec1f11e..5d1f9fbd 100644 --- a/src/ai/models/ai_gateway/__init__.py +++ b/src/ai/models/ai_gateway/__init__.py @@ -6,12 +6,8 @@ """ from . import errors -from .types import GenerateParams, ImageParams, VideoParams __all__ = [ - "GenerateParams", - "ImageParams", - "VideoParams", "errors", ] diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py index 0d0407fe..218eb324 100644 --- a/src/ai/models/ai_gateway/generate.py +++ b/src/ai/models/ai_gateway/generate.py @@ -1,7 +1,6 @@ """AI Gateway v3 generation adapter — image-model and video-model endpoints. -Provides typed parameter objects (:class:`ImageParams`, :class:`VideoParams`) -and a unified :func:`generate` entry point that dispatches based on param type +Unified :func:`generate` entry point that dispatches based on param type and validates against model capabilities. """ @@ -16,10 +15,10 @@ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files +from ..core.types import GenerateParams as GenerateParams +from ..core.types import ImageParams as ImageParams +from ..core.types import VideoParams as VideoParams from . import _common, errors -from .types import GenerateParams as GenerateParams -from .types import ImageParams as ImageParams -from .types import VideoParams as VideoParams # --------------------------------------------------------------------------- # Image generation — /image-model diff --git a/src/ai/models/ai_gateway/types.py b/src/ai/models/ai_gateway/types.py deleted file mode 100644 index d1ca0658..00000000 --- a/src/ai/models/ai_gateway/types.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Parameter types for AI Gateway generation endpoints. - -Extracted into a standalone module so they can be imported without pulling in -heavy dependencies (``httpx``, etc.) that the adapter functions need. -""" - -from __future__ import annotations - -from typing import Any - -import pydantic - -_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) - - -class ImageParams(pydantic.BaseModel): - """Parameters for image generation (``/image-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - size: str | None = None - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -class VideoParams(pydantic.BaseModel): - """Parameters for video generation (``/video-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - resolution: str | None = None - duration: int | None = None - fps: int | None = None - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -GenerateParams = ImageParams | VideoParams diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 40841556..0c2c40a4 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -7,16 +7,19 @@ from .client import Client from .model import Model, ModelCost from .proto import CheckConnFn, GenerateFn, StreamFn -from .stream_result import StreamResult +from .types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ "CheckConnFn", "Client", "GenerateFn", + "GenerateParams", + "ImageParams", "Model", "ModelCost", "StreamFn", "StreamResult", + "VideoParams", "check_connection", "generate", "get_models", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index fc7400a2..39b1cb2f 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -14,10 +14,10 @@ from ...types import messages as messages_ from ...types import stream as stream_ from ...types import tools as tools_ -from ..ai_gateway import types as ai_gw_types -from . import adapters, stream_result +from . import adapters from . import client as client_ from . import model as model_ +from . import types as types_ async def stream( @@ -50,7 +50,7 @@ async def stream( async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: c = call.client or client_.auto_client(call.model) adapter_fn = adapters.get_stream_adapter(call.model.adapter) - return stream_result.StreamResult( + return types_.StreamResult( adapter_fn( c, call.model, @@ -68,7 +68,7 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: async def generate( model: model_.Model, messages: list[messages_.Message], - params: ai_gw_types.GenerateParams | None = None, + params: types_.GenerateParams | None = None, *, client: client_.Client | None = None, ) -> messages_.Message: diff --git a/src/ai/models/core/stream_result.py b/src/ai/models/core/types.py similarity index 53% rename from src/ai/models/core/stream_result.py rename to src/ai/models/core/types.py index 2eeb3577..eff6e5af 100644 --- a/src/ai/models/core/stream_result.py +++ b/src/ai/models/core/types.py @@ -1,12 +1,67 @@ -"""StreamResult — concrete wrapper around a message stream.""" +"""Core model-layer types — parameter objects and StreamResult. + +Parameter types (:class:`ImageParams`, :class:`VideoParams`) live here +because they parameterise the public :func:`generate` API. + +:class:`StreamResult` is the concrete wrapper returned by :func:`stream`. +""" from __future__ import annotations from collections.abc import AsyncGenerator from typing import Any +import pydantic + from ...types import messages as messages_ +# --------------------------------------------------------------------------- +# Generation parameter types +# --------------------------------------------------------------------------- + +_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) + + +class ImageParams(pydantic.BaseModel): + """Parameters for image generation (``/image-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + size: str | None = None + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +class VideoParams(pydantic.BaseModel): + """Parameters for video generation (``/video-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + resolution: str | None = None + duration: int | None = None + fps: int | None = None + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +GenerateParams = ImageParams | VideoParams + +# --------------------------------------------------------------------------- +# StreamResult +# --------------------------------------------------------------------------- + class StreamResult: """Wrapper around a message stream. Async-iterable; collects the final result. From 29b9db2d6aacaf9661ccbb7212176ae167d22a83 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 13 Apr 2026 13:03:41 -0700 Subject: [PATCH 3/4] Reshape the API into provider("model") --- examples/fastapi-vite/backend/agent.py | 6 +- examples/multiagent-textual/server.py | 2 +- examples/samples/agent_custom_loop.py | 2 +- examples/samples/agent_hooks.py | 2 +- examples/samples/agent_hooks_serverless.py | 2 +- examples/samples/check_connection.py | 28 +++- examples/samples/explicit_client.py | 6 +- examples/samples/image_generation.py | 2 +- examples/samples/middleware_simple.py | 2 +- examples/samples/stream.py | 2 +- examples/samples/video_generation.py | 2 +- examples/temporal-direct/main.py | 4 +- examples/temporal-middleware/main.py | 4 +- src/ai/__init__.py | 13 +- src/ai/middleware.py | 3 - src/ai/models/__init__.py | 47 +++--- src/ai/models/ai_gateway/__init__.py | 70 +++++++- src/ai/models/ai_gateway/catalog.py | 150 ------------------ src/ai/models/ai_gateway/generate.py | 47 +----- src/ai/models/anthropic/__init__.py | 67 +++++++- src/ai/models/anthropic/catalog.py | 55 ------- src/ai/models/core/__init__.py | 9 +- src/ai/models/core/api.py | 31 ++-- src/ai/models/core/catalog.py | 103 ------------ src/ai/models/core/client.py | 42 +++-- src/ai/models/core/model.py | 27 ++-- src/ai/models/core/proto.py | 2 +- src/ai/models/openai/__init__.py | 63 +++++++- src/ai/models/openai/catalog.py | 45 ------ tests/conftest.py | 11 +- .../models/ai_gateway/test_generate_image.py | 15 +- .../models/ai_gateway/test_generate_video.py | 1 - tests/models/test_check.py | 31 ++-- tests/models/test_public_api.py | 10 +- tests/test_middleware.py | 4 +- 35 files changed, 363 insertions(+), 547 deletions(-) delete mode 100644 src/ai/models/ai_gateway/catalog.py delete mode 100644 src/ai/models/anthropic/catalog.py delete mode 100644 src/ai/models/core/catalog.py delete mode 100644 src/ai/models/openai/catalog.py diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index edfdf068..b3af15c3 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -10,7 +10,7 @@ import ai -MODEL = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +MODEL = ai.ai_gateway("anthropic/claude-sonnet-4") @ai.tool @@ -42,7 +42,9 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Message]: if not tool_calls: return - results = await asyncio.gather(*(_execute_with_approval(tc) for tc in tool_calls)) + results = await asyncio.gather( + *(_execute_with_approval(tc) for tc in tool_calls) + ) yield ai.tool_message(*results) diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index 6b85ffcb..a35935f0 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -58,7 +58,7 @@ class Approval(pydantic.BaseModel): # Model # --------------------------------------------------------------------------- -MODEL = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +MODEL = ai.ai_gateway("anthropic/claude-sonnet-4") # --------------------------------------------------------------------------- diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 8c540660..27a5a17a 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -19,7 +19,7 @@ async def get_population(city: str) -> int: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") tools = [get_weather, get_population] my_agent = ai.agent(tools=tools) diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index 361fe7a2..7180311b 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -26,7 +26,7 @@ async def contact_mothership(query: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[contact_mothership]) diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 2d7de2c2..23717689 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -31,7 +31,7 @@ async def delete_file(path: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[delete_file]) diff --git a/examples/samples/check_connection.py b/examples/samples/check_connection.py index 9efc6a66..b89d2f16 100644 --- a/examples/samples/check_connection.py +++ b/examples/samples/check_connection.py @@ -1,13 +1,19 @@ -"""Check connection — verify credentials and model availability for all providers.""" +"""Check connection and list models — verify credentials and model availability.""" import asyncio import ai MODELS = [ - ai.model("ai-gateway", "anthropic/claude-sonnet-4"), - ai.model("anthropic", "claude-sonnet-4-20250514"), - ai.model("openai", "gpt-5.4-mini"), + ai.ai_gateway("anthropic/claude-sonnet-4"), + ai.anthropic("claude-sonnet-4-20250514"), + ai.openai("gpt-5.4-mini"), +] + +PROVIDERS = [ + ("ai_gateway", ai.ai_gateway), + ("anthropic", ai.anthropic), + ("openai", ai.openai), ] @@ -20,9 +26,23 @@ async def _check(model: ai.Model) -> None: print(f" {status} {model.provider}/{model.id}") +async def _list_models(name: str, provider: object) -> None: + try: + ids: list[str] = await provider.list() # type: ignore[union-attr] + print(f" {name}: {len(ids)} models") + for mid in ids: + print(f" - {mid}") + except Exception as exc: + print(f" {name}: [ERR] {exc}") + + async def main() -> None: print("Checking connections...\n") await asyncio.gather(*[_check(m) for m in MODELS]) + + print("\nListing models...\n") + await asyncio.gather(*[_list_models(n, p) for n, p in PROVIDERS]) + print() diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index 9438a4f4..fa95ada8 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -5,8 +5,6 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") - # Explicit client — useful for custom auth, proxies, or self-hosted gateways. client = ai.Client( base_url="https://ai-gateway.vercel.sh/v3/ai", @@ -14,12 +12,14 @@ headers={"X-Custom-Header": "example"}, ) +model = ai.ai_gateway("anthropic/claude-sonnet-4", client=client) + messages = [ai.user_message("Hello!")] async def main() -> None: try: - async for msg in await ai.models.stream(model, messages, client=client): + async for msg in await ai.models.stream(model, messages): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/image_generation.py b/examples/samples/image_generation.py index 99b301d9..cd669f3a 100644 --- a/examples/samples/image_generation.py +++ b/examples/samples/image_generation.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "google/imagen-4.0-generate-001") +model = ai.ai_gateway("google/imagen-4.0-generate-001") messages = [ ai.user_message( diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index e9cb7859..e0deb06b 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -88,7 +88,7 @@ async def get_population(city: str) -> int: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[get_weather, get_population]) diff --git a/examples/samples/stream.py b/examples/samples/stream.py index 762afb84..dbfaae48 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -4,7 +4,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ ai.system_message("Be concise."), diff --git a/examples/samples/video_generation.py b/examples/samples/video_generation.py index 8692faee..196a63d4 100644 --- a/examples/samples/video_generation.py +++ b/examples/samples/video_generation.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "google/veo-3.0-generate-001") +model = ai.ai_gateway("google/veo-3.0-generate-001") messages = [ ai.user_message( diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index dcf54b0c..8039df76 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -96,7 +96,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] @@ -166,7 +166,7 @@ async def run_tool(tc: ai.ToolCallPart) -> ai.ToolResultPart: class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages: list[ai.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/examples/temporal-middleware/main.py b/examples/temporal-middleware/main.py index 9764d007..7e8a56fb 100644 --- a/examples/temporal-middleware/main.py +++ b/examples/temporal-middleware/main.py @@ -113,7 +113,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] @@ -194,7 +194,7 @@ async def wrap_tool( class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages: list[ai.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 47757ec5..23475bbc 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -20,12 +20,13 @@ Client, ImageParams, Model, - ModelCost, StreamResult, VideoParams, + ai_gateway, + anthropic, check_connection, generate, - model, + openai, stream, ) @@ -83,17 +84,19 @@ "thinking", # Models (from models/) "Model", - "ModelCost", "ImageParams", "VideoParams", "Client", "StreamResult", "StreamResultLike", "check_connection", - "model", - "models", "stream", "generate", + "models", + # Provider factories + "openai", + "anthropic", + "ai_gateway", # Agents — primary API "Agent", "agent", diff --git a/src/ai/middleware.py b/src/ai/middleware.py index bc703859..e5eaaa97 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from .agents.agent import Tool - from .models.core.client import Client from .models.core.model import Model @@ -50,7 +49,6 @@ class ModelContext: messages: list[messages_.Message] tools: Sequence[tools_.ToolLike] | None output_type: type[pydantic.BaseModel] | None - client: Client | None kwargs: dict[str, Any] def __post_init__(self) -> None: @@ -67,7 +65,6 @@ class GenerateContext: model: Model messages: list[messages_.Message] params: Any - client: Client | None = None def __post_init__(self) -> None: object.__setattr__(self, "messages", list(self.messages)) diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 30a17a4e..c03b7266 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -3,36 +3,37 @@ Usage:: import ai - from ai.types import Message, TextPart + from ai.models import openai, anthropic, ai_gateway - # look up a model from the catalog - opus = ai.model("ai-gateway", "anthropic/claude-opus-4-6") - - msgs = [Message(role="user", parts=[TextPart(text="hello")])] + model = openai("gpt-5.4") + model = anthropic("claude-sonnet-4-6") + model = ai_gateway("anthropic/claude-sonnet-4") # stream — auto-creates client from env vars - s = await ai.stream(opus, msgs) + msgs = [ai.user_message("hello")] + s = await ai.stream(model, msgs) async for msg in s: print(msg.text_delta, end="") - # explicit client - client = ai.Client( - base_url="https://custom.example.com/v3/ai", api_key="sk-...", - ) - s = await ai.stream(opus, msgs, client=client) - async for msg in s: - ... + # explicit client for custom auth + client = ai.Client(base_url="https://custom.example.com/v1", api_key="sk-...") + model = openai("gpt-5.4", client=client) + s = await ai.stream(model, msgs) + + # list available models + ids = await openai.list() """ from ..types.stream import StreamResultLike +from .ai_gateway import ai_gateway +from .anthropic import anthropic from .core.adapters import register_check, register_generate, register_stream from .core.api import check_connection, generate, stream -from .core.catalog import get_models, get_providers, register_catalog -from .core.catalog import model as model -from .core.client import _PROVIDER_DEFAULTS, Client -from .core.model import Model, ModelCost +from .core.client import Client +from .core.model import Model from .core.proto import CheckConnFn, GenerateFn, StreamFn from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams +from .openai import openai __all__ = [ # Core types @@ -42,16 +43,14 @@ "GenerateParams", "ImageParams", "Model", - "ModelCost", "StreamFn", "StreamResult", "StreamResultLike", "VideoParams", - # Catalog - "get_models", - "get_providers", - "model", - "register_catalog", + # Provider factories + "ai_gateway", + "anthropic", + "openai", # Adapter / check registration "register_check", "register_generate", @@ -60,6 +59,4 @@ "check_connection", "generate", "stream", - # Internal (used by tests) - "_PROVIDER_DEFAULTS", ] diff --git a/src/ai/models/ai_gateway/__init__.py b/src/ai/models/ai_gateway/__init__.py index 5d1f9fbd..18a73b6a 100644 --- a/src/ai/models/ai_gateway/__init__.py +++ b/src/ai/models/ai_gateway/__init__.py @@ -1,13 +1,81 @@ -"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol. +"""AI Gateway provider. + +Usage:: + + from ai.models import ai_gateway + + model = ai_gateway("anthropic/claude-sonnet-4") + ids = await ai_gateway.list() Heavy adapter modules (``.generate``, ``.stream``) are loaded lazily so that ``import ai`` does not pull in ``httpx`` and other I/O libraries at import time. This matters for sandboxed runtimes (e.g. Temporal workflow workers). """ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from ..core.model import Model from . import errors +if TYPE_CHECKING: + from ..core.client import Client + +_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" +_API_KEY_ENV = "AI_GATEWAY_API_KEY" +_PROTOCOL_VERSION = "0.0.1" + + +class _AIGateway: + """Callable provider factory for the Vercel AI Gateway.""" + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter="ai-gateway-v3", + provider="ai-gateway", + base_url=base_url or _BASE_URL, + api_key_env=_API_KEY_ENV, + client=client, + ) + + async def list(self, *, client: Client | None = None) -> list[str]: + """List available model IDs from the AI Gateway.""" + from ..core import client as client_ + + c = client or client_.Client( + base_url=_BASE_URL, + api_key=__import__("os").environ.get(_API_KEY_ENV), + ) + base_url = c.base_url.rstrip("/") + headers: dict[str, str] = { + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + if c.api_key: + headers["Authorization"] = f"Bearer {c.api_key}" + headers["ai-gateway-auth-method"] = "api-key" + + config_url = f"{base_url}/config" + response = await c.http.get(config_url, headers=headers) + response.raise_for_status() + data: dict[str, Any] = response.json() + return sorted(str(m["id"]) for m in data.get("models", [])) + + def __repr__(self) -> str: + return "ai_gateway" + + +ai_gateway = _AIGateway() + __all__ = [ + "ai_gateway", "errors", ] diff --git a/src/ai/models/ai_gateway/catalog.py b/src/ai/models/ai_gateway/catalog.py deleted file mode 100644 index a57fc044..00000000 --- a/src/ai/models/ai_gateway/catalog.py +++ /dev/null @@ -1,150 +0,0 @@ -"""AI Gateway model catalog. - -Model IDs use the gateway's ``provider/model-name`` format. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "ai-gateway-v3" -_PROVIDER = "ai-gateway" - - -def _text( - id: str, - name: str, - *, - context_window: int, - max_output_tokens: int, - cost: ModelCost, -) -> Model: - return Model( - id=id, - adapter=_ADAPTER, - provider=_PROVIDER, - name=name, - capabilities=("text",), - context_window=context_window, - max_output_tokens=max_output_tokens, - cost=cost, - ) - - -def _media( - id: str, - name: str, - *, - capabilities: tuple[str, ...], - context_window: int = 0, - max_output_tokens: int = 0, - cost: ModelCost | None = None, -) -> Model: - return Model( - id=id, - adapter=_ADAPTER, - provider=_PROVIDER, - name=name, - capabilities=capabilities, - context_window=context_window, - max_output_tokens=max_output_tokens, - cost=cost, - ) - - -# --------------------------------------------------------------------------- -# Catalog -# --------------------------------------------------------------------------- - -CATALOG: dict[str, Model] = { - # -- Anthropic (via gateway) ------------------------------------------- - "anthropic/claude-opus-4-6": _text( - "anthropic/claude-opus-4-6", - "Claude Opus 4.6", - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=5.0, output=25.0, cache_read=0.50, cache_write=6.25), - ), - "anthropic/claude-sonnet-4-6": _text( - "anthropic/claude-sonnet-4-6", - "Claude Sonnet 4.6", - context_window=1_000_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - "anthropic/claude-haiku-4-5": _text( - "anthropic/claude-haiku-4-5", - "Claude Haiku 4.5", - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=1.0, output=5.0, cache_read=0.10, cache_write=1.25), - ), - "anthropic/claude-sonnet-4": _text( - "anthropic/claude-sonnet-4", - "Claude Sonnet 4", - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - # -- OpenAI (via gateway) ---------------------------------------------- - "openai/gpt-5.4": _text( - "openai/gpt-5.4", - "GPT-5.4", - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=2.50, output=15.0, cache_read=0.25), - ), - "openai/gpt-5.4-mini": _text( - "openai/gpt-5.4-mini", - "GPT-5.4 Mini", - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.75, output=4.50, cache_read=0.075), - ), - "openai/gpt-5.4-nano": _text( - "openai/gpt-5.4-nano", - "GPT-5.4 Nano", - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.20, output=1.25, cache_read=0.02), - ), - # -- Google (via gateway) ---------------------------------------------- - "google/gemini-2.5-pro": _text( - "google/gemini-2.5-pro", - "Gemini 2.5 Pro", - context_window=1_000_000, - max_output_tokens=65_536, - cost=ModelCost(input=1.25, output=10.0, cache_read=0.315), - ), - "google/gemini-2.5-flash": _text( - "google/gemini-2.5-flash", - "Gemini 2.5 Flash", - context_window=1_000_000, - max_output_tokens=65_536, - cost=ModelCost(input=0.15, output=0.60, cache_read=0.0375), - ), - # -- Image / video models (via gateway) -------------------------------- - "google/gemini-3-pro-image": _media( - "google/gemini-3-pro-image", - "Gemini 3 Pro Image", - capabilities=("text", "image"), - context_window=1_000_000, - max_output_tokens=65_536, - ), - "google/imagen-4.0-generate-001": _media( - "google/imagen-4.0-generate-001", - "Imagen 4.0", - capabilities=("image",), - ), - "openai/gpt-image-1": _media( - "openai/gpt-image-1", - "GPT Image 1", - capabilities=("image",), - ), - "google/veo-3.0-generate-001": _media( - "google/veo-3.0-generate-001", - "Veo 3.0", - capabilities=("video",), - ), -} diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py index 218eb324..9a6dee45 100644 --- a/src/ai/models/ai_gateway/generate.py +++ b/src/ai/models/ai_gateway/generate.py @@ -1,7 +1,6 @@ """AI Gateway v3 generation adapter — image-model and video-model endpoints. -Unified :func:`generate` entry point that dispatches based on param type -and validates against model capabilities. +Unified :func:`generate` entry point that dispatches based on param type. """ from __future__ import annotations @@ -154,57 +153,17 @@ async def _generate_video( # --------------------------------------------------------------------------- -def _check_capabilities( - model: model_.Model, - params: GenerateParams, -) -> None: - """Validate that model capabilities match the requested generation type.""" - caps = model.capabilities - - if isinstance(params, VideoParams): - if "video" not in caps: - raise ValueError( - f"Model {model.id!r} does not have 'video' capability " - f"(capabilities={caps}). Use ImageParams for image models." - ) - if "text" in caps and "video" not in caps: - raise ValueError( - f"Model {model.id!r} is a text model (capabilities={caps}). " - f"Use stream() for text generation, not generate()." - ) - elif isinstance(params, ImageParams): - if "video" in caps and "image" not in caps: - raise ValueError( - f"Model {model.id!r} has 'video' capability but not 'image' " - f"(capabilities={caps}). Use VideoParams for video models." - ) - if "text" in caps and "image" not in caps: - raise ValueError( - f"Model {model.id!r} is a text model (capabilities={caps}). " - f"Use stream() for text generation, not generate()." - ) - - async def generate( client: client_.Client, model: model_.Model, messages: list[messages_.Message], - params: GenerateParams | None = None, + params: GenerateParams, ) -> messages_.Message: """Generate media (images or video) through the AI Gateway. Dispatches to ``/image-model`` or ``/video-model`` based on ``params`` - type, with fallback to model capabilities when ``params`` is ``None``. - - Raises :class:`ValueError` if the model capabilities don't match the - requested generation type. + type. """ - # Auto-detect from capabilities when no params provided - if params is None: - params = VideoParams() if "video" in model.capabilities else ImageParams() - - _check_capabilities(model, params) - if isinstance(params, VideoParams): return await _generate_video(client, model, messages, params) return await _generate_image(client, model, messages, params) diff --git a/src/ai/models/anthropic/__init__.py b/src/ai/models/anthropic/__init__.py index 2a0fd334..e1c2c2b8 100644 --- a/src/ai/models/anthropic/__init__.py +++ b/src/ai/models/anthropic/__init__.py @@ -1,10 +1,73 @@ -"""Anthropic provider — adapter for the Anthropic messages API. +"""Anthropic provider. + +Usage:: + + from ai.models import anthropic + + model = anthropic("claude-sonnet-4-6") + ids = await anthropic.list() The adapter module is loaded lazily to avoid pulling in the ``anthropic`` SDK at import time. """ -__all__: list[str] = [] +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..core.model import Model + +if TYPE_CHECKING: + from ..core.client import Client + +_BASE_URL = "https://api.anthropic.com/v1" +_API_KEY_ENV = "ANTHROPIC_API_KEY" +_ANTHROPIC_VERSION = "2023-06-01" + + +class _Anthropic: + """Callable provider factory for Anthropic.""" + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter="anthropic", + provider="anthropic", + base_url=base_url or _BASE_URL, + api_key_env=_API_KEY_ENV, + client=client, + ) + + async def list(self, *, client: Client | None = None) -> list[str]: + """List available model IDs from the Anthropic API.""" + from ..core import client as client_ + + c = client or client_.Client( + base_url=_BASE_URL, + api_key=__import__("os").environ.get(_API_KEY_ENV), + ) + headers = { + "x-api-key": c.api_key or "", + "anthropic-version": _ANTHROPIC_VERSION, + } + response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "anthropic" + + +anthropic = _Anthropic() + +__all__ = ["anthropic"] def __getattr__(name: str) -> object: diff --git a/src/ai/models/anthropic/catalog.py b/src/ai/models/anthropic/catalog.py deleted file mode 100644 index 333f9584..00000000 --- a/src/ai/models/anthropic/catalog.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Anthropic direct-API model catalog. - -Model IDs are what the Anthropic Messages API expects. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "anthropic" -_PROVIDER = "anthropic" - -CATALOG: dict[str, Model] = { - "claude-opus-4-6": Model( - id="claude-opus-4-6", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Opus 4.6", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=5.0, output=25.0, cache_read=0.50, cache_write=6.25), - ), - "claude-sonnet-4-6": Model( - id="claude-sonnet-4-6", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Sonnet 4.6", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - "claude-haiku-4-5": Model( - id="claude-haiku-4-5", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Haiku 4.5", - capabilities=("text",), - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=1.0, output=5.0, cache_read=0.10, cache_write=1.25), - ), - "claude-sonnet-4-20250514": Model( - id="claude-sonnet-4-20250514", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Sonnet 4", - capabilities=("text",), - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), -} diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 0c2c40a4..0be684e4 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -2,10 +2,8 @@ from .adapters import register_check, register_generate, register_stream from .api import check_connection, generate, stream -from .catalog import get_models, get_providers, register_catalog -from .catalog import model as model_factory from .client import Client -from .model import Model, ModelCost +from .model import Model from .proto import CheckConnFn, GenerateFn, StreamFn from .types import GenerateParams, ImageParams, StreamResult, VideoParams @@ -16,16 +14,11 @@ "GenerateParams", "ImageParams", "Model", - "ModelCost", "StreamFn", "StreamResult", "VideoParams", "check_connection", "generate", - "get_models", - "get_providers", - "model_factory", - "register_catalog", "register_check", "register_generate", "register_stream", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 39b1cb2f..53f56738 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -26,7 +26,6 @@ async def stream( *, tools: Sequence[tools_.ToolLike] | None = None, output_type: type[pydantic.BaseModel] | None = None, - client: client_.Client | None = None, **kwargs: Any, ) -> stream_.StreamResultLike: """Stream an LLM response. @@ -35,20 +34,19 @@ async def stream( collects the final ``Message``. After iteration, access ``.text``, ``.tool_calls``, ``.usage``, etc. - Without middleware the concrete type is :class:`StreamResult`; with - middleware it may be any :class:`~ai.StreamResultLike`. + The client is resolved from the model: ``model.client`` if set, + otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. """ call = middleware_.ModelContext( model=model, messages=messages, tools=tools, output_type=output_type, - client=client, kwargs=kwargs, ) async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: - c = call.client or client_.auto_client(call.model) + c = client_.auto_client(call.model) adapter_fn = adapters.get_stream_adapter(call.model.adapter) return types_.StreamResult( adapter_fn( @@ -68,30 +66,27 @@ async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: async def generate( model: model_.Model, messages: list[messages_.Message], - params: types_.GenerateParams | None = None, - *, - client: client_.Client | None = None, + params: types_.GenerateParams, + **kwargs: Any, ) -> messages_.Message: """Generate a response (images, video, etc.). Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from env vars if none is provided. + :class:`Client` from the model if no explicit client is set. - ``params`` controls the generation type: + ``params`` is required and controls the generation type: * :class:`ImageParams` — image generation (``/image-model``). * :class:`VideoParams` — video generation (``/video-model``). - * ``None`` — auto-detect from ``model.capabilities``. """ call = middleware_.GenerateContext( model=model, messages=messages, params=params, - client=client, ) async def _real(call: middleware_.GenerateContext) -> messages_.Message: - c = call.client or client_.auto_client(call.model) + c = client_.auto_client(call.model) adapter_fn = adapters.get_generate_adapter(call.model.adapter) return await adapter_fn(c, call.model, call.messages, params=call.params) @@ -101,10 +96,8 @@ async def _real(call: middleware_.GenerateContext) -> messages_.Message: async def check_connection( model: model_.Model, - *, - client: client_.Client | None = None, ) -> bool: - """Check whether *client* can reach *model*'s provider and the model exists. + """Check whether the model's provider is reachable and the model exists. Returns ``True`` when the credentials are valid **and** the model is available on the remote side — i.e. a subsequent :func:`stream` or @@ -113,13 +106,13 @@ async def check_connection( This only hits free metadata endpoints; no tokens or credits are consumed. - If no *client* is given, one is auto-created from environment - variables (same logic as :func:`stream`). + The client is resolved from the model: ``model.client`` if set, + otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. Non-auth transport errors (network failures, 5xx) are raised rather than returning ``False`` so that callers can distinguish "bad credentials / unknown model" from "provider unreachable". """ - c = client or client_.auto_client(model) + c = client_.auto_client(model) check_fn = adapters.get_check_fn(model.provider) return await check_fn(c, model) diff --git a/src/ai/models/core/catalog.py b/src/ai/models/core/catalog.py deleted file mode 100644 index 3244b757..00000000 --- a/src/ai/models/core/catalog.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Model catalog — per-provider registry of known models. - -Usage:: - - from ai.models.core.catalog import model - - opus = model("anthropic", "claude-opus-4-6") - sonnet_gw = model("ai-gateway", "anthropic/claude-sonnet-4-6") -""" - -from __future__ import annotations - -from .model import Model - -# --------------------------------------------------------------------------- -# Global registry: provider name -> {model_id -> Model} -# --------------------------------------------------------------------------- - -_catalogs: dict[str, dict[str, Model]] = {} - - -def register_catalog(provider: str, models: dict[str, Model]) -> None: - """Register a provider's model catalog. - - Called by each provider's ``catalog`` module during initialisation, - and by users who want to add custom providers. - """ - _catalogs[provider] = models - - -def _load_builtin_catalogs() -> None: - """Import and register all built-in provider catalogs. - - Catalog submodules only depend on :mod:`..core.model` (pure - dataclasses) — no ``httpx``, ``anthropic``, or ``openai`` imports. - This makes it safe to run eagerly at import time, including inside - sandboxed runtimes (e.g. Temporal workflow workers). - """ - from ..ai_gateway.catalog import CATALOG as ai_gw_catalog - from ..anthropic.catalog import CATALOG as anthropic_catalog - from ..openai.catalog import CATALOG as openai_catalog - - register_catalog("ai-gateway", ai_gw_catalog) - register_catalog("anthropic", anthropic_catalog) - register_catalog("openai", openai_catalog) - - -_load_builtin_catalogs() - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def model(provider: str, model_id: str) -> Model: - """Look up a model by provider and model ID. - - Raises :class:`KeyError` with a helpful message if the provider or - model ID is not found. - - Examples:: - - model("ai-gateway", "anthropic/claude-opus-4-6") - model("anthropic", "claude-opus-4-6") - model("openai", "gpt-5.4") - """ - provider_catalog = _catalogs.get(provider) - if provider_catalog is None: - available = ", ".join(sorted(_catalogs)) or "(none)" - raise KeyError( - f"Unknown provider {provider!r}. Registered providers: {available}" - ) - - entry = provider_catalog.get(model_id) - if entry is None: - available = ", ".join(sorted(provider_catalog)) or "(none)" - raise KeyError( - f"Unknown model {model_id!r} for provider {provider!r}. " - f"Available models: {available}" - ) - - return entry - - -def get_providers() -> list[str]: - """Return all registered provider names.""" - return sorted(_catalogs) - - -def get_models(provider: str) -> dict[str, Model]: - """Return all models for a provider. - - Raises :class:`KeyError` if the provider is not registered. - """ - provider_catalog = _catalogs.get(provider) - if provider_catalog is None: - available = ", ".join(sorted(_catalogs)) or "(none)" - raise KeyError( - f"Unknown provider {provider!r}. Registered providers: {available}" - ) - - return dict(provider_catalog) diff --git a/src/ai/models/core/client.py b/src/ai/models/core/client.py index 5db11d19..2f4c319e 100644 --- a/src/ai/models/core/client.py +++ b/src/ai/models/core/client.py @@ -4,7 +4,7 @@ import dataclasses import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import httpx @@ -28,9 +28,7 @@ class Client: api_key: str | None = None headers: dict[str, str] = dataclasses.field(default_factory=dict) - _http: httpx.AsyncClient | None = dataclasses.field( - default=None, repr=False, compare=False - ) + _http: Any = dataclasses.field(default=None, repr=False, compare=False) @property def http(self) -> httpx.AsyncClient: @@ -43,7 +41,7 @@ def http(self) -> httpx.AsyncClient: headers=self.headers, timeout=_httpx.Timeout(timeout=300.0, connect=10.0), ) - return self._http + return self._http # type: ignore[no-any-return] async def aclose(self) -> None: """Close the underlying HTTP client if open.""" @@ -52,24 +50,24 @@ async def aclose(self) -> None: self._http = None -# --------------------------------------------------------------------------- -# Provider defaults — base URLs and env var names for auto-client creation. -# --------------------------------------------------------------------------- - -_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { - "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), - "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), - "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), -} +def auto_client(model: model_.Model) -> Client: + """Create a :class:`Client` from the model's connection info. + Uses ``model.client`` if set, otherwise builds one from + ``model.base_url`` and the env var named by ``model.api_key_env``. + """ + if model.client is not None: + return model.client -def auto_client(model: model_.Model) -> Client: - """Create a :class:`Client` from env vars for the given model's provider.""" - defaults = _PROVIDER_DEFAULTS.get(model.provider) - if defaults is None: + if model.base_url is None: raise ValueError( - f"No default client config for provider {model.provider!r}. " - f"Pass an explicit client= argument." + f"Model {model.id!r} (provider={model.provider!r}) has no " + f"base_url and no explicit client. Pass a client= to the " + f"provider factory or set base_url." ) - base_url, env_var = defaults - return Client(base_url=base_url, api_key=os.environ.get(env_var)) + + api_key: str | None = None + if model.api_key_env: + api_key = os.environ.get(model.api_key_env) + + return Client(base_url=model.base_url, api_key=api_key) diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index cbf59f50..26d5f2a4 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -4,31 +4,24 @@ import dataclasses - -@dataclasses.dataclass(frozen=True) -class ModelCost: - """Per-million-token pricing.""" - - input: float = 0.0 - output: float = 0.0 - cache_read: float = 0.0 - cache_write: float = 0.0 +from .client import Client @dataclasses.dataclass(frozen=True) class Model: - """Pure-data description of a model. + """Lightweight reference to a model on a specific provider. - * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). - * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). + * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). + * ``adapter`` — wire protocol key (e.g. ``"ai-gateway-v3"``, ``"anthropic"``). * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). + * ``base_url`` — API endpoint for auto-client creation. + * ``api_key_env`` — env var name to read for auto-client creation. + * ``client`` — explicit :class:`Client` override (skips auto-client). """ id: str adapter: str provider: str - name: str = "" - capabilities: tuple[str, ...] = ("text",) - context_window: int = 0 - max_output_tokens: int = 0 - cost: ModelCost | None = None + base_url: str | None = None + api_key_env: str | None = None + client: Client | None = dataclasses.field(default=None, repr=False) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 754b2fa7..446417be 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -57,7 +57,7 @@ async def __call__( client: Client, model: Model, messages: list[messages_.Message], - params: Any = None, + params: Any, ) -> messages_.Message: ... diff --git a/src/ai/models/openai/__init__.py b/src/ai/models/openai/__init__.py index 5fc55fc0..69b2a041 100644 --- a/src/ai/models/openai/__init__.py +++ b/src/ai/models/openai/__init__.py @@ -1,10 +1,69 @@ -"""OpenAI provider — adapter for the OpenAI chat completions API. +"""OpenAI provider. + +Usage:: + + from ai.models import openai + + model = openai("gpt-5.4") + ids = await openai.list() The adapter module is loaded lazily to avoid pulling in the ``openai`` SDK at import time. """ -__all__: list[str] = [] +from __future__ import annotations + +from typing import TYPE_CHECKING + +from ..core.model import Model + +if TYPE_CHECKING: + from ..core.client import Client + +_BASE_URL = "https://api.openai.com/v1" +_API_KEY_ENV = "OPENAI_API_KEY" + + +class _OpenAI: + """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`.""" + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter="openai", + provider="openai", + base_url=base_url or _BASE_URL, + api_key_env=_API_KEY_ENV, + client=client, + ) + + async def list(self, *, client: Client | None = None) -> list[str]: + """List available model IDs from the OpenAI API.""" + from ..core import client as client_ + + c = client or client_.Client( + base_url=_BASE_URL, + api_key=__import__("os").environ.get(_API_KEY_ENV), + ) + headers = {"Authorization": f"Bearer {c.api_key}"} + response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "openai" + + +openai = _OpenAI() + +__all__ = ["openai"] def __getattr__(name: str) -> object: diff --git a/src/ai/models/openai/catalog.py b/src/ai/models/openai/catalog.py deleted file mode 100644 index 91df6605..00000000 --- a/src/ai/models/openai/catalog.py +++ /dev/null @@ -1,45 +0,0 @@ -"""OpenAI direct-API model catalog. - -Model IDs are what the OpenAI Chat Completions API expects. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "openai" -_PROVIDER = "openai" - -CATALOG: dict[str, Model] = { - "gpt-5.4": Model( - id="gpt-5.4", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=2.50, output=15.0, cache_read=0.25), - ), - "gpt-5.4-mini": Model( - id="gpt-5.4-mini", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4 Mini", - capabilities=("text",), - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.75, output=4.50, cache_read=0.075), - ), - "gpt-5.4-nano": Model( - id="gpt-5.4-nano", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4 Nano", - capabilities=("text",), - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.20, output=1.25, cache_read=0.02), - ), -} diff --git a/tests/conftest.py b/tests/conftest.py index 15796650..621bf264 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,10 +11,13 @@ from ai.types import messages as messages_ # A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. -MOCK_MODEL = models.Model(id="mock-model", adapter="mock", provider="mock") - -# Register a dummy provider so _auto_client() doesn't error for provider="mock". -models._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") +MOCK_MODEL = models.Model( + id="mock-model", + adapter="mock", + provider="mock", + base_url="http://mock.test", + api_key_env="MOCK_API_KEY", +) class MockAdapter: diff --git a/tests/models/ai_gateway/test_generate_image.py b/tests/models/ai_gateway/test_generate_image.py index a57abe3f..2b1614e0 100644 --- a/tests/models/ai_gateway/test_generate_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -40,7 +40,6 @@ id="google/imagen-4.0-generate-001", adapter="ai-gateway-v3", provider="ai-gateway", - capabilities=("image",), ) @@ -60,7 +59,9 @@ def handler(req: httpx.Request) -> httpx.Response: ) client = mock_client(httpx.MockTransport(handler)) - msg = await generate(client, _IMAGE_MODEL, [user_msg("A sunset over Tokyo")]) + msg = await generate( + client, _IMAGE_MODEL, [user_msg("A sunset over Tokyo")], ImageParams() + ) assert msg.role == "assistant" assert len(msg.images) == 1 @@ -104,7 +105,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) client = mock_client(httpx.MockTransport(handler)) - msg = await generate(client, _IMAGE_MODEL, [user_msg("a dog")]) + msg = await generate(client, _IMAGE_MODEL, [user_msg("a dog")], ImageParams()) assert msg.usage is not None assert msg.usage.input_tokens == 50 @@ -128,10 +129,9 @@ def handler(req: httpx.Request) -> httpx.Response: id="openai/gpt-image-1", adapter="ai-gateway-v3", provider="ai-gateway", - capabilities=("image",), ) client = mock_client(httpx.MockTransport(handler), api_key="sk-test") - await generate(client, model, [user_msg("Hi")]) + await generate(client, model, [user_msg("Hi")], ImageParams()) assert captured["authorization"] == "Bearer sk-test" assert captured["ai-image-model-specification-version"] == "3" @@ -184,7 +184,7 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, json={"images": [_PNG_B64]}) client = mock_client(httpx.MockTransport(handler)) - await generate(client, _IMAGE_MODEL, [user_msg("test")]) + await generate(client, _IMAGE_MODEL, [user_msg("test")], ImageParams()) assert captured_url[0] == "https://gw.test/v3/ai/image-model" @@ -212,6 +212,7 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) async def test_429_rate_limit_error(self) -> None: @@ -231,6 +232,7 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) async def test_empty_images_returns_empty_message(self) -> None: @@ -243,5 +245,6 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) assert len(msg.images) == 0 diff --git a/tests/models/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py index b283fcd9..e3dd3ac8 100644 --- a/tests/models/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -43,7 +43,6 @@ id="google/veo-3.0-generate-001", adapter="ai-gateway-v3", provider="ai-gateway", - capabilities=("video",), ) diff --git a/tests/models/test_check.py b/tests/models/test_check.py index 0889a0a8..f61be377 100644 --- a/tests/models/test_check.py +++ b/tests/models/test_check.py @@ -8,6 +8,7 @@ from __future__ import annotations +import dataclasses import json from typing import Any @@ -25,12 +26,23 @@ # Fixtures # --------------------------------------------------------------------------- -_OPENAI_MODEL = model_.Model(id="gpt-5.4", adapter="openai", provider="openai") +_OPENAI_MODEL = model_.Model( + id="gpt-5.4", + adapter="openai", + provider="openai", + base_url="https://api.openai.com/v1", +) _ANTHROPIC_MODEL = model_.Model( - id="claude-opus-4-6", adapter="anthropic", provider="anthropic" + id="claude-opus-4-6", + adapter="anthropic", + provider="anthropic", + base_url="https://api.anthropic.com/v1", ) _GATEWAY_MODEL = model_.Model( - id="anthropic/claude-opus-4-6", adapter="ai-gateway-v3", provider="ai-gateway" + id="anthropic/claude-opus-4-6", + adapter="ai-gateway-v3", + provider="ai-gateway", + base_url="https://ai-gateway.vercel.sh/v3/ai", ) _UNKNOWN_MODEL = model_.Model(id="x", adapter="x", provider="unknown-provider") @@ -155,15 +167,16 @@ class TestCheckConnection: async def test_gateway_dispatches(self) -> None: config = {"models": [{"id": "anthropic/claude-opus-4-6"}]} c = _gateway_client(config_body=config) - assert await check_connection(_GATEWAY_MODEL, client=c) is True + m = dataclasses.replace(_GATEWAY_MODEL, client=c) + assert await check_connection(m) is True async def test_unknown_provider_raises(self) -> None: c = _client_with_mock(200) + m = dataclasses.replace(_UNKNOWN_MODEL, client=c) with pytest.raises(KeyError, match="unknown-provider"): - await check_connection(_UNKNOWN_MODEL, client=c) + await check_connection(m) async def test_dispatch_false_propagates(self) -> None: - assert ( - await check_connection(_OPENAI_MODEL, client=_client_with_mock(401)) - is False - ) + c = _client_with_mock(401) + m = dataclasses.replace(_OPENAI_MODEL, client=c) + assert await check_connection(m) is False diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 378a6e74..bac3853a 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -41,7 +41,7 @@ async def test_stream_basic() -> None: async def test_stream_with_explicit_client() -> None: - """ai.models.stream(client=...) forwards the provided client to the adapter.""" + """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] async def _spy_stream( @@ -63,7 +63,10 @@ async def _spy_stream( models.register_stream("mock", _spy_stream) explicit = models.Client(base_url="https://custom.test", api_key="sk-custom") - s = await models.stream(MOCK_MODEL, [ai.user_message("Hi")], client=explicit) + explicit_model = models.Model( + id="mock-model", adapter="mock", provider="mock", client=explicit + ) + s = await models.stream(explicit_model, [ai.user_message("Hi")]) async for _ in s: pass @@ -121,7 +124,8 @@ async def _structured_stream( id="gen-model", adapter="mock-gen", provider="mock", - capabilities=("image",), + base_url="http://mock.test", + api_key_env="MOCK_API_KEY", ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 2e335247..8d7a88ce 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -320,7 +320,9 @@ async def wrap_generate( @my_agent.loop async def gen_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: - result = await models.generate(context.model, context.messages) + result = await models.generate( + context.model, context.messages, models.ImageParams() + ) yield result async for _m in my_agent.run( From e85bd493f119407ccd21f897dba96a02b4d30326 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 13 Apr 2026 14:39:23 -0700 Subject: [PATCH 4/4] Add a provider protocol, move provider-related props out of model --- src/ai/__init__.py | 2 + src/ai/models/__init__.py | 8 +-- src/ai/models/ai_gateway/__init__.py | 55 ++++++++++---- src/ai/models/anthropic/__init__.py | 55 ++++++++++---- src/ai/models/core/__init__.py | 6 +- src/ai/models/core/adapters.py | 58 +++------------ src/ai/models/core/api.py | 5 +- src/ai/models/core/client.py | 19 ++--- src/ai/models/core/model.py | 11 ++- src/ai/models/core/proto.py | 65 ++++++++++++++++- src/ai/models/openai/__init__.py | 55 ++++++++++---- tests/conftest.py | 72 ++++++++++++++++++- .../models/ai_gateway/test_generate_image.py | 15 +--- .../models/ai_gateway/test_generate_video.py | 9 +-- tests/models/ai_gateway/test_stream.py | 14 +--- tests/models/test_check.py | 64 +++++++++++------ tests/models/test_public_api.py | 10 +-- 17 files changed, 341 insertions(+), 182 deletions(-) diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 23475bbc..00a7c589 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -20,6 +20,7 @@ Client, ImageParams, Model, + Provider, StreamResult, VideoParams, ai_gateway, @@ -84,6 +85,7 @@ "thinking", # Models (from models/) "Model", + "Provider", "ImageParams", "VideoParams", "Client", diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index c03b7266..55ba9648 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -27,11 +27,11 @@ from ..types.stream import StreamResultLike from .ai_gateway import ai_gateway from .anthropic import anthropic -from .core.adapters import register_check, register_generate, register_stream +from .core.adapters import register_generate, register_stream from .core.api import check_connection, generate, stream from .core.client import Client from .core.model import Model -from .core.proto import CheckConnFn, GenerateFn, StreamFn +from .core.proto import CheckConnFn, GenerateFn, Provider, StreamFn from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams from .openai import openai @@ -43,6 +43,7 @@ "GenerateParams", "ImageParams", "Model", + "Provider", "StreamFn", "StreamResult", "StreamResultLike", @@ -51,8 +52,7 @@ "ai_gateway", "anthropic", "openai", - # Adapter / check registration - "register_check", + # Adapter registration "register_generate", "register_stream", # Public API diff --git a/src/ai/models/ai_gateway/__init__.py b/src/ai/models/ai_gateway/__init__.py index 18a73b6a..96d0bdf5 100644 --- a/src/ai/models/ai_gateway/__init__.py +++ b/src/ai/models/ai_gateway/__init__.py @@ -14,13 +14,15 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING, Any +from ..core import client as client_ from ..core.model import Model from . import errors if TYPE_CHECKING: - from ..core.client import Client + pass _BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" _API_KEY_ENV = "AI_GATEWAY_API_KEY" @@ -28,32 +30,57 @@ class _AIGateway: - """Callable provider factory for the Vercel AI Gateway.""" + """Callable provider factory for the Vercel AI Gateway. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "ai-gateway-v3" + + @property + def name(self) -> str: + return "ai-gateway" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`ai_gateway.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) def __call__( self, model_id: str, *, base_url: str | None = None, - client: Client | None = None, + client: client_.Client | None = None, ) -> Model: return Model( id=model_id, - adapter="ai-gateway-v3", - provider="ai-gateway", - base_url=base_url or _BASE_URL, - api_key_env=_API_KEY_ENV, + adapter=self.adapter, + provider=self, client=client, ) - async def list(self, *, client: Client | None = None) -> list[str]: + async def list(self, *, client: client_.Client | None = None) -> list[str]: """List available model IDs from the AI Gateway.""" - from ..core import client as client_ - - c = client or client_.Client( - base_url=_BASE_URL, - api_key=__import__("os").environ.get(_API_KEY_ENV), - ) + c = client or self.client() base_url = c.base_url.rstrip("/") headers: dict[str, str] = { "ai-gateway-protocol-version": _PROTOCOL_VERSION, diff --git a/src/ai/models/anthropic/__init__.py b/src/ai/models/anthropic/__init__.py index e1c2c2b8..659ce7c9 100644 --- a/src/ai/models/anthropic/__init__.py +++ b/src/ai/models/anthropic/__init__.py @@ -13,12 +13,14 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING +from ..core import client as client_ from ..core.model import Model if TYPE_CHECKING: - from ..core.client import Client + pass _BASE_URL = "https://api.anthropic.com/v1" _API_KEY_ENV = "ANTHROPIC_API_KEY" @@ -26,32 +28,57 @@ class _Anthropic: - """Callable provider factory for Anthropic.""" + """Callable provider factory for Anthropic. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "anthropic" + + @property + def name(self) -> str: + return "anthropic" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`anthropic.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) def __call__( self, model_id: str, *, base_url: str | None = None, - client: Client | None = None, + client: client_.Client | None = None, ) -> Model: return Model( id=model_id, - adapter="anthropic", - provider="anthropic", - base_url=base_url or _BASE_URL, - api_key_env=_API_KEY_ENV, + adapter=self.adapter, + provider=self, client=client, ) - async def list(self, *, client: Client | None = None) -> list[str]: + async def list(self, *, client: client_.Client | None = None) -> list[str]: """List available model IDs from the Anthropic API.""" - from ..core import client as client_ - - c = client or client_.Client( - base_url=_BASE_URL, - api_key=__import__("os").environ.get(_API_KEY_ENV), - ) + c = client or self.client() headers = { "x-api-key": c.api_key or "", "anthropic-version": _ANTHROPIC_VERSION, diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 0be684e4..5b5e580d 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,10 +1,10 @@ """Core types for models.""" -from .adapters import register_check, register_generate, register_stream +from .adapters import register_generate, register_stream from .api import check_connection, generate, stream from .client import Client from .model import Model -from .proto import CheckConnFn, GenerateFn, StreamFn +from .proto import CheckConnFn, GenerateFn, Provider, StreamFn from .types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ @@ -14,12 +14,12 @@ "GenerateParams", "ImageParams", "Model", + "Provider", "StreamFn", "StreamResult", "VideoParams", "check_connection", "generate", - "register_check", "register_generate", "register_stream", "stream", diff --git a/src/ai/models/core/adapters.py b/src/ai/models/core/adapters.py index f7a5c114..4496c3cd 100644 --- a/src/ai/models/core/adapters.py +++ b/src/ai/models/core/adapters.py @@ -1,7 +1,14 @@ -"""Adapter and check-function registries. +"""Adapter registries. -Maps adapter/provider strings to their handler functions. Adapter modules +Maps adapter strings to their handler functions. Adapter modules are imported lazily on first use to keep import-time lightweight. + +.. note:: + + Connection checks are no longer dispatched through a registry. + Each :class:`~ai.models.core.proto.Provider` implements ``check()`` + directly, and :func:`~ai.models.core.api.check_connection` delegates + to ``model.provider.check()``. """ from __future__ import annotations @@ -75,50 +82,3 @@ def get_generate_adapter(adapter: str) -> proto.GenerateFn: f"Registered: {registered}" ) return fn - - -# --------------------------------------------------------------------------- -# Connection-check registry — keyed by *provider* (not adapter) because the -# check verifies "can this client reach this provider and does this model -# exist there". -# --------------------------------------------------------------------------- - -_check_fns: dict[str, proto.CheckConnFn] = {} -_check_fns_loaded = False - - -def _ensure_check_fns() -> None: - """Lazily register built-in check functions on first call.""" - global _check_fns_loaded # noqa: PLW0603 - if _check_fns_loaded: - return - _check_fns_loaded = True - - from ..ai_gateway import check as ai_gw_check - from ..anthropic import check as anthropic_check - from ..openai import check as openai_check - - _check_fns["ai-gateway"] = ai_gw_check.check - _check_fns["anthropic"] = anthropic_check.check - _check_fns["openai"] = openai_check.check - - -def register_check(provider: str, fn: proto.CheckConnFn) -> None: - """Register a connection-check function for a provider. - - Use this to add checks for custom providers. - """ - _check_fns[provider] = fn - - -def get_check_fn(provider: str) -> proto.CheckConnFn: - """Return the check function for *provider*, raising on miss.""" - _ensure_check_fns() - fn = _check_fns.get(provider) - if fn is None: - registered = ", ".join(sorted(_check_fns)) or "(none)" - raise KeyError( - f"No check function registered for provider={provider!r}. " - f"Registered: {registered}" - ) - return fn diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index 53f56738..84f9f22b 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -107,12 +107,11 @@ async def check_connection( consumed. The client is resolved from the model: ``model.client`` if set, - otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. + otherwise created by the provider. Non-auth transport errors (network failures, 5xx) are raised rather than returning ``False`` so that callers can distinguish "bad credentials / unknown model" from "provider unreachable". """ c = client_.auto_client(model) - check_fn = adapters.get_check_fn(model.provider) - return await check_fn(c, model) + return await model.provider.check(c, model) diff --git a/src/ai/models/core/client.py b/src/ai/models/core/client.py index 2f4c319e..a559f52d 100644 --- a/src/ai/models/core/client.py +++ b/src/ai/models/core/client.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses -import os from typing import TYPE_CHECKING, Any if TYPE_CHECKING: @@ -53,21 +52,11 @@ async def aclose(self) -> None: def auto_client(model: model_.Model) -> Client: """Create a :class:`Client` from the model's connection info. - Uses ``model.client`` if set, otherwise builds one from - ``model.base_url`` and the env var named by ``model.api_key_env``. + Uses ``model.client`` if set, otherwise delegates to + ``model.provider.client()`` which reads the provider's default + ``base_url`` and ``api_key_env``. """ if model.client is not None: return model.client - if model.base_url is None: - raise ValueError( - f"Model {model.id!r} (provider={model.provider!r}) has no " - f"base_url and no explicit client. Pass a client= to the " - f"provider factory or set base_url." - ) - - api_key: str | None = None - if model.api_key_env: - api_key = os.environ.get(model.api_key_env) - - return Client(base_url=model.base_url, api_key=api_key) + return model.provider.client() diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 26d5f2a4..769330ed 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -5,6 +5,7 @@ import dataclasses from .client import Client +from .proto import Provider @dataclasses.dataclass(frozen=True) @@ -13,15 +14,11 @@ class Model: * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). * ``adapter`` — wire protocol key (e.g. ``"ai-gateway-v3"``, ``"anthropic"``). - * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). - * ``base_url`` — API endpoint for auto-client creation. - * ``api_key_env`` — env var name to read for auto-client creation. - * ``client`` — explicit :class:`Client` override (skips auto-client). + * ``provider`` — :class:`Provider` that owns this model. + * ``client`` — explicit :class:`Client` override (skips provider's default). """ id: str adapter: str - provider: str - base_url: str | None = None - api_key_env: str | None = None + provider: Provider client: Client | None = dataclasses.field(default=None, repr=False) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 446417be..173bcfe7 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -16,10 +16,73 @@ from ...types import messages as messages_ from ...types import tools as tools_ -from .model import Model if TYPE_CHECKING: from .client import Client + from .model import Model + + +@runtime_checkable +class Provider(Protocol): + """Protocol for model providers. + + A provider carries all provider-specific configuration and behaviour: + API endpoint, authentication, client creation, connection checks, and + model enumeration. Model objects hold only pure metadata (``id``, + ``adapter``) plus a back-reference to their provider. + + Implementations must be **callable** — ``provider(model_id)`` returns + a :class:`Model`. + """ + + @property + def api_key_env(self) -> str | None: + """Env var name that holds the API key (e.g. ``"OPENAI_API_KEY"``).""" + ... + + @property + def base_url(self) -> str: + """Default base URL for the provider API.""" + ... + + @property + def adapter(self) -> str: + """Wire-protocol key used to look up stream/generate adapters.""" + ... + + @property + def name(self) -> str: + """Human-readable provider name (for repr, error messages).""" + ... + + def client(self) -> Client: + """Create a :class:`Client` from the provider's default config. + + Reads ``api_key_env`` from the environment and uses ``base_url`` + as the endpoint. + """ + ... + + async def check(self, client: Client, model: Model) -> bool: + """Check whether *client* can reach this provider and *model* exists. + + Returns ``True`` when credentials are valid **and** the model is + available. Non-auth transport errors should be raised. + """ + ... + + async def list(self, *, client: Client | None = None) -> list[str]: + """List available model IDs from the provider API.""" + ... + + def __call__( + self, + model_id: str, + *, + client: Client | None = None, + ) -> Model: + """Create a :class:`Model` for the given *model_id*.""" + ... @runtime_checkable diff --git a/src/ai/models/openai/__init__.py b/src/ai/models/openai/__init__.py index 69b2a041..9dfb113b 100644 --- a/src/ai/models/openai/__init__.py +++ b/src/ai/models/openai/__init__.py @@ -13,44 +13,71 @@ from __future__ import annotations +import os from typing import TYPE_CHECKING +from ..core import client as client_ from ..core.model import Model if TYPE_CHECKING: - from ..core.client import Client + pass _BASE_URL = "https://api.openai.com/v1" _API_KEY_ENV = "OPENAI_API_KEY" class _OpenAI: - """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`.""" + """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "openai" + + @property + def name(self) -> str: + return "openai" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`openai.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) def __call__( self, model_id: str, *, base_url: str | None = None, - client: Client | None = None, + client: client_.Client | None = None, ) -> Model: return Model( id=model_id, - adapter="openai", - provider="openai", - base_url=base_url or _BASE_URL, - api_key_env=_API_KEY_ENV, + adapter=self.adapter, + provider=self, client=client, ) - async def list(self, *, client: Client | None = None) -> list[str]: + async def list(self, *, client: client_.Client | None = None) -> list[str]: """List available model IDs from the OpenAI API.""" - from ..core import client as client_ - - c = client or client_.Client( - base_url=_BASE_URL, - api_key=__import__("os").environ.get(_API_KEY_ENV), - ) + c = client or self.client() headers = {"Authorization": f"Bearer {c.api_key}"} response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) response.raise_for_status() diff --git a/tests/conftest.py b/tests/conftest.py index 621bf264..1e8529f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,13 +10,79 @@ from ai.types import builders from ai.types import messages as messages_ + +class MockProvider: + """Minimal provider that satisfies the ``Provider`` protocol for tests. + + Carries just enough state so that ``Model`` objects can be constructed + and ``auto_client`` can resolve a client. + """ + + def __init__( + self, + *, + name: str = "mock", + adapter: str = "mock", + base_url: str = "http://mock.test", + api_key_env: str | None = "MOCK_API_KEY", + ) -> None: + self._name = name + self._adapter = adapter + self._base_url = base_url + self._api_key_env = api_key_env + + @property + def name(self) -> str: + return self._name + + @property + def adapter(self) -> str: + return self._adapter + + @property + def base_url(self) -> str: + return self._base_url + + @property + def api_key_env(self) -> str | None: + return self._api_key_env + + def client(self) -> models.Client: + import os + + api_key = os.environ.get(self._api_key_env) if self._api_key_env else None + return models.Client(base_url=self._base_url, api_key=api_key) + + async def check(self, client: models.Client, model: models.Model) -> bool: + return True + + async def list(self, *, client: models.Client | None = None) -> list[str]: + return [] + + def __call__( + self, + model_id: str, + *, + client: models.Client | None = None, + ) -> models.Model: + return models.Model( + id=model_id, + adapter=self._adapter, + provider=self, + client=client, + ) + + def __repr__(self) -> str: + return self._name + + +MOCK_PROVIDER = MockProvider() + # A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. MOCK_MODEL = models.Model( id="mock-model", adapter="mock", - provider="mock", - base_url="http://mock.test", - api_key_env="MOCK_API_KEY", + provider=MOCK_PROVIDER, ) diff --git a/tests/models/ai_gateway/test_generate_image.py b/tests/models/ai_gateway/test_generate_image.py index 2b1614e0..b3d7bfd9 100644 --- a/tests/models/ai_gateway/test_generate_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -21,9 +21,8 @@ import httpx import pytest -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.ai_gateway.generate import ImageParams, generate -from ai.models.core import model as model_ from ai.types import messages from .conftest import mock_client, user_msg @@ -36,11 +35,7 @@ _JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) _JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() -_IMAGE_MODEL = model_.Model( - id="google/imagen-4.0-generate-001", - adapter="ai-gateway-v3", - provider="ai-gateway", -) +_IMAGE_MODEL = ai_gateway("google/imagen-4.0-generate-001") # --------------------------------------------------------------------------- @@ -125,11 +120,7 @@ def handler(req: httpx.Request) -> httpx.Response: captured.update(dict(req.headers)) return httpx.Response(200, json={"images": [_PNG_B64]}) - model = model_.Model( - id="openai/gpt-image-1", - adapter="ai-gateway-v3", - provider="ai-gateway", - ) + model = ai_gateway("openai/gpt-image-1") client = mock_client(httpx.MockTransport(handler), api_key="sk-test") await generate(client, model, [user_msg("Hi")], ImageParams()) diff --git a/tests/models/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py index e3dd3ac8..8233daa6 100644 --- a/tests/models/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -22,9 +22,8 @@ import httpx import pytest -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.ai_gateway.generate import VideoParams, generate -from ai.models.core import model as model_ from ai.types import messages from .conftest import mock_client, sse, user_msg @@ -39,11 +38,7 @@ _WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) _WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() -_VIDEO_MODEL = model_.Model( - id="google/veo-3.0-generate-001", - adapter="ai-gateway-v3", - provider="ai-gateway", -) +_VIDEO_MODEL = ai_gateway("google/veo-3.0-generate-001") # --------------------------------------------------------------------------- diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index 8195759b..71dfa5d9 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -23,7 +23,7 @@ import pytest import ai -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.core import model as model_ from ai.types import messages @@ -37,11 +37,7 @@ # Helpers # --------------------------------------------------------------------------- -_TEST_MODEL = model_.Model( - id="test-provider/test-model", - adapter="ai-gateway-v3", - provider="ai-gateway", -) +_TEST_MODEL = ai_gateway("test-provider/test-model") async def _collect( @@ -220,11 +216,7 @@ def handler(req: httpx.Request) -> httpx.Response: text=sse({"type": "finish", "finishReason": "stop", "usage": {}}), ) - model = model_.Model( - id="anthropic/claude-sonnet-4", - adapter="ai-gateway-v3", - provider="ai-gateway", - ) + model = ai_gateway("anthropic/claude-sonnet-4") client = mock_client(httpx.MockTransport(handler), api_key="sk-test") await _collect(client, [user_msg("Hi")], model=model) diff --git a/tests/models/test_check.py b/tests/models/test_check.py index f61be377..29500434 100644 --- a/tests/models/test_check.py +++ b/tests/models/test_check.py @@ -15,7 +15,8 @@ import httpx import pytest -from ai.models import check_connection +from ai.models import anthropic, check_connection, openai +from ai.models.ai_gateway import ai_gateway from ai.models.ai_gateway import check as ai_gw_check from ai.models.anthropic import check as anthropic_check from ai.models.core import client as client_ @@ -26,25 +27,48 @@ # Fixtures # --------------------------------------------------------------------------- -_OPENAI_MODEL = model_.Model( - id="gpt-5.4", - adapter="openai", - provider="openai", - base_url="https://api.openai.com/v1", -) -_ANTHROPIC_MODEL = model_.Model( - id="claude-opus-4-6", - adapter="anthropic", - provider="anthropic", - base_url="https://api.anthropic.com/v1", -) -_GATEWAY_MODEL = model_.Model( - id="anthropic/claude-opus-4-6", - adapter="ai-gateway-v3", - provider="ai-gateway", - base_url="https://ai-gateway.vercel.sh/v3/ai", -) -_UNKNOWN_MODEL = model_.Model(id="x", adapter="x", provider="unknown-provider") +_OPENAI_MODEL = openai("gpt-5.4") +_ANTHROPIC_MODEL = anthropic("claude-opus-4-6") +_GATEWAY_MODEL = ai_gateway("anthropic/claude-opus-4-6") + + +class _FailProvider: + """A provider whose check always raises KeyError (for testing unknown providers).""" + + @property + def api_key_env(self) -> None: + return None + + @property + def base_url(self) -> str: + return "http://unknown.test" + + @property + def adapter(self) -> str: + return "x" + + @property + def name(self) -> str: + return "unknown-provider" + + def client(self) -> client_.Client: + return client_.Client(base_url=self.base_url) + + async def check(self, client: client_.Client, model: model_.Model) -> bool: + raise KeyError(f"No check function registered for provider={self.name!r}.") + + async def list(self, *, client: client_.Client | None = None) -> list[str]: + return [] + + def __call__( + self, model_id: str, *, client: client_.Client | None = None + ) -> model_.Model: + return model_.Model( + id=model_id, adapter=self.adapter, provider=self, client=client + ) + + +_UNKNOWN_MODEL = _FailProvider()("x") def _client_with_mock( diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index bac3853a..7ca4b1b1 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -11,7 +11,7 @@ from ai import models from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, mock_llm, text_msg +from ..conftest import MOCK_MODEL, MOCK_PROVIDER, MockProvider, mock_llm, text_msg # Module-level model so StructuredOutputPart can resolve it by FQN. @@ -64,7 +64,7 @@ async def _spy_stream( explicit = models.Client(base_url="https://custom.test", api_key="sk-custom") explicit_model = models.Model( - id="mock-model", adapter="mock", provider="mock", client=explicit + id="mock-model", adapter="mock", provider=MOCK_PROVIDER, client=explicit ) s = await models.stream(explicit_model, [ai.user_message("Hi")]) async for _ in s: @@ -120,12 +120,12 @@ async def _structured_stream( # generate() tests # --------------------------------------------------------------------------- +_MOCK_GEN_PROVIDER = MockProvider(adapter="mock-gen") + GENERATE_MODEL = models.Model( id="gen-model", adapter="mock-gen", - provider="mock", - base_url="http://mock.test", - api_key_env="MOCK_API_KEY", + provider=_MOCK_GEN_PROVIDER, )