diff --git a/examples/fastapi-vite/backend/agent.py b/examples/fastapi-vite/backend/agent.py index edfdf068..b3af15c3 100644 --- a/examples/fastapi-vite/backend/agent.py +++ b/examples/fastapi-vite/backend/agent.py @@ -10,7 +10,7 @@ import ai -MODEL = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +MODEL = ai.ai_gateway("anthropic/claude-sonnet-4") @ai.tool @@ -42,7 +42,9 @@ async def graph(context: ai.Context) -> AsyncGenerator[ai.Message]: if not tool_calls: return - results = await asyncio.gather(*(_execute_with_approval(tc) for tc in tool_calls)) + results = await asyncio.gather( + *(_execute_with_approval(tc) for tc in tool_calls) + ) yield ai.tool_message(*results) diff --git a/examples/multiagent-textual/server.py b/examples/multiagent-textual/server.py index 6b85ffcb..a35935f0 100644 --- a/examples/multiagent-textual/server.py +++ b/examples/multiagent-textual/server.py @@ -58,7 +58,7 @@ class Approval(pydantic.BaseModel): # Model # --------------------------------------------------------------------------- -MODEL = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +MODEL = ai.ai_gateway("anthropic/claude-sonnet-4") # --------------------------------------------------------------------------- diff --git a/examples/samples/agent_custom_loop.py b/examples/samples/agent_custom_loop.py index 8c540660..27a5a17a 100644 --- a/examples/samples/agent_custom_loop.py +++ b/examples/samples/agent_custom_loop.py @@ -19,7 +19,7 @@ async def get_population(city: str) -> int: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") tools = [get_weather, get_population] my_agent = ai.agent(tools=tools) diff --git a/examples/samples/agent_hooks.py b/examples/samples/agent_hooks.py index 361fe7a2..7180311b 100644 --- a/examples/samples/agent_hooks.py +++ b/examples/samples/agent_hooks.py @@ -26,7 +26,7 @@ async def contact_mothership(query: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[contact_mothership]) diff --git a/examples/samples/agent_hooks_serverless.py b/examples/samples/agent_hooks_serverless.py index 2d7de2c2..23717689 100644 --- a/examples/samples/agent_hooks_serverless.py +++ b/examples/samples/agent_hooks_serverless.py @@ -31,7 +31,7 @@ async def delete_file(path: str) -> str: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[delete_file]) diff --git a/examples/samples/check_connection.py b/examples/samples/check_connection.py index 9efc6a66..b89d2f16 100644 --- a/examples/samples/check_connection.py +++ b/examples/samples/check_connection.py @@ -1,13 +1,19 @@ -"""Check connection — verify credentials and model availability for all providers.""" +"""Check connection and list models — verify credentials and model availability.""" import asyncio import ai MODELS = [ - ai.model("ai-gateway", "anthropic/claude-sonnet-4"), - ai.model("anthropic", "claude-sonnet-4-20250514"), - ai.model("openai", "gpt-5.4-mini"), + ai.ai_gateway("anthropic/claude-sonnet-4"), + ai.anthropic("claude-sonnet-4-20250514"), + ai.openai("gpt-5.4-mini"), +] + +PROVIDERS = [ + ("ai_gateway", ai.ai_gateway), + ("anthropic", ai.anthropic), + ("openai", ai.openai), ] @@ -20,9 +26,23 @@ async def _check(model: ai.Model) -> None: print(f" {status} {model.provider}/{model.id}") +async def _list_models(name: str, provider: object) -> None: + try: + ids: list[str] = await provider.list() # type: ignore[union-attr] + print(f" {name}: {len(ids)} models") + for mid in ids: + print(f" - {mid}") + except Exception as exc: + print(f" {name}: [ERR] {exc}") + + async def main() -> None: print("Checking connections...\n") await asyncio.gather(*[_check(m) for m in MODELS]) + + print("\nListing models...\n") + await asyncio.gather(*[_list_models(n, p) for n, p in PROVIDERS]) + print() diff --git a/examples/samples/explicit_client.py b/examples/samples/explicit_client.py index 9438a4f4..fa95ada8 100644 --- a/examples/samples/explicit_client.py +++ b/examples/samples/explicit_client.py @@ -5,8 +5,6 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") - # Explicit client — useful for custom auth, proxies, or self-hosted gateways. client = ai.Client( base_url="https://ai-gateway.vercel.sh/v3/ai", @@ -14,12 +12,14 @@ headers={"X-Custom-Header": "example"}, ) +model = ai.ai_gateway("anthropic/claude-sonnet-4", client=client) + messages = [ai.user_message("Hello!")] async def main() -> None: try: - async for msg in await ai.models.stream(model, messages, client=client): + async for msg in await ai.models.stream(model, messages): if msg.text_delta: print(msg.text_delta, end="", flush=True) print() diff --git a/examples/samples/image_generation.py b/examples/samples/image_generation.py index 99b301d9..cd669f3a 100644 --- a/examples/samples/image_generation.py +++ b/examples/samples/image_generation.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "google/imagen-4.0-generate-001") +model = ai.ai_gateway("google/imagen-4.0-generate-001") messages = [ ai.user_message( diff --git a/examples/samples/middleware_simple.py b/examples/samples/middleware_simple.py index e9cb7859..e0deb06b 100644 --- a/examples/samples/middleware_simple.py +++ b/examples/samples/middleware_simple.py @@ -88,7 +88,7 @@ async def get_population(city: str) -> int: async def main() -> None: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") my_agent = ai.agent(tools=[get_weather, get_population]) diff --git a/examples/samples/stream.py b/examples/samples/stream.py index 762afb84..dbfaae48 100644 --- a/examples/samples/stream.py +++ b/examples/samples/stream.py @@ -4,7 +4,7 @@ import ai -model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") +model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ ai.system_message("Be concise."), diff --git a/examples/samples/video_generation.py b/examples/samples/video_generation.py index 8692faee..196a63d4 100644 --- a/examples/samples/video_generation.py +++ b/examples/samples/video_generation.py @@ -6,7 +6,7 @@ import ai -model = ai.model("ai-gateway", "google/veo-3.0-generate-001") +model = ai.ai_gateway("google/veo-3.0-generate-001") messages = [ ai.user_message( diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index dcf54b0c..8039df76 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -96,7 +96,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] @@ -166,7 +166,7 @@ async def run_tool(tc: ai.ToolCallPart) -> ai.ToolResultPart: class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages: list[ai.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/examples/temporal-middleware/main.py b/examples/temporal-middleware/main.py index 9764d007..7e8a56fb 100644 --- a/examples/temporal-middleware/main.py +++ b/examples/temporal-middleware/main.py @@ -113,7 +113,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages = [ai.Message.model_validate(m) for m in params.messages] tools = [ai.ToolSchema(return_type=None, **t) for t in params.tool_schemas] @@ -194,7 +194,7 @@ async def wrap_tool( class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.model("ai-gateway", "anthropic/claude-sonnet-4") + model = ai.ai_gateway("anthropic/claude-sonnet-4") messages: list[ai.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 47757ec5..00a7c589 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -20,12 +20,14 @@ Client, ImageParams, Model, - ModelCost, + Provider, StreamResult, VideoParams, + ai_gateway, + anthropic, check_connection, generate, - model, + openai, stream, ) @@ -83,17 +85,20 @@ "thinking", # Models (from models/) "Model", - "ModelCost", + "Provider", "ImageParams", "VideoParams", "Client", "StreamResult", "StreamResultLike", "check_connection", - "model", - "models", "stream", "generate", + "models", + # Provider factories + "openai", + "anthropic", + "ai_gateway", # Agents — primary API "Agent", "agent", diff --git a/src/ai/middleware.py b/src/ai/middleware.py index bc703859..e5eaaa97 100644 --- a/src/ai/middleware.py +++ b/src/ai/middleware.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from .agents.agent import Tool - from .models.core.client import Client from .models.core.model import Model @@ -50,7 +49,6 @@ class ModelContext: messages: list[messages_.Message] tools: Sequence[tools_.ToolLike] | None output_type: type[pydantic.BaseModel] | None - client: Client | None kwargs: dict[str, Any] def __post_init__(self) -> None: @@ -67,7 +65,6 @@ class GenerateContext: model: Model messages: list[messages_.Message] params: Any - client: Client | None = None def __post_init__(self) -> None: object.__setattr__(self, "messages", list(self.messages)) diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 2dfe243f..55ba9648 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -3,351 +3,37 @@ Usage:: import ai - from ai.types import Message, TextPart + from ai.models import openai, anthropic, ai_gateway - # look up a model from the catalog - opus = ai.model("ai-gateway", "anthropic/claude-opus-4-6") - - msgs = [Message(role="user", parts=[TextPart(text="hello")])] + model = openai("gpt-5.4") + model = anthropic("claude-sonnet-4-6") + model = ai_gateway("anthropic/claude-sonnet-4") # stream — auto-creates client from env vars - s = await ai.stream(opus, msgs) + msgs = [ai.user_message("hello")] + s = await ai.stream(model, msgs) async for msg in s: print(msg.text_delta, end="") - # buffer the whole response - result = await ai.models.buffer(await ai.stream(opus, msgs)) - print(result.text) + # explicit client for custom auth + client = ai.Client(base_url="https://custom.example.com/v1", api_key="sk-...") + model = openai("gpt-5.4", client=client) + s = await ai.stream(model, msgs) - # explicit client - client = ai.Client( - base_url="https://custom.example.com/v3/ai", api_key="sk-...", - ) - s = await ai.stream(opus, msgs, client=client) - async for msg in s: - ... + # list available models + ids = await openai.list() """ -from __future__ import annotations - -import os -from collections.abc import AsyncGenerator, Sequence -from typing import Any - -import pydantic - -from .. import middleware as middleware_ -from ..types import messages as messages_ -from ..types import tools as tools_ from ..types.stream import StreamResultLike -from .ai_gateway.types import GenerateParams, ImageParams, VideoParams -from .core.catalog import get_models, get_providers, register_catalog -from .core.catalog import model as model +from .ai_gateway import ai_gateway +from .anthropic import anthropic +from .core.adapters import register_generate, register_stream +from .core.api import check_connection, generate, stream from .core.client import Client -from .core.model import Model, ModelCost -from .core.proto import CheckConnFn, GenerateFn, StreamFn - -# --------------------------------------------------------------------------- -# Adapter registry — maps adapter string → adapter function. -# Adapter modules are imported lazily on first use. -# --------------------------------------------------------------------------- - -_stream_adapters: dict[str, StreamFn] = {} -_generate_adapters: dict[str, GenerateFn] = {} -_adapters_loaded = False - - -def _ensure_adapters() -> None: - """Lazily register built-in adapter functions on first call.""" - global _adapters_loaded # noqa: PLW0603 - if _adapters_loaded: - return - _adapters_loaded = True - - from .ai_gateway.generate import generate as ai_gw_generate - from .ai_gateway.stream import stream as ai_gw_stream - from .anthropic.adapter import stream as anthropic_stream - from .openai.adapter import stream as openai_stream - - _stream_adapters["ai-gateway-v3"] = ai_gw_stream - _generate_adapters["ai-gateway-v3"] = ai_gw_generate - _stream_adapters["openai"] = openai_stream - _stream_adapters["anthropic"] = anthropic_stream - - -def register_stream(adapter: str, fn: StreamFn) -> None: - """Register a stream adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _stream_adapters[adapter] = fn - - -def register_generate(adapter: str, fn: GenerateFn) -> None: - """Register a generate adapter function for the given adapter key. - - Use this to add custom adapters (or override built-in ones). - """ - _generate_adapters[adapter] = fn - - -# --------------------------------------------------------------------------- -# Connection-check registry — maps *provider* string → check function. -# Keyed by provider (not adapter) because the check verifies "can this -# client reach this provider and does this model exist there". -# --------------------------------------------------------------------------- - -_check_fns: dict[str, CheckConnFn] = {} -_check_fns_loaded = False - - -def _ensure_check_fns() -> None: - """Lazily register built-in check functions on first call.""" - global _check_fns_loaded # noqa: PLW0603 - if _check_fns_loaded: - return - _check_fns_loaded = True - - from .ai_gateway import check as ai_gw_check - from .anthropic import check as anthropic_check - from .openai import check as openai_check - - _check_fns["ai-gateway"] = ai_gw_check.check - _check_fns["anthropic"] = anthropic_check.check - _check_fns["openai"] = openai_check.check - - -def register_check(provider: str, fn: CheckConnFn) -> None: - """Register a connection-check function for a provider. - - Use this to add checks for custom providers. - """ - _check_fns[provider] = fn - - -# --------------------------------------------------------------------------- -# Provider defaults — base URLs and env var names for auto-client creation. -# --------------------------------------------------------------------------- - -_PROVIDER_DEFAULTS: dict[str, tuple[str, str]] = { - "ai-gateway": ("https://ai-gateway.vercel.sh/v3/ai", "AI_GATEWAY_API_KEY"), - "anthropic": ("https://api.anthropic.com/v1", "ANTHROPIC_API_KEY"), - "openai": ("https://api.openai.com/v1", "OPENAI_API_KEY"), -} - - -def _auto_client(model: Model) -> Client: - """Create a :class:`Client` from env vars for the given model's provider.""" - defaults = _PROVIDER_DEFAULTS.get(model.provider) - if defaults is None: - raise ValueError( - f"No default client config for provider {model.provider!r}. " - f"Pass an explicit client= argument." - ) - base_url, env_var = defaults - return Client(base_url=base_url, api_key=os.environ.get(env_var)) - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -class StreamResult: - """Wrapper around a message stream. Async-iterable; collects the final result. - - Properties like ``.text`` and ``.tool_calls`` delegate to the final - ``Message`` snapshot and are available after iteration completes. - - Satisfies :class:`~ai.types.StreamResultLike`. - """ - - def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: - self._gen = gen - self._final: messages_.Message | None = None - - @classmethod - def from_generator(cls, gen: AsyncGenerator[messages_.Message]) -> StreamResult: - """Create a :class:`StreamResult` from an async generator. - - This is the public API for middleware that needs to transform or - replace the stream returned by ``wrap_model``:: - - async def wrap_model(self, call, next): - original = await next(call) - - async def _transformed(): - async for msg in original: - yield modify(msg) - - return StreamResult.from_generator(_transformed()) - """ - return cls(gen) - - def __aiter__(self) -> AsyncGenerator[messages_.Message]: - return self._iterate() - - async def _iterate(self) -> AsyncGenerator[messages_.Message]: - async for msg in self._gen: - self._final = msg - yield msg - - @property - def text(self) -> str: - return self._final.text if self._final else "" - - @property - def tool_calls(self) -> list[messages_.ToolCallPart]: - return self._final.tool_calls if self._final else [] - - @property - def usage(self) -> messages_.Usage | None: - return self._final.usage if self._final else None - - @property - def output(self) -> Any: - """Parsed structured output from the final message, if available.""" - return self._final.output if self._final else None - - -async def stream( - model: Model, - messages: list[messages_.Message], - *, - tools: Sequence[tools_.ToolLike] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - client: Client | None = None, - **kwargs: Any, -) -> StreamResultLike: - """Stream an LLM response. - - Returns a :class:`StreamResultLike` that is async-iterable and - collects the final ``Message``. After iteration, access ``.text``, - ``.tool_calls``, ``.usage``, etc. - - Without middleware the concrete type is :class:`StreamResult`; with - middleware it may be any :class:`~ai.StreamResultLike`. - """ - call = middleware_.ModelContext( - model=model, - messages=messages, - tools=tools, - output_type=output_type, - client=client, - kwargs=kwargs, - ) - - async def _real(call: middleware_.ModelContext) -> StreamResultLike: - _ensure_adapters() - c = call.client or _auto_client(call.model) - adapter_fn = _stream_adapters.get(call.model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_stream_adapters)) or "(none)" - raise KeyError( - f"No stream adapter registered for adapter={call.model.adapter!r}. " - f"Registered: {registered}" - ) - return StreamResult( - adapter_fn( - c, - call.model, - call.messages, - tools=call.tools, - output_type=call.output_type, - **call.kwargs, - ) - ) - - chain = middleware_._build_model_chain(_real) - return await chain(call) - - -async def generate( - model: Model, - messages: list[messages_.Message], - params: GenerateParams | None = None, - *, - client: Client | None = None, -) -> messages_.Message: - """Generate a response (images, video, etc.). - - Resolves the adapter function from ``model.adapter``, auto-creates a - :class:`Client` from env vars if none is provided. - - ``params`` controls the generation type: - - * :class:`ImageParams` — image generation (``/image-model``). - * :class:`VideoParams` — video generation (``/video-model``). - * ``None`` — auto-detect from ``model.capabilities``. - """ - call = middleware_.GenerateContext( - model=model, - messages=messages, - params=params, - client=client, - ) - - async def _real(call: middleware_.GenerateContext) -> messages_.Message: - _ensure_adapters() - c = call.client or _auto_client(call.model) - adapter_fn = _generate_adapters.get(call.model.adapter) - if adapter_fn is None: - registered = ", ".join(sorted(_generate_adapters)) or "(none)" - raise KeyError( - f"No generate adapter registered for adapter={call.model.adapter!r}. " - f"Registered: {registered}" - ) - return await adapter_fn(c, call.model, call.messages, params=call.params) - - chain = middleware_._build_generate_chain(_real) - return await chain(call) - - -async def check_connection( - model: Model, - *, - client: Client | None = None, -) -> bool: - """Check whether *client* can reach *model*'s provider and the model exists. - - Returns ``True`` when the credentials are valid **and** the model is - available on the remote side — i.e. a subsequent :func:`stream` or - :func:`generate` call should succeed (network conditions permitting). - - This only hits free metadata endpoints; no tokens or credits are - consumed. - - If no *client* is given, one is auto-created from environment - variables (same logic as :func:`stream`). - - Non-auth transport errors (network failures, 5xx) are raised rather - than returning ``False`` so that callers can distinguish "bad - credentials / unknown model" from "provider unreachable". - """ - _ensure_check_fns() - c = client or _auto_client(model) - check_fn = _check_fns.get(model.provider) - if check_fn is None: - registered = ", ".join(sorted(_check_fns)) or "(none)" - raise KeyError( - f"No check function registered for provider={model.provider!r}. " - f"Registered: {registered}" - ) - return await check_fn(c, model) - - -async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: - """Drain a stream and return the final ``Message``. - - Raises :class:`ValueError` if the stream yields nothing. - """ - result: messages_.Message | None = None - async for msg in gen: - result = msg - if result is None: - raise ValueError("empty stream") - return result - +from .core.model import Model +from .core.proto import CheckConnFn, GenerateFn, Provider, StreamFn +from .core.types import GenerateParams, ImageParams, StreamResult, VideoParams +from .openai import openai __all__ = [ # Core types @@ -357,22 +43,20 @@ async def buffer(gen: AsyncGenerator[messages_.Message]) -> messages_.Message: "GenerateParams", "ImageParams", "Model", - "ModelCost", + "Provider", "StreamFn", "StreamResult", "StreamResultLike", "VideoParams", - # Catalog - "get_models", - "get_providers", - "model", - "register_catalog", + # Provider factories + "ai_gateway", + "anthropic", + "openai", + # Adapter registration + "register_generate", + "register_stream", # Public API - "buffer", "check_connection", "generate", - "register_check", - "register_generate", - "register_stream", "stream", ] diff --git a/src/ai/models/ai_gateway/__init__.py b/src/ai/models/ai_gateway/__init__.py index 4ec1f11e..96d0bdf5 100644 --- a/src/ai/models/ai_gateway/__init__.py +++ b/src/ai/models/ai_gateway/__init__.py @@ -1,17 +1,108 @@ -"""AI Gateway provider — adapter for the Vercel AI Gateway v3 protocol. +"""AI Gateway provider. + +Usage:: + + from ai.models import ai_gateway + + model = ai_gateway("anthropic/claude-sonnet-4") + ids = await ai_gateway.list() Heavy adapter modules (``.generate``, ``.stream``) are loaded lazily so that ``import ai`` does not pull in ``httpx`` and other I/O libraries at import time. This matters for sandboxed runtimes (e.g. Temporal workflow workers). """ +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Any + +from ..core import client as client_ +from ..core.model import Model from . import errors -from .types import GenerateParams, ImageParams, VideoParams + +if TYPE_CHECKING: + pass + +_BASE_URL = "https://ai-gateway.vercel.sh/v3/ai" +_API_KEY_ENV = "AI_GATEWAY_API_KEY" +_PROTOCOL_VERSION = "0.0.1" + + +class _AIGateway: + """Callable provider factory for the Vercel AI Gateway. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "ai-gateway-v3" + + @property + def name(self) -> str: + return "ai-gateway" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`ai_gateway.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: client_.Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: client_.Client | None = None) -> list[str]: + """List available model IDs from the AI Gateway.""" + c = client or self.client() + base_url = c.base_url.rstrip("/") + headers: dict[str, str] = { + "ai-gateway-protocol-version": _PROTOCOL_VERSION, + } + if c.api_key: + headers["Authorization"] = f"Bearer {c.api_key}" + headers["ai-gateway-auth-method"] = "api-key" + + config_url = f"{base_url}/config" + response = await c.http.get(config_url, headers=headers) + response.raise_for_status() + data: dict[str, Any] = response.json() + return sorted(str(m["id"]) for m in data.get("models", [])) + + def __repr__(self) -> str: + return "ai_gateway" + + +ai_gateway = _AIGateway() __all__ = [ - "GenerateParams", - "ImageParams", - "VideoParams", + "ai_gateway", "errors", ] diff --git a/src/ai/models/ai_gateway/catalog.py b/src/ai/models/ai_gateway/catalog.py deleted file mode 100644 index a57fc044..00000000 --- a/src/ai/models/ai_gateway/catalog.py +++ /dev/null @@ -1,150 +0,0 @@ -"""AI Gateway model catalog. - -Model IDs use the gateway's ``provider/model-name`` format. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "ai-gateway-v3" -_PROVIDER = "ai-gateway" - - -def _text( - id: str, - name: str, - *, - context_window: int, - max_output_tokens: int, - cost: ModelCost, -) -> Model: - return Model( - id=id, - adapter=_ADAPTER, - provider=_PROVIDER, - name=name, - capabilities=("text",), - context_window=context_window, - max_output_tokens=max_output_tokens, - cost=cost, - ) - - -def _media( - id: str, - name: str, - *, - capabilities: tuple[str, ...], - context_window: int = 0, - max_output_tokens: int = 0, - cost: ModelCost | None = None, -) -> Model: - return Model( - id=id, - adapter=_ADAPTER, - provider=_PROVIDER, - name=name, - capabilities=capabilities, - context_window=context_window, - max_output_tokens=max_output_tokens, - cost=cost, - ) - - -# --------------------------------------------------------------------------- -# Catalog -# --------------------------------------------------------------------------- - -CATALOG: dict[str, Model] = { - # -- Anthropic (via gateway) ------------------------------------------- - "anthropic/claude-opus-4-6": _text( - "anthropic/claude-opus-4-6", - "Claude Opus 4.6", - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=5.0, output=25.0, cache_read=0.50, cache_write=6.25), - ), - "anthropic/claude-sonnet-4-6": _text( - "anthropic/claude-sonnet-4-6", - "Claude Sonnet 4.6", - context_window=1_000_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - "anthropic/claude-haiku-4-5": _text( - "anthropic/claude-haiku-4-5", - "Claude Haiku 4.5", - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=1.0, output=5.0, cache_read=0.10, cache_write=1.25), - ), - "anthropic/claude-sonnet-4": _text( - "anthropic/claude-sonnet-4", - "Claude Sonnet 4", - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - # -- OpenAI (via gateway) ---------------------------------------------- - "openai/gpt-5.4": _text( - "openai/gpt-5.4", - "GPT-5.4", - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=2.50, output=15.0, cache_read=0.25), - ), - "openai/gpt-5.4-mini": _text( - "openai/gpt-5.4-mini", - "GPT-5.4 Mini", - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.75, output=4.50, cache_read=0.075), - ), - "openai/gpt-5.4-nano": _text( - "openai/gpt-5.4-nano", - "GPT-5.4 Nano", - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.20, output=1.25, cache_read=0.02), - ), - # -- Google (via gateway) ---------------------------------------------- - "google/gemini-2.5-pro": _text( - "google/gemini-2.5-pro", - "Gemini 2.5 Pro", - context_window=1_000_000, - max_output_tokens=65_536, - cost=ModelCost(input=1.25, output=10.0, cache_read=0.315), - ), - "google/gemini-2.5-flash": _text( - "google/gemini-2.5-flash", - "Gemini 2.5 Flash", - context_window=1_000_000, - max_output_tokens=65_536, - cost=ModelCost(input=0.15, output=0.60, cache_read=0.0375), - ), - # -- Image / video models (via gateway) -------------------------------- - "google/gemini-3-pro-image": _media( - "google/gemini-3-pro-image", - "Gemini 3 Pro Image", - capabilities=("text", "image"), - context_window=1_000_000, - max_output_tokens=65_536, - ), - "google/imagen-4.0-generate-001": _media( - "google/imagen-4.0-generate-001", - "Imagen 4.0", - capabilities=("image",), - ), - "openai/gpt-image-1": _media( - "openai/gpt-image-1", - "GPT Image 1", - capabilities=("image",), - ), - "google/veo-3.0-generate-001": _media( - "google/veo-3.0-generate-001", - "Veo 3.0", - capabilities=("video",), - ), -} diff --git a/src/ai/models/ai_gateway/generate.py b/src/ai/models/ai_gateway/generate.py index 0d0407fe..9a6dee45 100644 --- a/src/ai/models/ai_gateway/generate.py +++ b/src/ai/models/ai_gateway/generate.py @@ -1,8 +1,6 @@ """AI Gateway v3 generation adapter — image-model and video-model endpoints. -Provides typed parameter objects (:class:`ImageParams`, :class:`VideoParams`) -and a unified :func:`generate` entry point that dispatches based on param type -and validates against model capabilities. +Unified :func:`generate` entry point that dispatches based on param type. """ from __future__ import annotations @@ -16,10 +14,10 @@ from ..core import client as client_ from ..core import model as model_ from ..core.helpers import files +from ..core.types import GenerateParams as GenerateParams +from ..core.types import ImageParams as ImageParams +from ..core.types import VideoParams as VideoParams from . import _common, errors -from .types import GenerateParams as GenerateParams -from .types import ImageParams as ImageParams -from .types import VideoParams as VideoParams # --------------------------------------------------------------------------- # Image generation — /image-model @@ -155,57 +153,17 @@ async def _generate_video( # --------------------------------------------------------------------------- -def _check_capabilities( - model: model_.Model, - params: GenerateParams, -) -> None: - """Validate that model capabilities match the requested generation type.""" - caps = model.capabilities - - if isinstance(params, VideoParams): - if "video" not in caps: - raise ValueError( - f"Model {model.id!r} does not have 'video' capability " - f"(capabilities={caps}). Use ImageParams for image models." - ) - if "text" in caps and "video" not in caps: - raise ValueError( - f"Model {model.id!r} is a text model (capabilities={caps}). " - f"Use stream() for text generation, not generate()." - ) - elif isinstance(params, ImageParams): - if "video" in caps and "image" not in caps: - raise ValueError( - f"Model {model.id!r} has 'video' capability but not 'image' " - f"(capabilities={caps}). Use VideoParams for video models." - ) - if "text" in caps and "image" not in caps: - raise ValueError( - f"Model {model.id!r} is a text model (capabilities={caps}). " - f"Use stream() for text generation, not generate()." - ) - - async def generate( client: client_.Client, model: model_.Model, messages: list[messages_.Message], - params: GenerateParams | None = None, + params: GenerateParams, ) -> messages_.Message: """Generate media (images or video) through the AI Gateway. Dispatches to ``/image-model`` or ``/video-model`` based on ``params`` - type, with fallback to model capabilities when ``params`` is ``None``. - - Raises :class:`ValueError` if the model capabilities don't match the - requested generation type. + type. """ - # Auto-detect from capabilities when no params provided - if params is None: - params = VideoParams() if "video" in model.capabilities else ImageParams() - - _check_capabilities(model, params) - if isinstance(params, VideoParams): return await _generate_video(client, model, messages, params) return await _generate_image(client, model, messages, params) diff --git a/src/ai/models/ai_gateway/types.py b/src/ai/models/ai_gateway/types.py deleted file mode 100644 index d1ca0658..00000000 --- a/src/ai/models/ai_gateway/types.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Parameter types for AI Gateway generation endpoints. - -Extracted into a standalone module so they can be imported without pulling in -heavy dependencies (``httpx``, etc.) that the adapter functions need. -""" - -from __future__ import annotations - -from typing import Any - -import pydantic - -_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) - - -class ImageParams(pydantic.BaseModel): - """Parameters for image generation (``/image-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - size: str | None = None - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -class VideoParams(pydantic.BaseModel): - """Parameters for video generation (``/video-model`` endpoint).""" - - model_config = _PARAMS_CONFIG - - n: int = 1 - aspect_ratio: str | None = pydantic.Field( - default=None, serialization_alias="aspectRatio" - ) - resolution: str | None = None - duration: int | None = None - fps: int | None = None - seed: int | None = None - provider_options: dict[str, Any] = pydantic.Field( - default_factory=dict, serialization_alias="providerOptions" - ) - - -GenerateParams = ImageParams | VideoParams diff --git a/src/ai/models/anthropic/__init__.py b/src/ai/models/anthropic/__init__.py index 2a0fd334..659ce7c9 100644 --- a/src/ai/models/anthropic/__init__.py +++ b/src/ai/models/anthropic/__init__.py @@ -1,10 +1,100 @@ -"""Anthropic provider — adapter for the Anthropic messages API. +"""Anthropic provider. + +Usage:: + + from ai.models import anthropic + + model = anthropic("claude-sonnet-4-6") + ids = await anthropic.list() The adapter module is loaded lazily to avoid pulling in the ``anthropic`` SDK at import time. """ -__all__: list[str] = [] +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from ..core import client as client_ +from ..core.model import Model + +if TYPE_CHECKING: + pass + +_BASE_URL = "https://api.anthropic.com/v1" +_API_KEY_ENV = "ANTHROPIC_API_KEY" +_ANTHROPIC_VERSION = "2023-06-01" + + +class _Anthropic: + """Callable provider factory for Anthropic. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "anthropic" + + @property + def name(self) -> str: + return "anthropic" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`anthropic.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: client_.Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: client_.Client | None = None) -> list[str]: + """List available model IDs from the Anthropic API.""" + c = client or self.client() + headers = { + "x-api-key": c.api_key or "", + "anthropic-version": _ANTHROPIC_VERSION, + } + response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "anthropic" + + +anthropic = _Anthropic() + +__all__ = ["anthropic"] def __getattr__(name: str) -> object: diff --git a/src/ai/models/anthropic/catalog.py b/src/ai/models/anthropic/catalog.py deleted file mode 100644 index 333f9584..00000000 --- a/src/ai/models/anthropic/catalog.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Anthropic direct-API model catalog. - -Model IDs are what the Anthropic Messages API expects. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "anthropic" -_PROVIDER = "anthropic" - -CATALOG: dict[str, Model] = { - "claude-opus-4-6": Model( - id="claude-opus-4-6", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Opus 4.6", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=5.0, output=25.0, cache_read=0.50, cache_write=6.25), - ), - "claude-sonnet-4-6": Model( - id="claude-sonnet-4-6", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Sonnet 4.6", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), - "claude-haiku-4-5": Model( - id="claude-haiku-4-5", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Haiku 4.5", - capabilities=("text",), - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=1.0, output=5.0, cache_read=0.10, cache_write=1.25), - ), - "claude-sonnet-4-20250514": Model( - id="claude-sonnet-4-20250514", - adapter=_ADAPTER, - provider=_PROVIDER, - name="Claude Sonnet 4", - capabilities=("text",), - context_window=200_000, - max_output_tokens=64_000, - cost=ModelCost(input=3.0, output=15.0, cache_read=0.30, cache_write=3.75), - ), -} diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 83892315..5b5e580d 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -1,19 +1,26 @@ """Core types for models.""" -from .catalog import get_models, get_providers, register_catalog -from .catalog import model as model_factory +from .adapters import register_generate, register_stream +from .api import check_connection, generate, stream from .client import Client -from .model import Model, ModelCost -from .proto import GenerateFn, StreamFn +from .model import Model +from .proto import CheckConnFn, GenerateFn, Provider, StreamFn +from .types import GenerateParams, ImageParams, StreamResult, VideoParams __all__ = [ + "CheckConnFn", "Client", "GenerateFn", + "GenerateParams", + "ImageParams", "Model", - "ModelCost", + "Provider", "StreamFn", - "get_models", - "get_providers", - "model_factory", - "register_catalog", + "StreamResult", + "VideoParams", + "check_connection", + "generate", + "register_generate", + "register_stream", + "stream", ] diff --git a/src/ai/models/core/adapters.py b/src/ai/models/core/adapters.py new file mode 100644 index 00000000..4496c3cd --- /dev/null +++ b/src/ai/models/core/adapters.py @@ -0,0 +1,84 @@ +"""Adapter registries. + +Maps adapter strings to their handler functions. Adapter modules +are imported lazily on first use to keep import-time lightweight. + +.. note:: + + Connection checks are no longer dispatched through a registry. + Each :class:`~ai.models.core.proto.Provider` implements ``check()`` + directly, and :func:`~ai.models.core.api.check_connection` delegates + to ``model.provider.check()``. +""" + +from __future__ import annotations + +from . import proto + +# --------------------------------------------------------------------------- +# Stream / generate adapter registry +# --------------------------------------------------------------------------- + +_stream_adapters: dict[str, proto.StreamFn] = {} +_generate_adapters: dict[str, proto.GenerateFn] = {} +_adapters_loaded = False + + +def _ensure_adapters() -> None: + """Lazily register built-in adapter functions on first call.""" + global _adapters_loaded # noqa: PLW0603 + if _adapters_loaded: + return + _adapters_loaded = True + + from ..ai_gateway.generate import generate as ai_gw_generate + from ..ai_gateway.stream import stream as ai_gw_stream + from ..anthropic.adapter import stream as anthropic_stream + from ..openai.adapter import stream as openai_stream + + _stream_adapters["ai-gateway-v3"] = ai_gw_stream + _generate_adapters["ai-gateway-v3"] = ai_gw_generate + _stream_adapters["openai"] = openai_stream + _stream_adapters["anthropic"] = anthropic_stream + + +def register_stream(adapter: str, fn: proto.StreamFn) -> None: + """Register a stream adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _stream_adapters[adapter] = fn + + +def register_generate(adapter: str, fn: proto.GenerateFn) -> None: + """Register a generate adapter function for the given adapter key. + + Use this to add custom adapters (or override built-in ones). + """ + _generate_adapters[adapter] = fn + + +def get_stream_adapter(adapter: str) -> proto.StreamFn: + """Return the stream adapter for *adapter*, raising on miss.""" + _ensure_adapters() + fn = _stream_adapters.get(adapter) + if fn is None: + registered = ", ".join(sorted(_stream_adapters)) or "(none)" + raise KeyError( + f"No stream adapter registered for adapter={adapter!r}. " + f"Registered: {registered}" + ) + return fn + + +def get_generate_adapter(adapter: str) -> proto.GenerateFn: + """Return the generate adapter for *adapter*, raising on miss.""" + _ensure_adapters() + fn = _generate_adapters.get(adapter) + if fn is None: + registered = ", ".join(sorted(_generate_adapters)) or "(none)" + raise KeyError( + f"No generate adapter registered for adapter={adapter!r}. " + f"Registered: {registered}" + ) + return fn diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py new file mode 100644 index 00000000..84f9f22b --- /dev/null +++ b/src/ai/models/core/api.py @@ -0,0 +1,117 @@ +"""Top-level orchestration — stream(), generate(), check_connection(). + +These wire together adapters, middleware chains, and auto-client creation. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import pydantic + +from ... import middleware as middleware_ +from ...types import messages as messages_ +from ...types import stream as stream_ +from ...types import tools as tools_ +from . import adapters +from . import client as client_ +from . import model as model_ +from . import types as types_ + + +async def stream( + model: model_.Model, + messages: list[messages_.Message], + *, + tools: Sequence[tools_.ToolLike] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + **kwargs: Any, +) -> stream_.StreamResultLike: + """Stream an LLM response. + + Returns a :class:`StreamResultLike` that is async-iterable and + collects the final ``Message``. After iteration, access ``.text``, + ``.tool_calls``, ``.usage``, etc. + + The client is resolved from the model: ``model.client`` if set, + otherwise auto-created from ``model.base_url`` / ``model.api_key_env``. + """ + call = middleware_.ModelContext( + model=model, + messages=messages, + tools=tools, + output_type=output_type, + kwargs=kwargs, + ) + + async def _real(call: middleware_.ModelContext) -> stream_.StreamResultLike: + c = client_.auto_client(call.model) + adapter_fn = adapters.get_stream_adapter(call.model.adapter) + return types_.StreamResult( + adapter_fn( + c, + call.model, + call.messages, + tools=call.tools, + output_type=call.output_type, + **call.kwargs, + ) + ) + + chain = middleware_._build_model_chain(_real) + return await chain(call) + + +async def generate( + model: model_.Model, + messages: list[messages_.Message], + params: types_.GenerateParams, + **kwargs: Any, +) -> messages_.Message: + """Generate a response (images, video, etc.). + + Resolves the adapter function from ``model.adapter``, auto-creates a + :class:`Client` from the model if no explicit client is set. + + ``params`` is required and controls the generation type: + + * :class:`ImageParams` — image generation (``/image-model``). + * :class:`VideoParams` — video generation (``/video-model``). + """ + call = middleware_.GenerateContext( + model=model, + messages=messages, + params=params, + ) + + async def _real(call: middleware_.GenerateContext) -> messages_.Message: + c = client_.auto_client(call.model) + adapter_fn = adapters.get_generate_adapter(call.model.adapter) + return await adapter_fn(c, call.model, call.messages, params=call.params) + + chain = middleware_._build_generate_chain(_real) + return await chain(call) + + +async def check_connection( + model: model_.Model, +) -> bool: + """Check whether the model's provider is reachable and the model exists. + + Returns ``True`` when the credentials are valid **and** the model is + available on the remote side — i.e. a subsequent :func:`stream` or + :func:`generate` call should succeed (network conditions permitting). + + This only hits free metadata endpoints; no tokens or credits are + consumed. + + The client is resolved from the model: ``model.client`` if set, + otherwise created by the provider. + + Non-auth transport errors (network failures, 5xx) are raised rather + than returning ``False`` so that callers can distinguish "bad + credentials / unknown model" from "provider unreachable". + """ + c = client_.auto_client(model) + return await model.provider.check(c, model) diff --git a/src/ai/models/core/catalog.py b/src/ai/models/core/catalog.py deleted file mode 100644 index 3244b757..00000000 --- a/src/ai/models/core/catalog.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Model catalog — per-provider registry of known models. - -Usage:: - - from ai.models.core.catalog import model - - opus = model("anthropic", "claude-opus-4-6") - sonnet_gw = model("ai-gateway", "anthropic/claude-sonnet-4-6") -""" - -from __future__ import annotations - -from .model import Model - -# --------------------------------------------------------------------------- -# Global registry: provider name -> {model_id -> Model} -# --------------------------------------------------------------------------- - -_catalogs: dict[str, dict[str, Model]] = {} - - -def register_catalog(provider: str, models: dict[str, Model]) -> None: - """Register a provider's model catalog. - - Called by each provider's ``catalog`` module during initialisation, - and by users who want to add custom providers. - """ - _catalogs[provider] = models - - -def _load_builtin_catalogs() -> None: - """Import and register all built-in provider catalogs. - - Catalog submodules only depend on :mod:`..core.model` (pure - dataclasses) — no ``httpx``, ``anthropic``, or ``openai`` imports. - This makes it safe to run eagerly at import time, including inside - sandboxed runtimes (e.g. Temporal workflow workers). - """ - from ..ai_gateway.catalog import CATALOG as ai_gw_catalog - from ..anthropic.catalog import CATALOG as anthropic_catalog - from ..openai.catalog import CATALOG as openai_catalog - - register_catalog("ai-gateway", ai_gw_catalog) - register_catalog("anthropic", anthropic_catalog) - register_catalog("openai", openai_catalog) - - -_load_builtin_catalogs() - - -# --------------------------------------------------------------------------- -# Public API -# --------------------------------------------------------------------------- - - -def model(provider: str, model_id: str) -> Model: - """Look up a model by provider and model ID. - - Raises :class:`KeyError` with a helpful message if the provider or - model ID is not found. - - Examples:: - - model("ai-gateway", "anthropic/claude-opus-4-6") - model("anthropic", "claude-opus-4-6") - model("openai", "gpt-5.4") - """ - provider_catalog = _catalogs.get(provider) - if provider_catalog is None: - available = ", ".join(sorted(_catalogs)) or "(none)" - raise KeyError( - f"Unknown provider {provider!r}. Registered providers: {available}" - ) - - entry = provider_catalog.get(model_id) - if entry is None: - available = ", ".join(sorted(provider_catalog)) or "(none)" - raise KeyError( - f"Unknown model {model_id!r} for provider {provider!r}. " - f"Available models: {available}" - ) - - return entry - - -def get_providers() -> list[str]: - """Return all registered provider names.""" - return sorted(_catalogs) - - -def get_models(provider: str) -> dict[str, Model]: - """Return all models for a provider. - - Raises :class:`KeyError` if the provider is not registered. - """ - provider_catalog = _catalogs.get(provider) - if provider_catalog is None: - available = ", ".join(sorted(_catalogs)) or "(none)" - raise KeyError( - f"Unknown provider {provider!r}. Registered providers: {available}" - ) - - return dict(provider_catalog) diff --git a/src/ai/models/core/client.py b/src/ai/models/core/client.py index fabed6cf..a559f52d 100644 --- a/src/ai/models/core/client.py +++ b/src/ai/models/core/client.py @@ -3,11 +3,13 @@ from __future__ import annotations import dataclasses -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: import httpx + from . import model as model_ + @dataclasses.dataclass class Client: @@ -25,9 +27,7 @@ class Client: api_key: str | None = None headers: dict[str, str] = dataclasses.field(default_factory=dict) - _http: httpx.AsyncClient | None = dataclasses.field( - default=None, repr=False, compare=False - ) + _http: Any = dataclasses.field(default=None, repr=False, compare=False) @property def http(self) -> httpx.AsyncClient: @@ -40,10 +40,23 @@ def http(self) -> httpx.AsyncClient: headers=self.headers, timeout=_httpx.Timeout(timeout=300.0, connect=10.0), ) - return self._http + return self._http # type: ignore[no-any-return] async def aclose(self) -> None: """Close the underlying HTTP client if open.""" if self._http is not None and not self._http.is_closed: await self._http.aclose() self._http = None + + +def auto_client(model: model_.Model) -> Client: + """Create a :class:`Client` from the model's connection info. + + Uses ``model.client`` if set, otherwise delegates to + ``model.provider.client()`` which reads the provider's default + ``base_url`` and ``api_key_env``. + """ + if model.client is not None: + return model.client + + return model.provider.client() diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index cbf59f50..769330ed 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -4,31 +4,21 @@ import dataclasses - -@dataclasses.dataclass(frozen=True) -class ModelCost: - """Per-million-token pricing.""" - - input: float = 0.0 - output: float = 0.0 - cache_read: float = 0.0 - cache_write: float = 0.0 +from .client import Client +from .proto import Provider @dataclasses.dataclass(frozen=True) class Model: - """Pure-data description of a model. + """Lightweight reference to a model on a specific provider. - * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-20250514"``). - * ``adapter`` — adapter key (e.g. ``"ai-gateway-v3"``, ``"anthropic-messages"``). - * ``provider`` — hosting service (e.g. ``"ai-gateway"``, ``"anthropic"``). + * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). + * ``adapter`` — wire protocol key (e.g. ``"ai-gateway-v3"``, ``"anthropic"``). + * ``provider`` — :class:`Provider` that owns this model. + * ``client`` — explicit :class:`Client` override (skips provider's default). """ id: str adapter: str - provider: str - name: str = "" - capabilities: tuple[str, ...] = ("text",) - context_window: int = 0 - max_output_tokens: int = 0 - cost: ModelCost | None = None + provider: Provider + client: Client | None = dataclasses.field(default=None, repr=False) diff --git a/src/ai/models/core/proto.py b/src/ai/models/core/proto.py index 754b2fa7..173bcfe7 100644 --- a/src/ai/models/core/proto.py +++ b/src/ai/models/core/proto.py @@ -16,10 +16,73 @@ from ...types import messages as messages_ from ...types import tools as tools_ -from .model import Model if TYPE_CHECKING: from .client import Client + from .model import Model + + +@runtime_checkable +class Provider(Protocol): + """Protocol for model providers. + + A provider carries all provider-specific configuration and behaviour: + API endpoint, authentication, client creation, connection checks, and + model enumeration. Model objects hold only pure metadata (``id``, + ``adapter``) plus a back-reference to their provider. + + Implementations must be **callable** — ``provider(model_id)`` returns + a :class:`Model`. + """ + + @property + def api_key_env(self) -> str | None: + """Env var name that holds the API key (e.g. ``"OPENAI_API_KEY"``).""" + ... + + @property + def base_url(self) -> str: + """Default base URL for the provider API.""" + ... + + @property + def adapter(self) -> str: + """Wire-protocol key used to look up stream/generate adapters.""" + ... + + @property + def name(self) -> str: + """Human-readable provider name (for repr, error messages).""" + ... + + def client(self) -> Client: + """Create a :class:`Client` from the provider's default config. + + Reads ``api_key_env`` from the environment and uses ``base_url`` + as the endpoint. + """ + ... + + async def check(self, client: Client, model: Model) -> bool: + """Check whether *client* can reach this provider and *model* exists. + + Returns ``True`` when credentials are valid **and** the model is + available. Non-auth transport errors should be raised. + """ + ... + + async def list(self, *, client: Client | None = None) -> list[str]: + """List available model IDs from the provider API.""" + ... + + def __call__( + self, + model_id: str, + *, + client: Client | None = None, + ) -> Model: + """Create a :class:`Model` for the given *model_id*.""" + ... @runtime_checkable @@ -57,7 +120,7 @@ async def __call__( client: Client, model: Model, messages: list[messages_.Message], - params: Any = None, + params: Any, ) -> messages_.Message: ... diff --git a/src/ai/models/core/types.py b/src/ai/models/core/types.py new file mode 100644 index 00000000..eff6e5af --- /dev/null +++ b/src/ai/models/core/types.py @@ -0,0 +1,120 @@ +"""Core model-layer types — parameter objects and StreamResult. + +Parameter types (:class:`ImageParams`, :class:`VideoParams`) live here +because they parameterise the public :func:`generate` API. + +:class:`StreamResult` is the concrete wrapper returned by :func:`stream`. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from typing import Any + +import pydantic + +from ...types import messages as messages_ + +# --------------------------------------------------------------------------- +# Generation parameter types +# --------------------------------------------------------------------------- + +_PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) + + +class ImageParams(pydantic.BaseModel): + """Parameters for image generation (``/image-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + size: str | None = None + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +class VideoParams(pydantic.BaseModel): + """Parameters for video generation (``/video-model`` endpoint).""" + + model_config = _PARAMS_CONFIG + + n: int = 1 + aspect_ratio: str | None = pydantic.Field( + default=None, serialization_alias="aspectRatio" + ) + resolution: str | None = None + duration: int | None = None + fps: int | None = None + seed: int | None = None + provider_options: dict[str, Any] = pydantic.Field( + default_factory=dict, serialization_alias="providerOptions" + ) + + +GenerateParams = ImageParams | VideoParams + +# --------------------------------------------------------------------------- +# StreamResult +# --------------------------------------------------------------------------- + + +class StreamResult: + """Wrapper around a message stream. Async-iterable; collects the final result. + + Properties like ``.text`` and ``.tool_calls`` delegate to the final + ``Message`` snapshot and are available after iteration completes. + + Satisfies :class:`~ai.types.StreamResultLike`. + """ + + def __init__(self, gen: AsyncGenerator[messages_.Message]) -> None: + self._gen = gen + self._final: messages_.Message | None = None + + @classmethod + def from_generator(cls, gen: AsyncGenerator[messages_.Message]) -> StreamResult: + """Create a :class:`StreamResult` from an async generator. + + This is the public API for middleware that needs to transform or + replace the stream returned by ``wrap_model``:: + + async def wrap_model(self, call, next): + original = await next(call) + + async def _transformed(): + async for msg in original: + yield modify(msg) + + return StreamResult.from_generator(_transformed()) + """ + return cls(gen) + + def __aiter__(self) -> AsyncGenerator[messages_.Message]: + return self._iterate() + + async def _iterate(self) -> AsyncGenerator[messages_.Message]: + async for msg in self._gen: + self._final = msg + yield msg + + @property + def text(self) -> str: + return self._final.text if self._final else "" + + @property + def tool_calls(self) -> list[messages_.ToolCallPart]: + return self._final.tool_calls if self._final else [] + + @property + def usage(self) -> messages_.Usage | None: + return self._final.usage if self._final else None + + @property + def output(self) -> Any: + """Parsed structured output from the final message, if available.""" + return self._final.output if self._final else None diff --git a/src/ai/models/openai/__init__.py b/src/ai/models/openai/__init__.py index 5fc55fc0..9dfb113b 100644 --- a/src/ai/models/openai/__init__.py +++ b/src/ai/models/openai/__init__.py @@ -1,10 +1,96 @@ -"""OpenAI provider — adapter for the OpenAI chat completions API. +"""OpenAI provider. + +Usage:: + + from ai.models import openai + + model = openai("gpt-5.4") + ids = await openai.list() The adapter module is loaded lazily to avoid pulling in the ``openai`` SDK at import time. """ -__all__: list[str] = [] +from __future__ import annotations + +import os +from typing import TYPE_CHECKING + +from ..core import client as client_ +from ..core.model import Model + +if TYPE_CHECKING: + pass + +_BASE_URL = "https://api.openai.com/v1" +_API_KEY_ENV = "OPENAI_API_KEY" + + +class _OpenAI: + """Callable provider — ``openai("gpt-5.4")`` returns a :class:`Model`. + + Satisfies the :class:`~ai.models.core.proto.Provider` protocol. + """ + + @property + def api_key_env(self) -> str: + return _API_KEY_ENV + + @property + def base_url(self) -> str: + return _BASE_URL + + @property + def adapter(self) -> str: + return "openai" + + @property + def name(self) -> str: + return "openai" + + def client(self) -> client_.Client: + """Create a :class:`Client` from env-var credentials.""" + return client_.Client( + base_url=_BASE_URL, + api_key=os.environ.get(_API_KEY_ENV), + ) + + async def check(self, client: client_.Client, model: Model) -> bool: + """Delegate to :func:`openai.check.check`.""" + from . import check as check_ + + return await check_.check(client, model) + + def __call__( + self, + model_id: str, + *, + base_url: str | None = None, + client: client_.Client | None = None, + ) -> Model: + return Model( + id=model_id, + adapter=self.adapter, + provider=self, + client=client, + ) + + async def list(self, *, client: client_.Client | None = None) -> list[str]: + """List available model IDs from the OpenAI API.""" + c = client or self.client() + headers = {"Authorization": f"Bearer {c.api_key}"} + response = await c.http.get(f"{c.base_url.rstrip('/')}/models", headers=headers) + response.raise_for_status() + data: list[dict[str, object]] = response.json().get("data", []) + return sorted(str(m["id"]) for m in data) + + def __repr__(self) -> str: + return "openai" + + +openai = _OpenAI() + +__all__ = ["openai"] def __getattr__(name: str) -> object: diff --git a/src/ai/models/openai/catalog.py b/src/ai/models/openai/catalog.py deleted file mode 100644 index 91df6605..00000000 --- a/src/ai/models/openai/catalog.py +++ /dev/null @@ -1,45 +0,0 @@ -"""OpenAI direct-API model catalog. - -Model IDs are what the OpenAI Chat Completions API expects. -Pricing is per million tokens (USD). -""" - -from __future__ import annotations - -from ..core.model import Model, ModelCost - -_ADAPTER = "openai" -_PROVIDER = "openai" - -CATALOG: dict[str, Model] = { - "gpt-5.4": Model( - id="gpt-5.4", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4", - capabilities=("text",), - context_window=1_000_000, - max_output_tokens=128_000, - cost=ModelCost(input=2.50, output=15.0, cache_read=0.25), - ), - "gpt-5.4-mini": Model( - id="gpt-5.4-mini", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4 Mini", - capabilities=("text",), - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.75, output=4.50, cache_read=0.075), - ), - "gpt-5.4-nano": Model( - id="gpt-5.4-nano", - adapter=_ADAPTER, - provider=_PROVIDER, - name="GPT-5.4 Nano", - capabilities=("text",), - context_window=400_000, - max_output_tokens=128_000, - cost=ModelCost(input=0.20, output=1.25, cache_read=0.02), - ), -} diff --git a/tests/conftest.py b/tests/conftest.py index 15796650..1e8529f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,11 +10,80 @@ from ai.types import builders from ai.types import messages as messages_ -# A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. -MOCK_MODEL = models.Model(id="mock-model", adapter="mock", provider="mock") -# Register a dummy provider so _auto_client() doesn't error for provider="mock". -models._PROVIDER_DEFAULTS["mock"] = ("http://mock.test", "MOCK_API_KEY") +class MockProvider: + """Minimal provider that satisfies the ``Provider`` protocol for tests. + + Carries just enough state so that ``Model`` objects can be constructed + and ``auto_client`` can resolve a client. + """ + + def __init__( + self, + *, + name: str = "mock", + adapter: str = "mock", + base_url: str = "http://mock.test", + api_key_env: str | None = "MOCK_API_KEY", + ) -> None: + self._name = name + self._adapter = adapter + self._base_url = base_url + self._api_key_env = api_key_env + + @property + def name(self) -> str: + return self._name + + @property + def adapter(self) -> str: + return self._adapter + + @property + def base_url(self) -> str: + return self._base_url + + @property + def api_key_env(self) -> str | None: + return self._api_key_env + + def client(self) -> models.Client: + import os + + api_key = os.environ.get(self._api_key_env) if self._api_key_env else None + return models.Client(base_url=self._base_url, api_key=api_key) + + async def check(self, client: models.Client, model: models.Model) -> bool: + return True + + async def list(self, *, client: models.Client | None = None) -> list[str]: + return [] + + def __call__( + self, + model_id: str, + *, + client: models.Client | None = None, + ) -> models.Model: + return models.Model( + id=model_id, + adapter=self._adapter, + provider=self, + client=client, + ) + + def __repr__(self) -> str: + return self._name + + +MOCK_PROVIDER = MockProvider() + +# A fixed Model used in tests — adapter="mock" dispatches to the mock adapter. +MOCK_MODEL = models.Model( + id="mock-model", + adapter="mock", + provider=MOCK_PROVIDER, +) class MockAdapter: diff --git a/tests/models/ai_gateway/test_generate_image.py b/tests/models/ai_gateway/test_generate_image.py index a57abe3f..b3d7bfd9 100644 --- a/tests/models/ai_gateway/test_generate_image.py +++ b/tests/models/ai_gateway/test_generate_image.py @@ -21,9 +21,8 @@ import httpx import pytest -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.ai_gateway.generate import ImageParams, generate -from ai.models.core import model as model_ from ai.types import messages from .conftest import mock_client, user_msg @@ -36,12 +35,7 @@ _JPEG_HEADER = bytes([0xFF, 0xD8, 0xFF, 0xE0]) _JPEG_B64 = base64.b64encode(_JPEG_HEADER).decode() -_IMAGE_MODEL = model_.Model( - id="google/imagen-4.0-generate-001", - adapter="ai-gateway-v3", - provider="ai-gateway", - capabilities=("image",), -) +_IMAGE_MODEL = ai_gateway("google/imagen-4.0-generate-001") # --------------------------------------------------------------------------- @@ -60,7 +54,9 @@ def handler(req: httpx.Request) -> httpx.Response: ) client = mock_client(httpx.MockTransport(handler)) - msg = await generate(client, _IMAGE_MODEL, [user_msg("A sunset over Tokyo")]) + msg = await generate( + client, _IMAGE_MODEL, [user_msg("A sunset over Tokyo")], ImageParams() + ) assert msg.role == "assistant" assert len(msg.images) == 1 @@ -104,7 +100,7 @@ def handler(req: httpx.Request) -> httpx.Response: ) client = mock_client(httpx.MockTransport(handler)) - msg = await generate(client, _IMAGE_MODEL, [user_msg("a dog")]) + msg = await generate(client, _IMAGE_MODEL, [user_msg("a dog")], ImageParams()) assert msg.usage is not None assert msg.usage.input_tokens == 50 @@ -124,14 +120,9 @@ def handler(req: httpx.Request) -> httpx.Response: captured.update(dict(req.headers)) return httpx.Response(200, json={"images": [_PNG_B64]}) - model = model_.Model( - id="openai/gpt-image-1", - adapter="ai-gateway-v3", - provider="ai-gateway", - capabilities=("image",), - ) + model = ai_gateway("openai/gpt-image-1") client = mock_client(httpx.MockTransport(handler), api_key="sk-test") - await generate(client, model, [user_msg("Hi")]) + await generate(client, model, [user_msg("Hi")], ImageParams()) assert captured["authorization"] == "Bearer sk-test" assert captured["ai-image-model-specification-version"] == "3" @@ -184,7 +175,7 @@ def handler(req: httpx.Request) -> httpx.Response: return httpx.Response(200, json={"images": [_PNG_B64]}) client = mock_client(httpx.MockTransport(handler)) - await generate(client, _IMAGE_MODEL, [user_msg("test")]) + await generate(client, _IMAGE_MODEL, [user_msg("test")], ImageParams()) assert captured_url[0] == "https://gw.test/v3/ai/image-model" @@ -212,6 +203,7 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) async def test_429_rate_limit_error(self) -> None: @@ -231,6 +223,7 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) async def test_empty_images_returns_empty_message(self) -> None: @@ -243,5 +236,6 @@ def handler(req: httpx.Request) -> httpx.Response: mock_client(httpx.MockTransport(handler)), _IMAGE_MODEL, [user_msg("test")], + ImageParams(), ) assert len(msg.images) == 0 diff --git a/tests/models/ai_gateway/test_generate_video.py b/tests/models/ai_gateway/test_generate_video.py index b283fcd9..8233daa6 100644 --- a/tests/models/ai_gateway/test_generate_video.py +++ b/tests/models/ai_gateway/test_generate_video.py @@ -22,9 +22,8 @@ import httpx import pytest -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.ai_gateway.generate import VideoParams, generate -from ai.models.core import model as model_ from ai.types import messages from .conftest import mock_client, sse, user_msg @@ -39,12 +38,7 @@ _WEBM_HEADER = bytes([0x1A, 0x45, 0xDF, 0xA3]) _WEBM_B64 = base64.b64encode(_WEBM_HEADER).decode() -_VIDEO_MODEL = model_.Model( - id="google/veo-3.0-generate-001", - adapter="ai-gateway-v3", - provider="ai-gateway", - capabilities=("video",), -) +_VIDEO_MODEL = ai_gateway("google/veo-3.0-generate-001") # --------------------------------------------------------------------------- diff --git a/tests/models/ai_gateway/test_stream.py b/tests/models/ai_gateway/test_stream.py index 8195759b..71dfa5d9 100644 --- a/tests/models/ai_gateway/test_stream.py +++ b/tests/models/ai_gateway/test_stream.py @@ -23,7 +23,7 @@ import pytest import ai -from ai.models.ai_gateway import errors +from ai.models.ai_gateway import ai_gateway, errors from ai.models.core import model as model_ from ai.types import messages @@ -37,11 +37,7 @@ # Helpers # --------------------------------------------------------------------------- -_TEST_MODEL = model_.Model( - id="test-provider/test-model", - adapter="ai-gateway-v3", - provider="ai-gateway", -) +_TEST_MODEL = ai_gateway("test-provider/test-model") async def _collect( @@ -220,11 +216,7 @@ def handler(req: httpx.Request) -> httpx.Response: text=sse({"type": "finish", "finishReason": "stop", "usage": {}}), ) - model = model_.Model( - id="anthropic/claude-sonnet-4", - adapter="ai-gateway-v3", - provider="ai-gateway", - ) + model = ai_gateway("anthropic/claude-sonnet-4") client = mock_client(httpx.MockTransport(handler), api_key="sk-test") await _collect(client, [user_msg("Hi")], model=model) diff --git a/tests/models/test_check.py b/tests/models/test_check.py index 0889a0a8..29500434 100644 --- a/tests/models/test_check.py +++ b/tests/models/test_check.py @@ -8,13 +8,15 @@ from __future__ import annotations +import dataclasses import json from typing import Any import httpx import pytest -from ai.models import check_connection +from ai.models import anthropic, check_connection, openai +from ai.models.ai_gateway import ai_gateway from ai.models.ai_gateway import check as ai_gw_check from ai.models.anthropic import check as anthropic_check from ai.models.core import client as client_ @@ -25,14 +27,48 @@ # Fixtures # --------------------------------------------------------------------------- -_OPENAI_MODEL = model_.Model(id="gpt-5.4", adapter="openai", provider="openai") -_ANTHROPIC_MODEL = model_.Model( - id="claude-opus-4-6", adapter="anthropic", provider="anthropic" -) -_GATEWAY_MODEL = model_.Model( - id="anthropic/claude-opus-4-6", adapter="ai-gateway-v3", provider="ai-gateway" -) -_UNKNOWN_MODEL = model_.Model(id="x", adapter="x", provider="unknown-provider") +_OPENAI_MODEL = openai("gpt-5.4") +_ANTHROPIC_MODEL = anthropic("claude-opus-4-6") +_GATEWAY_MODEL = ai_gateway("anthropic/claude-opus-4-6") + + +class _FailProvider: + """A provider whose check always raises KeyError (for testing unknown providers).""" + + @property + def api_key_env(self) -> None: + return None + + @property + def base_url(self) -> str: + return "http://unknown.test" + + @property + def adapter(self) -> str: + return "x" + + @property + def name(self) -> str: + return "unknown-provider" + + def client(self) -> client_.Client: + return client_.Client(base_url=self.base_url) + + async def check(self, client: client_.Client, model: model_.Model) -> bool: + raise KeyError(f"No check function registered for provider={self.name!r}.") + + async def list(self, *, client: client_.Client | None = None) -> list[str]: + return [] + + def __call__( + self, model_id: str, *, client: client_.Client | None = None + ) -> model_.Model: + return model_.Model( + id=model_id, adapter=self.adapter, provider=self, client=client + ) + + +_UNKNOWN_MODEL = _FailProvider()("x") def _client_with_mock( @@ -155,15 +191,16 @@ class TestCheckConnection: async def test_gateway_dispatches(self) -> None: config = {"models": [{"id": "anthropic/claude-opus-4-6"}]} c = _gateway_client(config_body=config) - assert await check_connection(_GATEWAY_MODEL, client=c) is True + m = dataclasses.replace(_GATEWAY_MODEL, client=c) + assert await check_connection(m) is True async def test_unknown_provider_raises(self) -> None: c = _client_with_mock(200) + m = dataclasses.replace(_UNKNOWN_MODEL, client=c) with pytest.raises(KeyError, match="unknown-provider"): - await check_connection(_UNKNOWN_MODEL, client=c) + await check_connection(m) async def test_dispatch_false_propagates(self) -> None: - assert ( - await check_connection(_OPENAI_MODEL, client=_client_with_mock(401)) - is False - ) + c = _client_with_mock(401) + m = dataclasses.replace(_OPENAI_MODEL, client=c) + assert await check_connection(m) is False diff --git a/tests/models/test_public_api.py b/tests/models/test_public_api.py index 378a6e74..7ca4b1b1 100644 --- a/tests/models/test_public_api.py +++ b/tests/models/test_public_api.py @@ -11,7 +11,7 @@ from ai import models from ai.types import messages as messages_ -from ..conftest import MOCK_MODEL, mock_llm, text_msg +from ..conftest import MOCK_MODEL, MOCK_PROVIDER, MockProvider, mock_llm, text_msg # Module-level model so StructuredOutputPart can resolve it by FQN. @@ -41,7 +41,7 @@ async def test_stream_basic() -> None: async def test_stream_with_explicit_client() -> None: - """ai.models.stream(client=...) forwards the provided client to the adapter.""" + """Model with explicit client= forwards it to the adapter.""" received_clients: list[models.Client] = [] async def _spy_stream( @@ -63,7 +63,10 @@ async def _spy_stream( models.register_stream("mock", _spy_stream) explicit = models.Client(base_url="https://custom.test", api_key="sk-custom") - s = await models.stream(MOCK_MODEL, [ai.user_message("Hi")], client=explicit) + explicit_model = models.Model( + id="mock-model", adapter="mock", provider=MOCK_PROVIDER, client=explicit + ) + s = await models.stream(explicit_model, [ai.user_message("Hi")]) async for _ in s: pass @@ -117,11 +120,12 @@ async def _structured_stream( # generate() tests # --------------------------------------------------------------------------- +_MOCK_GEN_PROVIDER = MockProvider(adapter="mock-gen") + GENERATE_MODEL = models.Model( id="gen-model", adapter="mock-gen", - provider="mock", - capabilities=("image",), + provider=_MOCK_GEN_PROVIDER, ) diff --git a/tests/test_middleware.py b/tests/test_middleware.py index 2e335247..8d7a88ce 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -320,7 +320,9 @@ async def wrap_generate( @my_agent.loop async def gen_loop(context: ai.Context) -> AsyncGenerator[ai.Message]: - result = await models.generate(context.model, context.messages) + result = await models.generate( + context.model, context.messages, models.ImageParams() + ) yield result async for _m in my_agent.run(