From 259f5de6ca019e8c5e8460ffd7b52fd13ae77bc7 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 16 Jun 2026 08:43:33 -0700 Subject: [PATCH 1/3] Outline the serializable models implementation --- src/ai/models/__init__.py | 44 +++- src/ai/models/core/model.py | 320 ++++++++++++++++++++++--- src/ai/providers/anthropic/__init__.py | 10 +- src/ai/providers/base.py | 79 ++++-- src/ai/providers/openai/__init__.py | 7 +- 5 files changed, 396 insertions(+), 64 deletions(-) diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 29cbb800..4cf49fb1 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -1,16 +1,28 @@ """models — composable model layer. +A :class:`Model` holds a *recipe* for its provider — a factory callable +plus its arguments — instead of a live provider object. The provider and +its client are built lazily on first use. When the factory is a named, +module-level callable and the args are JSON-friendly (everything +``get_model`` produces), the model serializes: ``model.model_dump()`` / +``Model.model_validate()`` round-trip. Any other callable (a lambda, a +closure over live objects) works normally in-process, but ``model_dump`` +raises and ``model.serializable`` is ``False``. + Usage:: import ai model = ai.get_model("openai:gpt-5.4") - provider = ai.get_provider("openai", base_url="http://localhost:11434/v1") - model = ai.Model("llama3", provider=provider) model = ai.get_model("anthropic:claude-sonnet-4-6") - provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com") - model = ai.Model("claude-sonnet-4-6", provider=provider) model = ai.get_model("anthropic/claude-sonnet-4") # defaults to Gateway + # custom provider configuration — JSON-friendly args + model = ai.Model( + "llama3", + provider_factory=ai.get_provider, + provider_args={"id": "openai", "base_url": "http://localhost:11434/v1"}, + ) + # stream — auto-creates client from env vars msgs = [ai.user_message("hello")] async with ai.stream(model, msgs) as s: @@ -18,15 +30,21 @@ if isinstance(event, ai.events.TextDelta): print(event.chunk, end="", flush=True) - # explicit provider for custom auth / transport - provider = ai.get_provider( - "openai", - base_url="https://custom.example.com/v1", - api_key="sk-...", - ) - model = ai.Model("gpt-5.4", provider=provider) - async with ai.stream(model, msgs) as s: - ... + # models serialize and rebuild their provider on first use + data = model.model_dump(mode="json") + model = ai.Model.model_validate(data) + + # anything non-serializable (clients, custom auth) lives inside a + # named module-level factory; its import path is what's serialized + def my_provider() -> ai.Provider: + return ai.get_provider("openai", client=shared_client) + + model = ai.Model("gpt-5.4", provider_factory=my_provider) + + # if the model never crosses a process boundary, any callable works — + # the model just isn't serializable (model_dump() raises) + model = ai.Model("gpt-5.4", provider_factory=lambda: provider) + assert model.serializable is False # list available models ids = await ai.get_provider("openai").list_models() diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 5ff6fd5d..a31c9669 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,7 +1,13 @@ """Model metadata types.""" +import importlib +import inspect +import json import os -from typing import Any, Self +from collections.abc import Callable +from typing import Any, Self, cast + +import pydantic from ... import _modelsdev from ...errors import ConfigurationError @@ -10,51 +16,293 @@ _DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL" -class Model: - """Lightweight reference to a model on a specific provider. +def _callable_ref(fn: Callable[..., Any]) -> str: + """Return the ``"package.module:qualname"`` reference for *fn*. + + Raises :class:`ai.ConfigurationError` when *fn* cannot be found again + by name — which is exactly what makes a factory unserializable. + """ + module_name = getattr(fn, "__module__", None) + qualname = getattr(fn, "__qualname__", None) + if not module_name or not qualname: + raise ConfigurationError( + f"factory {fn!r} has no importable name; it must be a named, " + "module-level function or class so the model can be serialized" + ) + if module_name == "__main__": + raise ConfigurationError( + f"factory {qualname!r} is defined in __main__ and cannot be " + "imported by other processes; move it into an importable module" + ) + if "<" in qualname: + raise ConfigurationError( + f"factory {module_name}.{qualname} must be a named, module-level " + "function or class so the model can be serialized; lambdas and " + "callables defined inside functions are not importable by name" + ) + try: + module = importlib.import_module(module_name) + except ImportError as error: + raise ConfigurationError( + f"cannot import module {module_name!r} of factory " + f"{qualname!r}: {error}" + ) from error + obj: Any = module + for part in qualname.split("."): + obj = getattr(obj, part, None) + if obj is not fn: + raise ConfigurationError( + f"factory {module_name}:{qualname} does not import back to the " + "same object; it must be a named, module-level function or class " + "(bound methods and decorated wrappers are not supported)" + ) + return f"{module_name}:{qualname}" + + +def _import_ref(ref: str) -> Callable[..., Any]: + """Import a factory from a ``"package.module:qualname"`` reference.""" + module_name, sep, qualname = ref.partition(":") + if not sep or not module_name or not qualname: + raise ConfigurationError( + f"malformed factory reference {ref!r}; " + "expected 'package.module:name'" + ) + try: + module = importlib.import_module(module_name) + except ImportError as error: + raise ConfigurationError( + f"cannot import factory {ref!r}: {error}" + ) from error + obj: Any = module + for part in qualname.split("."): + obj = getattr(obj, part, None) + if obj is None: + raise ConfigurationError( + f"cannot import factory {ref!r}: module {module_name!r} " + f"has no attribute {qualname!r}" + ) + if not callable(obj): + raise ConfigurationError(f"factory {ref!r} is not callable") + return cast("Callable[..., Any]", obj) + + +class Model(pydantic.BaseModel): + """Reference to a model on a specific provider. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). - * ``provider`` — :class:`Provider` that owns this model. - * ``protocol`` — optional wire-protocol override for this model. + * ``provider_factory`` — callable that builds the :class:`Provider`. + * ``provider_args`` — keyword arguments for the factory. + * ``protocol_factory`` / ``protocol_args`` — optional wire-protocol + override for this model, same rules as the provider factory. + + The provider is built lazily on first :attr:`provider` access and + cached. + + A model is **serializable** when the factory is a named module-level + function or class (dumped as a ``"package.module:name"`` reference) + and the args are JSON-friendly — everything :func:`get_model` + produces qualifies. Anything that cannot be expressed as JSON args + (custom clients, shared connection pools) can live inside the + factory body:: + + def my_provider() -> ai.Provider: + return ai.get_provider("openai", client=_shared_client) + + model = ai.Model("gpt-5", provider_factory=my_provider) + + Any other callable — a lambda, a closure over a live provider — is + accepted and works normally in-process, but the model then cannot + cross a JSON boundary: ``model_dump()`` raises with the reason. + Check :attr:`serializable` to know ahead of time. """ + id: str + provider_factory: Callable[..., Any] + provider_args: dict[str, Any] = pydantic.Field(default_factory=dict) + protocol_factory: Callable[..., Any] | None = None + protocol_args: dict[str, Any] = pydantic.Field(default_factory=dict) + + _provider_instance: base.Provider[Any] | None = pydantic.PrivateAttr( + default=None + ) + _protocol_instance: base.ProviderProtocol[Any] | None = ( + pydantic.PrivateAttr(default=None) + ) + def __init__( self, id: str, *, - provider: base.Provider, - protocol: base.ProviderProtocol[Any] | None = None, + provider_factory: Callable[..., Any] | str, + provider_args: dict[str, Any] | None = None, + protocol_factory: Callable[..., Any] | str | None = None, + protocol_args: dict[str, Any] | None = None, ) -> None: - self.id = id - self.provider = provider - self.protocol = protocol + super().__init__( + id=id, + provider_factory=provider_factory, + provider_args={} if provider_args is None else provider_args, + protocol_factory=protocol_factory, + protocol_args={} if protocol_args is None else protocol_args, + ) + + @pydantic.field_validator( + "provider_factory", "protocol_factory", mode="before" + ) + @classmethod + def _coerce_factory(cls, value: Any) -> Any: + if isinstance(value, str): + return _import_ref(value) + return value + + @pydantic.model_validator(mode="after") + def _check_factory_args(self) -> Self: + for label, factory, args in ( + ("provider_args", self.provider_factory, self.provider_args), + ("protocol_args", self.protocol_factory, self.protocol_args), + ): + if factory is None: + if args: + raise ConfigurationError( + "protocol_args given without protocol_factory" + ) + continue + try: + signature = inspect.signature(factory) + except (TypeError, ValueError): + continue # some builtins have no introspectable signature + try: + signature.bind(**args) + except TypeError as error: + raise ConfigurationError( + f"{label} do not match the signature of " + f"{factory!r}: {error}" + ) from error + return self + + # Dumps must round-trip or raise: both serializers run for python + # and JSON dumps alike, so a model built around a closure or live + # objects fails loudly the moment it tries to cross a boundary. + @pydantic.field_serializer("provider_factory", "protocol_factory") + def _serialize_factory( + self, value: Callable[..., Any] | None + ) -> str | None: + return None if value is None else _callable_ref(value) + + @pydantic.field_serializer("provider_args", "protocol_args") + def _serialize_args( + self, + value: dict[str, Any], + info: pydantic.FieldSerializationInfo, + ) -> dict[str, Any]: + try: + json.dumps(value) + except (TypeError, ValueError) as error: + raise ConfigurationError( + f"{info.field_name} cannot round-trip through JSON " + f"(put live objects inside a named module-level factory " + f"instead): {error}" + ) from error + return value + + @property + def serializable(self) -> bool: + """Whether this model round-trips through ``model_dump``. + + ``True`` for named module-level factories with JSON-friendly + args — everything :func:`get_model` produces. ``False`` when + the model was built around a lambda, closure, or live objects; + such models work normally in-process but ``model_dump`` raises. + """ + for factory, args in ( + (self.provider_factory, self.provider_args), + (self.protocol_factory, self.protocol_args), + ): + if factory is None: + continue + try: + _callable_ref(factory) + json.dumps(args) + except (ConfigurationError, TypeError, ValueError): + return False + return True + + @property + def provider(self) -> base.Provider[Any]: + """Provider instance, built lazily from the factory and cached.""" + provider = self._provider_instance + if provider is None: + provider = self.provider_factory(**self.provider_args) + if not isinstance(provider, base.Provider): + raise ConfigurationError( + f"provider factory {self.provider_factory!r} returned " + f"{type(provider).__name__}, expected a Provider" + ) + self._provider_instance = provider + return provider + + @property + def protocol(self) -> base.ProviderProtocol[Any] | None: + """Protocol override instance, built lazily and cached.""" + if self.protocol_factory is None: + return None + protocol = self._protocol_instance + if protocol is None: + protocol = self.protocol_factory(**self.protocol_args) + if not isinstance(protocol, base.ProviderProtocol): + raise ConfigurationError( + f"protocol factory {self.protocol_factory!r} returned " + f"{type(protocol).__name__}, expected a ProviderProtocol" + ) + self._protocol_instance = protocol + return protocol + + async def aclose(self) -> None: + """Close the provider if this model lazily built one.""" + if self._provider_instance is not None: + await self._provider_instance.aclose() + self._provider_instance = None def __eq__(self, other: object) -> bool: + # Pydantic's default __eq__ also compares private attributes, + # which would make models unequal once one of them lazily built + # its provider. Compare the recipe fields only; factories + # compare by identity (a round-tripped model imports the same + # factory object back, so equality survives serialization). return ( isinstance(other, Model) and self.id == other.id - and self.provider is other.provider - and self.protocol is other.protocol + and self.provider_factory is other.provider_factory + and self.provider_args == other.provider_args + and self.protocol_factory is other.protocol_factory + and self.protocol_args == other.protocol_args ) - def __repr__(self) -> str: - return f"Model(id={self.id!r}, provider={self.provider!r})" - def __hash__(self) -> int: - return hash((self.id, id(self.provider), id(self.protocol))) + return hash((self.id, self.provider_factory, self.protocol_factory)) - def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: - return self.__class__( - id=self.id, - provider=self.provider, - protocol=protocol, + def with_protocol( + self, + protocol_factory: Callable[..., Any] | str, + **protocol_args: Any, + ) -> Self: + model = self.__class__( + self.id, + provider_factory=self.provider_factory, + provider_args=self.provider_args, + protocol_factory=protocol_factory, + protocol_args=protocol_args, ) + # Keep sharing an already-built provider instance. + model._provider_instance = self._provider_instance + return model def get_model( model_id: str | None = None, *, - protocol: base.ProviderProtocol[Any] | None = None, + protocol_factory: Callable[..., Any] | str | None = None, + protocol_args: dict[str, Any] | None = None, ) -> Model: """Resolve a model ID into a :class:`Model`. @@ -65,9 +313,13 @@ def get_model( Vercel AI Gateway. Examples: ``"openai:gpt-5"`` or ``"anthropic/claude-sonnet-4"``. When omitted, reads ``AI_SDK_DEFAULT_MODEL`` from the environment. - protocol: - Optional wire-protocol override for this model. When omitted, - the provider chooses its default protocol. + protocol_factory: + Optional wire-protocol override for this model — a named, + module-level callable (usually the protocol class) that builds + the protocol. When omitted, the provider chooses its default + protocol. + protocol_args: + JSON-serializable keyword arguments for ``protocol_factory``. Raises: Raises :class:`ai.ConfigurationError` when ``model_id`` and @@ -102,9 +354,19 @@ def get_model( None if model_info is None else model_info.provider_config ) - provider = base.Provider.from_id( - provider_id, - model_provider_config=model_provider_config, + # Fail early on unknown or unsupported providers without building a + # provider (and its client); the model only stores the recipe. + base.resolve_provider_type( + provider_id, model_provider_config=model_provider_config ) - return Model(provider_model_id, provider=provider, protocol=protocol) + return Model( + provider_model_id, + provider_factory=base.provider_for_model, + provider_args={ + "provider_id": provider_id, + "model_id": provider_model_id, + }, + protocol_factory=protocol_factory, + protocol_args=protocol_args, + ) diff --git a/src/ai/providers/anthropic/__init__.py b/src/ai/providers/anthropic/__init__.py index a183b555..5577421d 100644 --- a/src/ai/providers/anthropic/__init__.py +++ b/src/ai/providers/anthropic/__init__.py @@ -6,8 +6,14 @@ from ai.providers.anthropic import tools as anthropic_tools model = ai.get_model("anthropic:claude-sonnet-4-6") - provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com") - model = ai.Model("claude-sonnet-4-6", provider=provider) + model = ai.Model( + "claude-sonnet-4-6", + provider_factory=ai.get_provider, + provider_args={ + "id": "anthropic", + "base_url": "https://anthropic.example.com", + }, + ) ids = await ai.get_provider("anthropic").list_models() # built-in tools diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 28d1af10..560c882a 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -279,24 +279,19 @@ def from_id( if modelsdev_provider is None: raise ValueError(f"unknown provider id: {known_id!r}") - for handle in ( - modelsdev_provider.id, - _modelsdev.provider_npm(modelsdev_provider, model_provider_config), - ): - provider_type = _PROVIDER_REGISTRY.get(handle) - if provider_type is not None: - return provider_type.from_modelsdev_provider( - modelsdev_provider, - model_provider_config=model_provider_config, - base_url=base_url, - api_key=api_key, - headers=headers, - env=env, - client=client, - protocol=protocol, - ) - - raise UnsupportedProviderError(modelsdev_provider.id) + provider_type = resolve_provider_type( + known_id, model_provider_config=model_provider_config + ) + return provider_type.from_modelsdev_provider( + modelsdev_provider, + model_provider_config=model_provider_config, + base_url=base_url, + api_key=api_key, + headers=headers, + env=env, + client=client, + protocol=protocol, + ) @classmethod def from_modelsdev_provider( @@ -318,6 +313,54 @@ def from_modelsdev_provider( _PROVIDER_REGISTRY: dict[str, type[Provider[Any]]] = {} +def resolve_provider_type( + known_id: str, + *, + model_provider_config: modelsdotdev.ModelProviderConfig | None = None, +) -> type[Provider[Any]]: + """Return the registered provider class for a models.dev provider ID. + + Performs the same lookup as :meth:`Provider.from_id` without building + a provider (or its client). + + Raises ``ValueError`` for unknown provider IDs and + :class:`ai.UnsupportedProviderError` for known providers without a + registered implementation. + """ + modelsdev_provider = _modelsdev.get_provider_by_id(known_id) + if modelsdev_provider is None: + raise ValueError(f"unknown provider id: {known_id!r}") + + for handle in ( + modelsdev_provider.id, + _modelsdev.provider_npm(modelsdev_provider, model_provider_config), + ): + provider_type = _PROVIDER_REGISTRY.get(handle) + if provider_type is not None: + return provider_type + + raise UnsupportedProviderError(modelsdev_provider.id) + + +def provider_for_model( + provider_id: str, model_id: str | None = None +) -> Provider[Any]: + """Build the provider serving *model_id* on *provider_id*. + + This is the default provider factory stored by :func:`ai.get_model`. + It looks up models.dev metadata at build time, so its arguments stay + plain JSON and models that reference it serialize cleanly. + """ + model_provider_config = None + if model_id is not None: + model_info = _modelsdev.get_model_by_id(f"{provider_id}:{model_id}") + if model_info is not None: + model_provider_config = model_info.provider_config + return Provider.from_id( + provider_id, model_provider_config=model_provider_config + ) + + def get_provider( id: str, *, diff --git a/src/ai/providers/openai/__init__.py b/src/ai/providers/openai/__init__.py index b590d20d..0b50b40b 100644 --- a/src/ai/providers/openai/__init__.py +++ b/src/ai/providers/openai/__init__.py @@ -5,8 +5,11 @@ import ai model = ai.get_model("openai:gpt-5.4") - provider = ai.get_provider("openai", base_url="http://localhost:11434/v1") - model = ai.Model("llama3", provider=provider) + model = ai.Model( + "llama3", + provider_factory=ai.get_provider, + provider_args={"id": "openai", "base_url": "http://localhost:11434/v1"}, + ) ids = await ai.get_provider("openai").list_models() The optional upstream OpenAI SDK is loaded lazily when the provider creates or From 868948f14defc8a69cde70db011c26f7e56af039 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 16 Jun 2026 10:39:05 -0700 Subject: [PATCH 2/3] Update examples and tests to use the new api --- .../.test_scripts/run-with-patched-model.py | 39 ++- examples/builtin_web_search.py | 8 +- examples/check_connection.py | 33 ++- examples/explicit_client.py | 22 +- examples/openai_chat_completions.py | 16 +- examples/stream_all.py | 22 +- examples/temporal-direct/main.py | 36 +-- tests/conftest.py | 10 +- tests/models/core/test_api.py | 105 ++++---- tests/models/core/test_model.py | 225 ++++++++++++++++++ tests/models/test_resolution.py | 8 +- tests/providers/ai_gateway/conftest.py | 13 +- tests/providers/ai_gateway/test_probe.py | 28 ++- tests/providers/anthropic/test_adapter.py | 2 +- tests/providers/anthropic/test_probe.py | 85 +++++-- tests/providers/anthropic/test_provider.py | 3 - tests/providers/anthropic/test_stream.py | 2 +- tests/providers/anthropic/test_tools.py | 2 +- tests/providers/openai/test_adapter.py | 5 +- tests/providers/openai/test_probe.py | 32 ++- tests/providers/openai/test_provider.py | 3 - 21 files changed, 485 insertions(+), 214 deletions(-) create mode 100644 tests/models/core/test_model.py diff --git a/examples/.test_scripts/run-with-patched-model.py b/examples/.test_scripts/run-with-patched-model.py index 253977ab..0001f279 100644 --- a/examples/.test_scripts/run-with-patched-model.py +++ b/examples/.test_scripts/run-with-patched-model.py @@ -88,36 +88,31 @@ def main() -> None: original_stream = _api.stream original_generate = _api.generate - def selected_protocol() -> ai.ProviderProtocol[Any] | None: - if protocol_factory is None: - return None - return protocol_factory() - def selected_protocol_for_provider( provider: ai.Provider[Any], - ) -> ai.ProviderProtocol[Any] | None: - if args.protocol is None: + ) -> Callable[[], ai.ProviderProtocol[Any]] | None: + if args.protocol is None or protocol_factory is None: return None if args.protocol in ("chat", "responses") and isinstance( provider, OpenAICompatibleProvider ): - return selected_protocol() + return protocol_factory if args.protocol == "messages" and isinstance( provider, AnthropicCompatibleProvider ): - return selected_protocol() + return protocol_factory return None def selected_protocol_for_model( model: ai.Model, - ) -> ai.ProviderProtocol[Any] | None: + ) -> Callable[[], ai.ProviderProtocol[Any]] | None: return selected_protocol_for_provider(model.provider) def with_selected_protocol(model: ModelT) -> ModelT: - protocol = selected_protocol_for_model(model) - if protocol is None: + factory = selected_protocol_for_model(model) + if factory is None: return model - return model.with_protocol(protocol) + return model.with_protocol(factory) class PatchedContext: def __init__(self, context: Any) -> None: @@ -174,18 +169,12 @@ async def patched_generate(*call_args: Any, **kwargs: Any) -> Any: return await original_generate(*call_args, **kwargs) class PatchedModel(_model.Model): - def __init__( - self, - id: str, - *, - provider: ai.Provider[Any], - protocol: ai.ProviderProtocol[Any] | None = None, - ) -> None: - super().__init__( - id, - provider=provider, - protocol=selected_protocol_for_provider(provider) or protocol, - ) + def __init__(self, id: str, **kwargs: Any) -> None: + super().__init__(id, **kwargs) + override = selected_protocol_for_provider(self.provider) + if override is not None: + self.protocol_factory = override + self.protocol_args = {} cast("Any", ai).get_model = patched_get_model cast("Any", models).get_model = patched_get_model diff --git a/examples/builtin_web_search.py b/examples/builtin_web_search.py index bba7ad86..97d49257 100644 --- a/examples/builtin_web_search.py +++ b/examples/builtin_web_search.py @@ -45,13 +45,11 @@ def format(value: object) -> str: async def main() -> None: - provider = ai.get_provider("anthropic") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") + model = ai.get_model("anthropic:claude-sonnet-4-6") + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") return - model = ai.Model("claude-sonnet-4-6", provider=provider) - async with ai.stream(model, messages, tools=tools) as s: async for event in s: match event: diff --git a/examples/check_connection.py b/examples/check_connection.py index 85b82fff..320ff236 100644 --- a/examples/check_connection.py +++ b/examples/check_connection.py @@ -5,10 +5,10 @@ import ai -PROVIDERS: list[tuple[str, ai.Provider, str]] = [ - ("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"), - ("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"), - ("openai", ai.get_provider("openai"), "gpt-5.4-mini"), +MODELS: list[tuple[str, ai.Model]] = [ + ("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")), + ("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")), + ("openai", ai.get_model("openai:gpt-5.4-mini")), ] _failed = False @@ -20,23 +20,22 @@ def _fail(msg: str) -> None: print(msg) -async def _check(name: str, provider: ai.Provider, model_id: str) -> None: - if not provider.is_configured(): - print(f" [SKIP] {provider.name} provider is not configured") +async def _check(name: str, model: ai.Model) -> None: + if not model.provider.is_configured(): + print(f" [SKIP] {model.provider.name} provider is not configured") return - model = ai.Model(model_id, provider=provider) try: await ai.probe(model) - print(f" [OK] {name}/{model_id}") + print(f" [OK] {name}/{model.id}") except Exception as exc: - _fail(f" [ERR] {name}/{model_id}: {exc}") + _fail(f" [ERR] {name}/{model.id}: {exc}") -async def _list_models(name: str, provider: ai.Provider) -> None: - if not provider.is_configured(): +async def _list_models(name: str, model: ai.Model) -> None: + if not model.provider.is_configured(): return try: - ids: list[str] = await provider.list_models() + ids: list[str] = await model.provider.list_models() print(f" {name}: {len(ids)} models (last: {ids[-1]})") except Exception as exc: _fail(f" {name}: [ERR] {exc}") @@ -44,12 +43,12 @@ async def _list_models(name: str, provider: ai.Provider) -> None: async def main() -> None: print("Checking connections...\n") - for name, provider, model_id in PROVIDERS: - await _check(name, provider, model_id) + for name, model in MODELS: + await _check(name, model) print("\nListing models...\n") - for name, provider, _ in PROVIDERS: - await _list_models(name, provider) + for name, model in MODELS: + await _list_models(name, model) print() if _failed: diff --git a/examples/explicit_client.py b/examples/explicit_client.py index dbbadd18..822ff1ae 100644 --- a/examples/explicit_client.py +++ b/examples/explicit_client.py @@ -10,18 +10,16 @@ async def main() -> None: # Example for local OpenAI-compatible servers like LM Studio. - provider = ai.get_provider( - "openai", - base_url=os.environ.get( - "LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1" - ), - api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"), - headers={"X-Custom-Header": "example"}, - ) - model = ai.Model( os.environ.get("LOCAL_OPENAI_MODEL", "local-model"), - provider=provider, + provider_factory=lambda: ai.get_provider( + "openai", + base_url=os.environ.get( + "LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1" + ), + api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"), + headers={"X-Custom-Header": "example"}, + ), ) try: @@ -39,8 +37,8 @@ async def main() -> None: print(event.chunk, end="", flush=True) print() finally: - # Explicit providers need explicit cleanup. - await provider.aclose() + # Explicitly configured providers need explicit cleanup. + await model.aclose() if __name__ == "__main__": diff --git a/examples/openai_chat_completions.py b/examples/openai_chat_completions.py index 5a909725..cc86ff36 100644 --- a/examples/openai_chat_completions.py +++ b/examples/openai_chat_completions.py @@ -14,16 +14,12 @@ async def main() -> None: - provider = ai.get_provider("openai") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") - return - - model = ai.Model( - "gpt-5.5", - provider=provider, - protocol=OpenAIChatCompletionsProtocol(), + model = ai.get_model("openai:gpt-5.5").with_protocol( + OpenAIChatCompletionsProtocol ) + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") + return try: async with ai.stream(model, messages) as stream: @@ -32,7 +28,7 @@ async def main() -> None: print(event.chunk, end="", flush=True) print() finally: - await provider.aclose() + await model.aclose() if __name__ == "__main__": diff --git a/examples/stream_all.py b/examples/stream_all.py index a4e01f7b..41574969 100644 --- a/examples/stream_all.py +++ b/examples/stream_all.py @@ -4,10 +4,10 @@ import ai -MODELS: list[tuple[str, ai.Provider, str]] = [ - ("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"), - ("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"), - ("openai", ai.get_provider("openai"), "gpt-5.5"), +MODELS: list[tuple[str, ai.Model]] = [ + ("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")), + ("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")), + ("openai", ai.get_model("openai:gpt-5.5")), ] messages = [ @@ -16,15 +16,13 @@ ] -async def _run(name: str, provider: ai.Provider, model_id: str) -> None: - print(f"\n{name} / {model_id}") +async def _run(name: str, model: ai.Model) -> None: + print(f"\n{name} / {model.id}") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") return - model = ai.Model(model_id, provider=provider) - try: async with ai.stream(model, messages) as s: async for event in s: @@ -36,8 +34,8 @@ async def _run(name: str, provider: ai.Provider, model_id: str) -> None: async def main() -> None: - for name, provider, model_id in MODELS: - await _run(name, provider, model_id) + for name, model in MODELS: + await _run(name, model) if __name__ == "__main__": diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 7e30a462..59184f11 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -49,25 +49,13 @@ MODEL_ID = "gateway:anthropic/claude-sonnet-4.6" -# ── Workflow-safe model placeholder ────────────────────────────── -# -# ``agent.run`` requires a ``Model``, but a real one can't be built -# inside the workflow: ``ai.get_model("gateway:...")`` constructs an -# ``httpx.AsyncClient`` at provider-init time, which imports -# httpcore/anyio and trips the Temporal sandbox (``threading.local`` -# at module load). Our loop never calls the model directly anyway -- -# every LLM call is delegated to ``llm_call_activity``, which runs -# outside the sandbox and resolves the real model by id there. -# -# So hand the workflow a placeholder ``Model`` whose provider builds -# no client. It carries the real model id (so the activity can -# resolve it) but is safe to construct inside the sandbox. -class WorkflowModelProvider(ai.Provider[Any]): - """A clientless provider, safe to construct in a workflow sandbox.""" - - def __init__(self) -> None: - super().__init__(name="workflow-placeholder", base_url="") - +# ``ai.Model`` is fully serializable: it stores a provider *recipe* +# (factory reference + JSON args) and only builds the provider — and +# its ``httpx.AsyncClient`` — on first use. The workflow can therefore +# construct the model with ``ai.get_model`` inside the Temporal sandbox +# (nothing network-shaped is created there) and ship it to +# ``llm_call_activity`` as plain JSON, where the provider gets built +# for real. # ── Tool definitions ───────────────────────────────────────────── # @@ -116,7 +104,7 @@ async def get_population_activity(city: str) -> int: @dataclasses.dataclass class LLMParams: - model_id: str + model: dict[str, Any] messages: list[dict[str, Any]] tool_schemas: list[dict[str, Any]] @@ -129,7 +117,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.get_model(params.model_id) + model = ai.Model.model_validate(params.model) messages = [ai.messages.Message.model_validate(m) for m in params.messages] tools = [ ai.Tool( @@ -172,7 +160,7 @@ async def loop( result = await temporalio.workflow.execute_activity( llm_call_activity, LLMParams( - model_id=context.model.id, + model=context.model.model_dump(mode="json"), messages=[m.model_dump() for m in context.messages], tool_schemas=tool_schemas, ), @@ -238,7 +226,9 @@ async def _call() -> ai.events.ToolCallResult: class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.Model(MODEL_ID, provider=WorkflowModelProvider()) + # Safe in the sandbox: the provider (and its HTTP client) is + # only built lazily, inside the LLM activity. + model = ai.get_model(MODEL_ID) messages: list[ai.messages.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/tests/conftest.py b/tests/conftest.py index 3f742f05..434b5620 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -98,10 +98,16 @@ async def generate( MOCK_PROVIDER = MockProvider() + +def mock_provider() -> MockProvider: + """Provider factory returning the shared mock provider.""" + return MOCK_PROVIDER + + # A fixed Model used in tests. MOCK_MODEL: models.Model = models.Model( - id="mock-model", - provider=MOCK_PROVIDER, + "mock-model", + provider_factory=mock_provider, ) diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index e7c740f0..02b3c409 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -271,32 +271,34 @@ async def test_stream_requires_model_messages_or_context() -> None: pass -async def test_stream_uses_model_protocol() -> None: - class OverrideProtocol(models.ProviderProtocol[Any]): - def stream( - self, - client: Any, - model: models.Model, - messages: list[messages_.Message], - *, - tools: Sequence[ai.tools.Tool] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - params: models.InferenceRequestParams | None = None, - provider: str, - ) -> AsyncGenerator[events_.Event]: - _ = client, model, messages, tools, output_type, params, provider - - async def _stream() -> AsyncGenerator[events_.Event]: - yield events_.StreamStart() - yield events_.TextStart(block_id="text") - yield events_.TextDelta(block_id="text", chunk="override") - yield events_.TextEnd(block_id="text") - yield events_.StreamEnd() - - return _stream() +# Module-level so ``with_protocol`` can serialize a reference to it. +class _StreamOverrideProtocol(models.ProviderProtocol[Any]): + def stream( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: models.InferenceRequestParams | None = None, + provider: str, + ) -> AsyncGenerator[events_.Event]: + _ = client, model, messages, tools, output_type, params, provider + async def _stream() -> AsyncGenerator[events_.Event]: + yield events_.StreamStart() + yield events_.TextStart(block_id="text") + yield events_.TextDelta(block_id="text", chunk="override") + yield events_.TextEnd(block_id="text") + yield events_.StreamEnd() + + return _stream() + + +async def test_stream_uses_model_protocol() -> None: async with models.stream( - MOCK_MODEL.with_protocol(OverrideProtocol()), + MOCK_MODEL.with_protocol(_StreamOverrideProtocol), [ai.user_message("Hi")], ) as stream: async for _ in stream: @@ -306,11 +308,10 @@ async def _stream() -> AsyncGenerator[events_.Event]: async def test_generate_dispatches_to_provider() -> None: - provider = MockProvider() - model = models.Model( - id="generate-model", - provider=provider, - ) + # A provider class is itself a valid model factory. + model = models.Model("generate-model", provider_factory=MockProvider) + provider = model.provider + assert isinstance(provider, MockProvider) sentinel = messages_.Message( role="assistant", parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], @@ -338,32 +339,35 @@ async def _generate( assert result is sentinel -async def test_generate_uses_model_protocol() -> None: - sentinel = messages_.Message( - role="assistant", - parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], - ) +# Module-level so ``with_protocol`` can serialize a reference to it. +_GENERATED_IMAGE = messages_.Message( + role="assistant", + parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], +) - class OverrideProtocol(models.ProviderProtocol[Any]): - async def generate( - self, - client: Any, - model: models.Model, - messages: list[messages_.Message], - params: models.GenerateParams, - *, - provider: str, - ) -> messages_.Message: - _ = client, model, messages, params, provider - return sentinel +class _GenerateOverrideProtocol(models.ProviderProtocol[Any]): + async def generate( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + params: models.GenerateParams, + *, + provider: str, + ) -> messages_.Message: + _ = client, model, messages, params, provider + return _GENERATED_IMAGE + + +async def test_generate_uses_model_protocol() -> None: result = await models.generate( - MOCK_MODEL.with_protocol(OverrideProtocol()), + MOCK_MODEL.with_protocol(_GenerateOverrideProtocol), [ai.user_message("A cat")], models.ImageParams(n=1), ) - assert result is sentinel + assert result is _GENERATED_IMAGE class _CheckProvider(MockProvider): @@ -376,11 +380,12 @@ async def probe(self, model: models.Model) -> None: async def test_probe_delegates_to_model_provider() -> None: - provider = _CheckProvider() - model = models.Model("mock-model", provider=provider) + model = models.Model("mock-model", provider_factory=_CheckProvider) await models.probe(model) + provider = model.provider + assert isinstance(provider, _CheckProvider) assert provider.checked_model is model diff --git a/tests/models/core/test_model.py b/tests/models/core/test_model.py new file mode 100644 index 00000000..6edc56e1 --- /dev/null +++ b/tests/models/core/test_model.py @@ -0,0 +1,225 @@ +"""Tests for ``Model`` serialization and lazy provider construction.""" + +from __future__ import annotations + +import json + +import httpx +import pytest +from pydantic_core import PydanticSerializationError + +import ai +from ai import models +from ai.providers.openai import OpenAIChatCompletionsProtocol + +from ...conftest import MockProvider + + +def make_mock_provider(name: str = "mock") -> MockProvider: + """Module-level provider factory used by serialization tests.""" + return MockProvider(name=name) + + +def not_a_provider() -> str: + return "nope" + + +def test_model_dumps_factory_as_import_reference() -> None: + model = models.Model("mock-model", provider_factory=make_mock_provider) + + data = model.model_dump(mode="json") + + assert data == { + "id": "mock-model", + "provider_factory": ("tests.models.core.test_model:make_mock_provider"), + "provider_args": {}, + "protocol_factory": None, + "protocol_args": {}, + } + # the dump is real JSON + json.dumps(data) + + +def test_model_json_round_trip() -> None: + model = models.Model( + "mock-model", + provider_factory=make_mock_provider, + provider_args={"name": "custom"}, + ) + + restored = models.Model.model_validate( + json.loads(json.dumps(model.model_dump(mode="json"))) + ) + + assert restored == model + assert hash(restored) == hash(model) + assert restored.provider_factory is make_mock_provider + assert restored.provider.name == "custom" + + +def test_get_model_round_trip_builds_equivalent_provider() -> None: + model = ai.get_model("openai:gpt-5") + restored = models.Model.model_validate(model.model_dump(mode="json")) + + assert restored == model + assert restored.provider.name == "openai" + + +def test_provider_is_built_lazily_and_cached() -> None: + model = models.Model("mock-model", provider_factory=make_mock_provider) + + assert model._provider_instance is None + provider = model.provider + assert isinstance(provider, MockProvider) + assert model.provider is provider + + +async def test_aclose_drops_cached_provider() -> None: + model = models.Model("mock-model", provider_factory=make_mock_provider) + provider = model.provider + + await model.aclose() + + assert model._provider_instance is None + assert model.provider is not provider + + +def test_with_protocol_round_trips_and_shares_provider() -> None: + model = models.Model("mock-model", provider_factory=make_mock_provider) + provider = model.provider + + override = model.with_protocol(OpenAIChatCompletionsProtocol) + + assert override.provider is provider + assert isinstance(override.protocol, OpenAIChatCompletionsProtocol) + + restored = models.Model.model_validate(override.model_dump(mode="json")) + assert isinstance(restored.protocol, OpenAIChatCompletionsProtocol) + + +def test_accepts_string_factory_reference() -> None: + model = models.Model( + "mock-model", + provider_factory="tests.models.core.test_model:make_mock_provider", + ) + + assert model.provider_factory is make_mock_provider + + +def test_closure_factory_works_in_process_but_does_not_dump() -> None: + provider = MockProvider() + model = models.Model("mock-model", provider_factory=lambda: provider) + + assert model.provider is provider + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="lambdas"): + model.model_dump() + with pytest.raises(PydanticSerializationError, match="lambdas"): + model.model_dump(mode="json") + + +def test_factory_defined_inside_function_does_not_dump() -> None: + def local_factory() -> MockProvider: + return MockProvider() + + model = models.Model("mock-model", provider_factory=local_factory) + + assert isinstance(model.provider, MockProvider) + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="module-level"): + model.model_dump() + + +def test_factory_defined_in_main_does_not_dump( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(make_mock_provider, "__module__", "__main__") + model = models.Model("mock-model", provider_factory=make_mock_provider) + + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="__main__"): + model.model_dump() + + +def test_bound_method_factory_does_not_dump() -> None: + provider = MockProvider() + model = models.Model("mock-model", provider_factory=provider.list_models) + + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="same object"): + model.model_dump() + + +def test_non_json_provider_args_do_not_dump() -> None: + model = models.Model( + "mock-model", + provider_factory=ai.get_provider, + provider_args={"id": "openai", "client": httpx.AsyncClient()}, + ) + + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="round-trip"): + model.model_dump() + + +def test_serializable_is_true_for_factory_models() -> None: + assert ai.get_model("openai:gpt-5").serializable is True + assert ( + models.Model( + "mock-model", provider_factory=make_mock_provider + ).serializable + is True + ) + + +def test_rejects_args_not_matching_factory_signature() -> None: + with pytest.raises(ai.ConfigurationError, match="signature"): + models.Model( + "mock-model", + provider_factory=make_mock_provider, + provider_args={"unknown_arg": 1}, + ) + + +def test_rejects_protocol_args_without_protocol_factory() -> None: + with pytest.raises(ai.ConfigurationError, match="protocol_factory"): + models.Model( + "mock-model", + provider_factory=make_mock_provider, + protocol_args={"x": 1}, + ) + + +def test_rejects_malformed_string_reference() -> None: + with pytest.raises(ai.ConfigurationError, match="malformed"): + models.Model("mock-model", provider_factory="no-colon-here") + + +def test_rejects_unimportable_string_reference() -> None: + with pytest.raises(ai.ConfigurationError, match="cannot import"): + models.Model( + "mock-model", provider_factory="tests.no_such_module:factory" + ) + + +def test_factory_returning_non_provider_fails_on_access() -> None: + model = models.Model("mock-model", provider_factory=not_a_provider) + + with pytest.raises(ai.ConfigurationError, match="expected a Provider"): + _ = model.provider + + +def test_equality_ignores_cached_provider_instance() -> None: + a = models.Model("mock-model", provider_factory=make_mock_provider) + b = models.Model("mock-model", provider_factory=make_mock_provider) + _ = a.provider # cache an instance on one side only + + assert a == b + assert hash(a) == hash(b) + + c = models.Model( + "mock-model", + provider_factory=make_mock_provider, + provider_args={"name": "other"}, + ) + assert a != c diff --git a/tests/models/test_resolution.py b/tests/models/test_resolution.py index 98a286f3..08fb5426 100644 --- a/tests/models/test_resolution.py +++ b/tests/models/test_resolution.py @@ -171,10 +171,12 @@ def test_get_rejects_empty_model_id() -> None: def test_get_model_accepts_model_protocol_override() -> None: - protocol = OpenAIChatCompletionsProtocol() - model = models.get_model("openai:gpt-5", protocol=protocol) + model = models.get_model( + "openai:gpt-5", protocol_factory=OpenAIChatCompletionsProtocol + ) - assert model.protocol is protocol + assert isinstance(model.protocol, OpenAIChatCompletionsProtocol) + assert model.protocol is model.protocol # built once, cached assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) diff --git a/tests/providers/ai_gateway/conftest.py b/tests/providers/ai_gateway/conftest.py index c3cba9d2..6d7bb9e1 100644 --- a/tests/providers/ai_gateway/conftest.py +++ b/tests/providers/ai_gateway/conftest.py @@ -10,6 +10,8 @@ import ai from ai.types import messages +_BASE_URL = "https://gw.test/v3/ai" + def sse(*events: dict[str, Any]) -> str: """Build SSE response text from event dicts.""" @@ -22,14 +24,19 @@ def mock_model( model_id: str = "test-provider/test-model", api_key: str = "test-key", ) -> ai.Model: - """Create a Gateway model wired to a mock transport.""" + """Create a Gateway model wired to a mock transport. + + Per-test handlers are live objects, so the factory is a closure and + the model is deliberately not serializable (it never crosses a JSON + boundary in these tests; ``model_dump`` would raise). + """ provider = ai.get_provider( "vercel", - base_url="https://gw.test/v3/ai", + base_url=_BASE_URL, api_key=api_key, client=httpx.AsyncClient(transport=handler), ) - return ai.Model(model_id, provider=provider) + return ai.Model(model_id, provider_factory=lambda: provider) mock_client = mock_model diff --git a/tests/providers/ai_gateway/test_probe.py b/tests/providers/ai_gateway/test_probe.py index 2803c914..9683bfc4 100644 --- a/tests/providers/ai_gateway/test_probe.py +++ b/tests/providers/ai_gateway/test_probe.py @@ -12,13 +12,13 @@ _MODEL_ID = "anthropic/claude-opus-4-6" -def _gateway_client( - *, +def probe_provider( credits_status: int = 200, config_status: int = 200, config_body: dict[str, Any] | None = None, api_key: str | None = "sk-test-key", -) -> ai.Model: +) -> ai.Provider[Any]: + """Gateway provider whose mock responses are built from JSON args.""" credits_body = json.dumps({"balance": "10.00", "totalUsed": "5.00"}) config_bytes = json.dumps(config_body or {"models": []}).encode() @@ -27,13 +27,31 @@ def _handler(request: httpx.Request) -> httpx.Response: return httpx.Response(credits_status, content=credits_body.encode()) return httpx.Response(config_status, content=config_bytes) - provider = ai.get_provider( + return ai.get_provider( "vercel", base_url="https://gateway.test/v3/ai", api_key=api_key, client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), ) - return ai.Model(_MODEL_ID, provider=provider) + + +def _gateway_client( + *, + credits_status: int = 200, + config_status: int = 200, + config_body: dict[str, Any] | None = None, + api_key: str | None = "sk-test-key", +) -> ai.Model: + return ai.Model( + _MODEL_ID, + provider_factory=probe_provider, + provider_args={ + "credits_status": credits_status, + "config_status": config_status, + "config_body": config_body, + "api_key": api_key, + }, + ) async def test_auth_ok_model_present_succeeds() -> None: diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 15c4b59d..c7abc141 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -45,7 +45,7 @@ def _patch_client( return cast("anthropic.AsyncAnthropic", fake), captured -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _drain(stream: Any) -> None: diff --git a/tests/providers/anthropic/test_probe.py b/tests/providers/anthropic/test_probe.py index e9e92884..963dc86f 100644 --- a/tests/providers/anthropic/test_probe.py +++ b/tests/providers/anthropic/test_probe.py @@ -8,7 +8,7 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, ClassVar import httpx import pytest @@ -17,16 +17,18 @@ from ai.providers.anthropic import AnthropicCompatibleProvider -def _client_with_mock( +def probe_provider( status_code: int = 200, - json_body: Any = None, + json_body: dict[str, Any] | None = None, base_url: str = "https://anthropic.test", -) -> ai.Model: +) -> ai.Provider[Any]: + """Anthropic provider whose mock response is built from JSON args.""" + def _handler(request: httpx.Request) -> httpx.Response: body = json.dumps(json_body or {}).encode() return httpx.Response(status_code, content=body) - provider = ai.get_provider( + return ai.get_provider( "anthropic", base_url=base_url, api_key="sk-test-key", @@ -35,7 +37,22 @@ def _handler(request: httpx.Request) -> httpx.Response: transport=httpx.MockTransport(_handler), ), ) - return ai.Model("claude-opus-4-6", provider=provider) + + +def _client_with_mock( + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://anthropic.test", +) -> ai.Model: + return ai.Model( + "claude-opus-4-6", + provider_factory=probe_provider, + provider_args={ + "status_code": status_code, + "json_body": json_body, + "base_url": base_url, + }, + ) async def test_200_succeeds() -> None: @@ -51,27 +68,43 @@ async def test_model_not_found_raises_model_not_found() -> None: assert exc_info.value.model_id == model.id -async def test_custom_anthropic_version_header() -> None: - captured_headers: dict[str, str] = {} +class _HeaderCaptureProvider(AnthropicCompatibleProvider): + """Custom provider that records request headers. - def _handler(request: httpx.Request) -> httpx.Response: - captured_headers.update(dict(request.headers)) - body = json.dumps({"id": "custom-model", "type": "model"}).encode() - return httpx.Response(200, content=body) + A provider subclass is itself a valid model factory: the class is + module-level, so ``ai.Model`` can serialize a reference to it, and + per-instance state (the captured headers) stays on the provider, + reachable through ``model.provider``. + """ - provider = AnthropicCompatibleProvider( - name="custom-anthropic", - default_base_url="https://anthropic.test", - api_key="sk-test-key", - anthropic_version="2024-01-01", - headers={"X-Custom-Header": "example"}, - client=httpx.AsyncClient( - base_url="https://anthropic.test", - transport=httpx.MockTransport(_handler), - ), - ) + handles: ClassVar[tuple[str, ...]] = () + + def __init__(self) -> None: + self.captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + self.captured_headers.update(dict(request.headers)) + body = json.dumps({"id": "custom-model", "type": "model"}).encode() + return httpx.Response(200, content=body) + + super().__init__( + name="custom-anthropic", + default_base_url="https://anthropic.test", + api_key="sk-test-key", + anthropic_version="2024-01-01", + headers={"X-Custom-Header": "example"}, + client=httpx.AsyncClient( + base_url="https://anthropic.test", + transport=httpx.MockTransport(_handler), + ), + ) + + +async def test_custom_anthropic_version_header() -> None: + model = ai.Model("custom-model", provider_factory=_HeaderCaptureProvider) - model = ai.Model("custom-model", provider=provider) + provider = model.provider + assert isinstance(provider, _HeaderCaptureProvider) await provider.probe(model) - assert captured_headers["anthropic-version"] == "2024-01-01" - assert captured_headers["x-custom-header"] == "example" + assert provider.captured_headers["anthropic-version"] == "2024-01-01" + assert provider.captured_headers["x-custom-header"] == "example" diff --git a/tests/providers/anthropic/test_provider.py b/tests/providers/anthropic/test_provider.py index 1692b48d..41ba00a5 100644 --- a/tests/providers/anthropic/test_provider.py +++ b/tests/providers/anthropic/test_provider.py @@ -181,15 +181,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: headers={"X-Custom-Header": "example"}, ) - model = ai.Model("custom-model", provider=provider) assert repr(provider) == "anthropic" assert isinstance(provider.protocol, AnthropicMessagesProtocol) assert provider.base_url == "https://custom.example.com" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True - assert model.id == "custom-model" - assert model.provider is provider def test_get_provider_env_overrides_base_url_env() -> None: diff --git a/tests/providers/anthropic/test_stream.py b/tests/providers/anthropic/test_stream.py index ad0400dc..adfd2ee7 100644 --- a/tests/providers/anthropic/test_stream.py +++ b/tests/providers/anthropic/test_stream.py @@ -28,7 +28,7 @@ snapshot_block, ) -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _drain( diff --git a/tests/providers/anthropic/test_tools.py b/tests/providers/anthropic/test_tools.py index 7c1e5008..a9f0fe33 100644 --- a/tests/providers/anthropic/test_tools.py +++ b/tests/providers/anthropic/test_tools.py @@ -23,7 +23,7 @@ from .conftest import FakeAnthropicClient -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _capture_tools( diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index 30d6ba86..d8db6810 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -111,10 +111,7 @@ async def close(self) -> None: self.closed = True -_MODEL = ai.Model( - "gpt-5.4", - provider=ai.get_provider("openai", api_key="sk-test"), -) +_MODEL = ai.get_model("openai:gpt-5.4") def _patch( diff --git a/tests/providers/openai/test_probe.py b/tests/providers/openai/test_probe.py index 437067a8..f9a5cd1c 100644 --- a/tests/providers/openai/test_probe.py +++ b/tests/providers/openai/test_probe.py @@ -9,16 +9,18 @@ import ai -def _client_with_mock( +def probe_provider( status_code: int = 200, - json_body: Any = None, + json_body: dict[str, Any] | None = None, base_url: str = "https://openai.test/v1", -) -> ai.Model: +) -> ai.Provider[Any]: + """OpenAI provider whose mock response is built from JSON args.""" + def _handler(request: httpx.Request) -> httpx.Response: body = json.dumps(json_body or {}).encode() return httpx.Response(status_code, content=body) - provider = ai.get_provider( + return ai.get_provider( "openai", base_url=base_url, api_key="sk-test-key", @@ -27,7 +29,22 @@ def _handler(request: httpx.Request) -> httpx.Response: transport=httpx.MockTransport(_handler), ), ) - return ai.Model("gpt-5.4", provider=provider) + + +def _client_with_mock( + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://openai.test/v1", +) -> ai.Model: + return ai.Model( + "gpt-5.4", + provider_factory=probe_provider, + provider_args={ + "status_code": status_code, + "json_body": json_body, + "base_url": base_url, + }, + ) async def test_200_succeeds() -> None: @@ -63,7 +80,6 @@ async def test_no_api_key_raises_not_configured( ) -> None: monkeypatch.delenv("OPENAI_API_KEY", raising=False) - provider = ai.get_provider("openai", base_url="https://openai.test/v1") - model = ai.Model("gpt-5.4", provider=provider) + model = ai.get_model("openai:gpt-5.4") with pytest.raises(ai.ProviderNotConfiguredError): - await provider.probe(model) + await model.provider.probe(model) diff --git a/tests/providers/openai/test_provider.py b/tests/providers/openai/test_provider.py index b7ff6429..8467815f 100644 --- a/tests/providers/openai/test_provider.py +++ b/tests/providers/openai/test_provider.py @@ -210,15 +210,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: headers={"X-Custom-Header": "example"}, ) - model = ai.Model("custom-model", provider=provider) assert repr(provider) == "openai" assert isinstance(provider.protocol, OpenAIResponsesProtocol) assert provider.base_url == "https://custom.example.com/v1" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True - assert model.id == "custom-model" - assert model.provider is provider assert isinstance(provider, ai.Provider) From 42f8a2b7320bcd00448f352e242fc157d61c5c4e Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 16 Jun 2026 15:33:13 -0700 Subject: [PATCH 3/3] Replace function-based factories with classes --- .../.test_scripts/run-with-patched-model.py | 35 +- examples/explicit_client.py | 2 +- examples/openai_chat_completions.py | 3 +- src/ai/__init__.py | 4 + src/ai/models/__init__.py | 34 +- src/ai/models/core/__init__.py | 4 +- src/ai/models/core/model.py | 396 +++++++++--------- src/ai/providers/anthropic/__init__.py | 9 +- src/ai/providers/anthropic/provider.py | 2 +- src/ai/providers/base.py | 106 ++--- src/ai/providers/openai/__init__.py | 6 +- tests/conftest.py | 26 +- tests/models/core/test_api.py | 48 ++- tests/models/core/test_model.py | 218 +++++----- tests/models/test_resolution.py | 4 +- tests/providers/ai_gateway/conftest.py | 19 +- tests/providers/ai_gateway/test_probe.py | 73 ++-- tests/providers/anthropic/test_probe.py | 80 ++-- tests/providers/openai/test_probe.py | 58 +-- 19 files changed, 584 insertions(+), 543 deletions(-) diff --git a/examples/.test_scripts/run-with-patched-model.py b/examples/.test_scripts/run-with-patched-model.py index 0001f279..cdba686c 100644 --- a/examples/.test_scripts/run-with-patched-model.py +++ b/examples/.test_scripts/run-with-patched-model.py @@ -20,7 +20,6 @@ import argparse import runpy import sys -from collections.abc import Callable from typing import Any, TypeVar, cast import ai @@ -30,12 +29,9 @@ from ai.models.core import model as _model from ai.providers.anthropic import ( AnthropicCompatibleProvider, - AnthropicMessagesProtocol, ) from ai.providers.openai import ( - OpenAIChatCompletionsProtocol, OpenAICompatibleProvider, - OpenAIResponsesProtocol, ) PROTOCOLS = ("chat", "messages", "responses") @@ -43,18 +39,16 @@ ModelT = TypeVar("ModelT", bound=ai.Model) -def _protocol_factory( - name: str | None, -) -> Callable[[], ai.ProviderProtocol[Any]] | None: +def _protocol_ref(name: str | None) -> ai.ProtocolRef | None: if name is None: return None if name == "chat": - return OpenAIChatCompletionsProtocol + return ai.ProtocolRef("openai.chat_completions") if name == "messages": - return AnthropicMessagesProtocol + return ai.ProtocolRef("anthropic.messages") if name == "responses": - return OpenAIResponsesProtocol + return ai.ProtocolRef("openai.responses") raise ValueError(f"unsupported protocol: {name}") @@ -82,7 +76,7 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - protocol_factory = _protocol_factory(args.protocol) + protocol_ref = _protocol_ref(args.protocol) original_get_model = _model.get_model original_stream = _api.stream @@ -90,29 +84,29 @@ def main() -> None: def selected_protocol_for_provider( provider: ai.Provider[Any], - ) -> Callable[[], ai.ProviderProtocol[Any]] | None: - if args.protocol is None or protocol_factory is None: + ) -> ai.ProtocolRef | None: + if args.protocol is None or protocol_ref is None: return None if args.protocol in ("chat", "responses") and isinstance( provider, OpenAICompatibleProvider ): - return protocol_factory + return protocol_ref if args.protocol == "messages" and isinstance( provider, AnthropicCompatibleProvider ): - return protocol_factory + return protocol_ref return None def selected_protocol_for_model( model: ai.Model, - ) -> Callable[[], ai.ProviderProtocol[Any]] | None: + ) -> ai.ProtocolRef | None: return selected_protocol_for_provider(model.provider) def with_selected_protocol(model: ModelT) -> ModelT: - factory = selected_protocol_for_model(model) - if factory is None: + selected = selected_protocol_for_model(model) + if selected is None: return model - return model.with_protocol(factory) + return model.with_protocol(selected) class PatchedContext: def __init__(self, context: Any) -> None: @@ -173,8 +167,7 @@ def __init__(self, id: str, **kwargs: Any) -> None: super().__init__(id, **kwargs) override = selected_protocol_for_provider(self.provider) if override is not None: - self.protocol_factory = override - self.protocol_args = {} + self.protocol_ref = override cast("Any", ai).get_model = patched_get_model cast("Any", models).get_model = patched_get_model diff --git a/examples/explicit_client.py b/examples/explicit_client.py index 822ff1ae..e01d264e 100644 --- a/examples/explicit_client.py +++ b/examples/explicit_client.py @@ -12,7 +12,7 @@ async def main() -> None: # Example for local OpenAI-compatible servers like LM Studio. model = ai.Model( os.environ.get("LOCAL_OPENAI_MODEL", "local-model"), - provider_factory=lambda: ai.get_provider( + provider=ai.ProviderRef( "openai", base_url=os.environ.get( "LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1" diff --git a/examples/openai_chat_completions.py b/examples/openai_chat_completions.py index cc86ff36..b7f8d285 100644 --- a/examples/openai_chat_completions.py +++ b/examples/openai_chat_completions.py @@ -3,7 +3,6 @@ import asyncio import ai -from ai.providers.openai import OpenAIChatCompletionsProtocol messages = [ ai.system_message("Be concise."), @@ -15,7 +14,7 @@ async def main() -> None: model = ai.get_model("openai:gpt-5.5").with_protocol( - OpenAIChatCompletionsProtocol + "openai.chat_completions" ) if not model.provider.is_configured(): print(f"[SKIP] {model.provider.name} provider is not configured") diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 2df6e7f4..757da4a7 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -62,9 +62,11 @@ Model, ModelProviderDefault, OutputParams, + ProtocolRef, Provider, ProviderProtocol, ProviderRankingStrategy, + ProviderRef, ProviderServiceParams, RandomSeed, ReasoningParams, @@ -125,6 +127,7 @@ "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderAPIError", "ProviderAuthenticationError", @@ -142,6 +145,7 @@ "ProviderProtocol", "ProviderRankingStrategy", "ProviderRateLimitError", + "ProviderRef", "ProviderRequestTooLargeError", "ProviderResponseError", "ProviderServiceParams", diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 4cf49fb1..612e530e 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -1,13 +1,9 @@ """models — composable model layer. -A :class:`Model` holds a *recipe* for its provider — a factory callable -plus its arguments — instead of a live provider object. The provider and -its client are built lazily on first use. When the factory is a named, -module-level callable and the args are JSON-friendly (everything -``get_model`` produces), the model serializes: ``model.model_dump()`` / -``Model.model_validate()`` round-trip. Any other callable (a lambda, a -closure over live objects) works normally in-process, but ``model_dump`` -raises and ``model.serializable`` is ``False``. +A :class:`Model` holds JSON-safe provider settings instead of a live +provider object. The provider and its client are built lazily on first +use. Models produced by ``get_model`` round-trip through +``model.model_dump(mode="json")`` / ``Model.model_validate()``. Usage:: @@ -19,8 +15,10 @@ # custom provider configuration — JSON-friendly args model = ai.Model( "llama3", - provider_factory=ai.get_provider, - provider_args={"id": "openai", "base_url": "http://localhost:11434/v1"}, + provider=ai.ProviderRef( + "openai", + base_url="http://localhost:11434/v1", + ), ) # stream — auto-creates client from env vars @@ -34,18 +32,6 @@ data = model.model_dump(mode="json") model = ai.Model.model_validate(data) - # anything non-serializable (clients, custom auth) lives inside a - # named module-level factory; its import path is what's serialized - def my_provider() -> ai.Provider: - return ai.get_provider("openai", client=shared_client) - - model = ai.Model("gpt-5.4", provider_factory=my_provider) - - # if the model never crosses a process boundary, any callable works — - # the model just isn't serializable (model_dump() raises) - model = ai.Model("gpt-5.4", provider_factory=lambda: provider) - assert model.serializable is False - # list available models ids = await ai.get_provider("openai").list_models() """ @@ -62,7 +48,7 @@ def my_provider() -> ai.Provider: probe, stream, ) -from .core.model import Model, get_model +from .core.model import Model, ProtocolRef, ProviderRef, get_model from .core.params import ( DEFAULT, GLOBAL, @@ -118,9 +104,11 @@ def my_provider() -> ai.Provider: "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderProtocol", "ProviderRankingStrategy", + "ProviderRef", "ProviderServiceParams", "RandomSeed", "ReasoningParams", diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index a826cfec..7aabe535 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -13,7 +13,7 @@ probe, stream, ) -from .model import Model, get_model +from .model import Model, ProtocolRef, ProviderRef, get_model from .params import ( DEFAULT, GLOBAL, @@ -69,8 +69,10 @@ "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderRankingStrategy", + "ProviderRef", "ProviderServiceParams", "RandomSeed", "ReasoningParams", diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index a31c9669..51068ff7 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,10 +1,7 @@ """Model metadata types.""" import importlib -import inspect -import json import os -from collections.abc import Callable from typing import Any, Self, cast import pydantic @@ -16,111 +13,134 @@ _DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL" -def _callable_ref(fn: Callable[..., Any]) -> str: - """Return the ``"package.module:qualname"`` reference for *fn*. +type ProtocolName = str - Raises :class:`ai.ConfigurationError` when *fn* cannot be found again - by name — which is exactly what makes a factory unserializable. - """ - module_name = getattr(fn, "__module__", None) - qualname = getattr(fn, "__qualname__", None) - if not module_name or not qualname: - raise ConfigurationError( - f"factory {fn!r} has no importable name; it must be a named, " - "module-level function or class so the model can be serialized" - ) - if module_name == "__main__": - raise ConfigurationError( - f"factory {qualname!r} is defined in __main__ and cannot be " - "imported by other processes; move it into an importable module" - ) - if "<" in qualname: - raise ConfigurationError( - f"factory {module_name}.{qualname} must be a named, module-level " - "function or class so the model can be serialized; lambdas and " - "callables defined inside functions are not importable by name" - ) - try: - module = importlib.import_module(module_name) - except ImportError as error: - raise ConfigurationError( - f"cannot import module {module_name!r} of factory " - f"{qualname!r}: {error}" - ) from error - obj: Any = module - for part in qualname.split("."): - obj = getattr(obj, part, None) - if obj is not fn: - raise ConfigurationError( - f"factory {module_name}:{qualname} does not import back to the " - "same object; it must be a named, module-level function or class " - "(bound methods and decorated wrappers are not supported)" - ) - return f"{module_name}:{qualname}" +class ProtocolRef(pydantic.BaseModel): + """JSON-safe reference to a built-in provider protocol.""" + + name: ProtocolName + + model_config = pydantic.ConfigDict(frozen=True) + + def __init__(self, name: ProtocolName | None = None, **data: Any) -> None: + if name is not None: + data["name"] = name + super().__init__(**data) + + @pydantic.model_validator(mode="before") + @classmethod + def _coerce(cls, value: Any) -> Any: + if isinstance(value, str): + return {"name": value} + return value + + def build(self) -> base.ProviderProtocol[Any]: + match self.name: + case "openai.responses": + protocol: Any = importlib.import_module( + "ai.providers.openai" + ).OpenAIResponsesProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "openai.chat_completions": + protocol = importlib.import_module( + "ai.providers.openai" + ).OpenAIChatCompletionsProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "anthropic.messages": + protocol = importlib.import_module( + "ai.providers.anthropic" + ).AnthropicMessagesProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "gateway.v3": + protocol = importlib.import_module( + "ai.providers.ai_gateway" + ).GatewayV3Protocol + return cast("base.ProviderProtocol[Any]", protocol()) + case _: + raise ConfigurationError( + f"unknown protocol reference: {self.name!r}" + ) + + def __hash__(self) -> int: + return hash(self.name) + + +class ProviderRef(pydantic.BaseModel): + """JSON-safe provider settings used by :class:`Model`.""" + + id: str + model_id: str | None = None + base_url: str | None = None + api_key: str | None = None + headers: dict[str, str] = pydantic.Field(default_factory=dict) + env: dict[str, str] = pydantic.Field(default_factory=dict) + + model_config = pydantic.ConfigDict(frozen=True) + + def __init__(self, id: str | None = None, **data: Any) -> None: + if id is not None: + data["id"] = id + super().__init__(**data) -def _import_ref(ref: str) -> Callable[..., Any]: - """Import a factory from a ``"package.module:qualname"`` reference.""" - module_name, sep, qualname = ref.partition(":") - if not sep or not module_name or not qualname: - raise ConfigurationError( - f"malformed factory reference {ref!r}; " - "expected 'package.module:name'" + @pydantic.model_validator(mode="before") + @classmethod + def _coerce(cls, value: Any) -> Any: + if isinstance(value, str): + return {"id": value} + return value + + def build(self) -> base.Provider[Any]: + model_provider_config = None + if self.model_id is not None: + model_info = _modelsdev.get_model_by_id( + f"{self.id}:{self.model_id}" + ) + if model_info is not None: + model_provider_config = model_info.provider_config + return base.Provider.from_id( + self.id, + model_provider_config=model_provider_config, + base_url=self.base_url, + api_key=self.api_key, + headers=self.headers, + env=self.env, ) - try: - module = importlib.import_module(module_name) - except ImportError as error: - raise ConfigurationError( - f"cannot import factory {ref!r}: {error}" - ) from error - obj: Any = module - for part in qualname.split("."): - obj = getattr(obj, part, None) - if obj is None: - raise ConfigurationError( - f"cannot import factory {ref!r}: module {module_name!r} " - f"has no attribute {qualname!r}" + + def __hash__(self) -> int: + return hash( + ( + self.id, + self.model_id, + self.base_url, + self.api_key, + tuple(sorted(self.headers.items())), + tuple(sorted(self.env.items())), ) - if not callable(obj): - raise ConfigurationError(f"factory {ref!r} is not callable") - return cast("Callable[..., Any]", obj) + ) class Model(pydantic.BaseModel): """Reference to a model on a specific provider. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). - * ``provider_factory`` — callable that builds the :class:`Provider`. - * ``provider_args`` — keyword arguments for the factory. - * ``protocol_factory`` / ``protocol_args`` — optional wire-protocol - override for this model, same rules as the provider factory. + * ``provider_ref`` — JSON-safe provider settings. + * ``protocol_ref`` — optional JSON-safe wire-protocol override. The provider is built lazily on first :attr:`provider` access and cached. - - A model is **serializable** when the factory is a named module-level - function or class (dumped as a ``"package.module:name"`` reference) - and the args are JSON-friendly — everything :func:`get_model` - produces qualifies. Anything that cannot be expressed as JSON args - (custom clients, shared connection pools) can live inside the - factory body:: - - def my_provider() -> ai.Provider: - return ai.get_provider("openai", client=_shared_client) - - model = ai.Model("gpt-5", provider_factory=my_provider) - - Any other callable — a lambda, a closure over a live provider — is - accepted and works normally in-process, but the model then cannot - cross a JSON boundary: ``model_dump()`` raises with the reason. - Check :attr:`serializable` to know ahead of time. """ id: str - provider_factory: Callable[..., Any] - provider_args: dict[str, Any] = pydantic.Field(default_factory=dict) - protocol_factory: Callable[..., Any] | None = None - protocol_args: dict[str, Any] = pydantic.Field(default_factory=dict) + provider_ref: ProviderRef = pydantic.Field( + validation_alias=pydantic.AliasChoices("provider_ref", "provider"), + serialization_alias="provider", + ) + protocol_ref: ProtocolRef | None = pydantic.Field( + default=None, + validation_alias=pydantic.AliasChoices("protocol_ref", "protocol"), + serialization_alias="protocol", + ) _provider_instance: base.Provider[Any] | None = pydantic.PrivateAttr( default=None @@ -133,109 +153,90 @@ def __init__( self, id: str, *, - provider_factory: Callable[..., Any] | str, - provider_args: dict[str, Any] | None = None, - protocol_factory: Callable[..., Any] | str | None = None, - protocol_args: dict[str, Any] | None = None, + provider: ProviderRef | str | None = None, + protocol: ProtocolRef | ProtocolName | None = None, + provider_ref: ProviderRef | str | None = None, + protocol_ref: ProtocolRef | ProtocolName | None = None, ) -> None: + provider_ref = provider_ref if provider_ref is not None else provider + protocol_ref = protocol_ref if protocol_ref is not None else protocol + if provider_ref is None: + raise TypeError("Model requires a provider") super().__init__( id=id, - provider_factory=provider_factory, - provider_args={} if provider_args is None else provider_args, - protocol_factory=protocol_factory, - protocol_args={} if protocol_args is None else protocol_args, + provider_ref=provider_ref, + protocol_ref=protocol_ref, ) - @pydantic.field_validator( - "provider_factory", "protocol_factory", mode="before" + model_config = pydantic.ConfigDict( + populate_by_name=True, + serialize_by_alias=True, ) + + @pydantic.field_validator("provider_ref", mode="before") @classmethod - def _coerce_factory(cls, value: Any) -> Any: + def _coerce_provider_ref(cls, value: Any) -> Any: if isinstance(value, str): - return _import_ref(value) + return ProviderRef(value) return value - @pydantic.model_validator(mode="after") - def _check_factory_args(self) -> Self: - for label, factory, args in ( - ("provider_args", self.provider_factory, self.provider_args), - ("protocol_args", self.protocol_factory, self.protocol_args), - ): - if factory is None: - if args: - raise ConfigurationError( - "protocol_args given without protocol_factory" - ) - continue - try: - signature = inspect.signature(factory) - except (TypeError, ValueError): - continue # some builtins have no introspectable signature - try: - signature.bind(**args) - except TypeError as error: - raise ConfigurationError( - f"{label} do not match the signature of " - f"{factory!r}: {error}" - ) from error - return self - - # Dumps must round-trip or raise: both serializers run for python - # and JSON dumps alike, so a model built around a closure or live - # objects fails loudly the moment it tries to cross a boundary. - @pydantic.field_serializer("provider_factory", "protocol_factory") - def _serialize_factory( - self, value: Callable[..., Any] | None - ) -> str | None: - return None if value is None else _callable_ref(value) - - @pydantic.field_serializer("provider_args", "protocol_args") - def _serialize_args( + @pydantic.field_validator("protocol_ref", mode="before") + @classmethod + def _coerce_protocol_ref(cls, value: Any) -> Any: + if isinstance(value, str): + return ProtocolRef(value) + return value + + @pydantic.field_serializer("provider_ref") + def _serialize_provider_ref( self, - value: dict[str, Any], + value: ProviderRef, info: pydantic.FieldSerializationInfo, ) -> dict[str, Any]: - try: - json.dumps(value) - except (TypeError, ValueError) as error: + if type(value) is not ProviderRef: raise ConfigurationError( - f"{info.field_name} cannot round-trip through JSON " - f"(put live objects inside a named module-level factory " - f"instead): {error}" - ) from error - return value + "custom provider refs cannot be serialized" + ) + return value.model_dump( + mode=info.mode, + exclude_defaults=True, + exclude_none=True, + ) + + @pydantic.field_serializer("protocol_ref") + def _serialize_protocol_ref( + self, + value: ProtocolRef | None, + info: pydantic.FieldSerializationInfo, + ) -> dict[str, Any] | None: + if value is None: + return None + if type(value) is not ProtocolRef: + raise ConfigurationError( + "custom protocol refs cannot be serialized" + ) + return value.model_dump( + mode=info.mode, + exclude_defaults=True, + exclude_none=True, + ) @property def serializable(self) -> bool: - """Whether this model round-trips through ``model_dump``. - - ``True`` for named module-level factories with JSON-friendly - args — everything :func:`get_model` produces. ``False`` when - the model was built around a lambda, closure, or live objects; - such models work normally in-process but ``model_dump`` raises. - """ - for factory, args in ( - (self.provider_factory, self.provider_args), - (self.protocol_factory, self.protocol_args), - ): - if factory is None: - continue - try: - _callable_ref(factory) - json.dumps(args) - except (ConfigurationError, TypeError, ValueError): - return False - return True + """Whether this model round-trips through ``model_dump``.""" + return type(self.provider_ref) is ProviderRef and ( + self.protocol_ref is None or type(self.protocol_ref) is ProtocolRef + ) @property def provider(self) -> base.Provider[Any]: - """Provider instance, built lazily from the factory and cached.""" + """Provider instance, built lazily from the provider ref and cached.""" provider = self._provider_instance if provider is None: - provider = self.provider_factory(**self.provider_args) + provider = self.provider_ref.build() if not isinstance(provider, base.Provider): raise ConfigurationError( - f"provider factory {self.provider_factory!r} returned " + f"provider ref {self.provider_ref!r} returned " f"{type(provider).__name__}, expected a Provider" ) self._provider_instance = provider @@ -244,14 +245,14 @@ def provider(self) -> base.Provider[Any]: @property def protocol(self) -> base.ProviderProtocol[Any] | None: """Protocol override instance, built lazily and cached.""" - if self.protocol_factory is None: + if self.protocol_ref is None: return None protocol = self._protocol_instance if protocol is None: - protocol = self.protocol_factory(**self.protocol_args) + protocol = self.protocol_ref.build() if not isinstance(protocol, base.ProviderProtocol): raise ConfigurationError( - f"protocol factory {self.protocol_factory!r} returned " + f"protocol ref {self.protocol_ref!r} returned " f"{type(protocol).__name__}, expected a ProviderProtocol" ) self._protocol_instance = protocol @@ -264,34 +265,27 @@ async def aclose(self) -> None: self._provider_instance = None def __eq__(self, other: object) -> bool: - # Pydantic's default __eq__ also compares private attributes, - # which would make models unequal once one of them lazily built - # its provider. Compare the recipe fields only; factories - # compare by identity (a round-tripped model imports the same - # factory object back, so equality survives serialization). + # Pydantic's default __eq__ also compares private attributes. + # Compare the JSON recipe only, so a lazily built provider does + # not affect equality. return ( isinstance(other, Model) and self.id == other.id - and self.provider_factory is other.provider_factory - and self.provider_args == other.provider_args - and self.protocol_factory is other.protocol_factory - and self.protocol_args == other.protocol_args + and self.provider_ref == other.provider_ref + and self.protocol_ref == other.protocol_ref ) def __hash__(self) -> int: - return hash((self.id, self.provider_factory, self.protocol_factory)) + return hash((self.id, self.provider_ref, self.protocol_ref)) def with_protocol( self, - protocol_factory: Callable[..., Any] | str, - **protocol_args: Any, + protocol: ProtocolRef | ProtocolName, ) -> Self: model = self.__class__( self.id, - provider_factory=self.provider_factory, - provider_args=self.provider_args, - protocol_factory=protocol_factory, - protocol_args=protocol_args, + provider_ref=self.provider_ref, + protocol_ref=protocol, ) # Keep sharing an already-built provider instance. model._provider_instance = self._provider_instance @@ -301,8 +295,7 @@ def with_protocol( def get_model( model_id: str | None = None, *, - protocol_factory: Callable[..., Any] | str | None = None, - protocol_args: dict[str, Any] | None = None, + protocol: ProtocolRef | ProtocolName | None = None, ) -> Model: """Resolve a model ID into a :class:`Model`. @@ -313,13 +306,9 @@ def get_model( Vercel AI Gateway. Examples: ``"openai:gpt-5"`` or ``"anthropic/claude-sonnet-4"``. When omitted, reads ``AI_SDK_DEFAULT_MODEL`` from the environment. - protocol_factory: - Optional wire-protocol override for this model — a named, - module-level callable (usually the protocol class) that builds - the protocol. When omitted, the provider chooses its default - protocol. - protocol_args: - JSON-serializable keyword arguments for ``protocol_factory``. + protocol: + Optional wire-protocol override for this model. When omitted, + the provider chooses its default protocol. Raises: Raises :class:`ai.ConfigurationError` when ``model_id`` and @@ -356,17 +345,12 @@ def get_model( # Fail early on unknown or unsupported providers without building a # provider (and its client); the model only stores the recipe. - base.resolve_provider_type( + base.Provider.resolve_type( provider_id, model_provider_config=model_provider_config ) return Model( provider_model_id, - provider_factory=base.provider_for_model, - provider_args={ - "provider_id": provider_id, - "model_id": provider_model_id, - }, - protocol_factory=protocol_factory, - protocol_args=protocol_args, + provider=ProviderRef(provider_id, model_id=provider_model_id), + protocol=protocol, ) diff --git a/src/ai/providers/anthropic/__init__.py b/src/ai/providers/anthropic/__init__.py index 5577421d..785b9290 100644 --- a/src/ai/providers/anthropic/__init__.py +++ b/src/ai/providers/anthropic/__init__.py @@ -8,11 +8,10 @@ model = ai.get_model("anthropic:claude-sonnet-4-6") model = ai.Model( "claude-sonnet-4-6", - provider_factory=ai.get_provider, - provider_args={ - "id": "anthropic", - "base_url": "https://anthropic.example.com", - }, + provider=ai.ProviderRef( + "anthropic", + base_url="https://anthropic.example.com", + ), ) ids = await ai.get_provider("anthropic").list_models() diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index ef5ac5fb..43854c0e 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -39,7 +39,7 @@ class AnthropicCompatibleProvider(base.Provider[AnthropicSDKClient]): - """Callable provider for Anthropic-compatible APIs.""" + """Provider for Anthropic-compatible APIs.""" handles: ClassVar[tuple[str, ...]] = ("anthropic", "@ai-sdk/anthropic") diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 560c882a..1350b7f0 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -73,7 +73,7 @@ class Provider(Generic[ClientT]): def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - for handle in cls.handles: + for handle in cls.__dict__.get("handles", ()): existing = _PROVIDER_REGISTRY.get(handle) if existing is not None and existing is not cls: raise RuntimeError(f"duplicate provider handle: {handle!r}") @@ -277,9 +277,20 @@ def from_id( """Return a concrete provider for a models.dev provider ID.""" modelsdev_provider = _modelsdev.get_provider_by_id(known_id) if modelsdev_provider is None: - raise ValueError(f"unknown provider id: {known_id!r}") + provider_type = _PROVIDER_REGISTRY.get(known_id) + if provider_type is None: + raise ValueError(f"unknown provider id: {known_id!r}") + return provider_type.from_provider_id( + known_id, + base_url=base_url, + api_key=api_key, + headers=headers, + env=env, + client=client, + protocol=protocol, + ) - provider_type = resolve_provider_type( + provider_type = cls.resolve_type( known_id, model_provider_config=model_provider_config ) return provider_type.from_modelsdev_provider( @@ -293,6 +304,47 @@ def from_id( protocol=protocol, ) + @classmethod + def resolve_type( + cls, + known_id: str, + *, + model_provider_config: modelsdotdev.ModelProviderConfig | None = None, + ) -> type[Provider[Any]]: + """Return the registered provider class without building a client.""" + modelsdev_provider = _modelsdev.get_provider_by_id(known_id) + if modelsdev_provider is None: + provider_type = _PROVIDER_REGISTRY.get(known_id) + if provider_type is not None: + return provider_type + raise ValueError(f"unknown provider id: {known_id!r}") + + for handle in ( + modelsdev_provider.id, + _modelsdev.provider_npm(modelsdev_provider, model_provider_config), + ): + provider_type = _PROVIDER_REGISTRY.get(handle) + if provider_type is not None: + return provider_type + + raise UnsupportedProviderError(modelsdev_provider.id) + + @classmethod + def from_provider_id( + cls, + known_id: str, + *, + base_url: str | None = None, + api_key: str | None = None, + headers: Mapping[str, str] | None = None, + env: Mapping[str, str] | None = None, + client: Any | None = None, + protocol: ProviderProtocol[Any] | None = None, + ) -> Provider[Any]: + """Construct this provider from a directly registered provider ID.""" + _ = base_url, api_key, headers, env, client, protocol + raise UnsupportedProviderError(known_id) + @classmethod def from_modelsdev_provider( cls, @@ -313,54 +365,6 @@ def from_modelsdev_provider( _PROVIDER_REGISTRY: dict[str, type[Provider[Any]]] = {} -def resolve_provider_type( - known_id: str, - *, - model_provider_config: modelsdotdev.ModelProviderConfig | None = None, -) -> type[Provider[Any]]: - """Return the registered provider class for a models.dev provider ID. - - Performs the same lookup as :meth:`Provider.from_id` without building - a provider (or its client). - - Raises ``ValueError`` for unknown provider IDs and - :class:`ai.UnsupportedProviderError` for known providers without a - registered implementation. - """ - modelsdev_provider = _modelsdev.get_provider_by_id(known_id) - if modelsdev_provider is None: - raise ValueError(f"unknown provider id: {known_id!r}") - - for handle in ( - modelsdev_provider.id, - _modelsdev.provider_npm(modelsdev_provider, model_provider_config), - ): - provider_type = _PROVIDER_REGISTRY.get(handle) - if provider_type is not None: - return provider_type - - raise UnsupportedProviderError(modelsdev_provider.id) - - -def provider_for_model( - provider_id: str, model_id: str | None = None -) -> Provider[Any]: - """Build the provider serving *model_id* on *provider_id*. - - This is the default provider factory stored by :func:`ai.get_model`. - It looks up models.dev metadata at build time, so its arguments stay - plain JSON and models that reference it serialize cleanly. - """ - model_provider_config = None - if model_id is not None: - model_info = _modelsdev.get_model_by_id(f"{provider_id}:{model_id}") - if model_info is not None: - model_provider_config = model_info.provider_config - return Provider.from_id( - provider_id, model_provider_config=model_provider_config - ) - - def get_provider( id: str, *, diff --git a/src/ai/providers/openai/__init__.py b/src/ai/providers/openai/__init__.py index 0b50b40b..4c490652 100644 --- a/src/ai/providers/openai/__init__.py +++ b/src/ai/providers/openai/__init__.py @@ -7,8 +7,10 @@ model = ai.get_model("openai:gpt-5.4") model = ai.Model( "llama3", - provider_factory=ai.get_provider, - provider_args={"id": "openai", "base_url": "http://localhost:11434/v1"}, + provider=ai.ProviderRef( + "openai", + base_url="http://localhost:11434/v1", + ), ) ids = await ai.get_provider("openai").list_models() diff --git a/tests/conftest.py b/tests/conftest.py index 434b5620..871181d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterable, Sequence +from collections.abc import AsyncGenerator, AsyncIterable, Mapping, Sequence from typing import Any, cast import pydantic @@ -20,6 +20,8 @@ class MockProvider(models.Provider): Carries just enough state so that ``Model`` objects can be constructed. """ + handles = ("mock",) + def __init__( self, *, @@ -35,6 +37,21 @@ def __init__( self._stream_impl: Any | None = None self._generate_impl: Any | None = None + @classmethod + def from_provider_id( + cls, + known_id: str, + *, + base_url: str | None = None, + api_key: str | None = None, + headers: Mapping[str, str] | None = None, + env: Mapping[str, str] | None = None, + client: Any | None = None, + protocol: models.ProviderProtocol[Any] | None = None, + ) -> models.Provider[Any]: + _ = known_id, base_url, api_key, headers, env, client, protocol + return MOCK_PROVIDER + async def list_models(self) -> list[str]: return [] @@ -99,15 +116,10 @@ async def generate( MOCK_PROVIDER = MockProvider() -def mock_provider() -> MockProvider: - """Provider factory returning the shared mock provider.""" - return MOCK_PROVIDER - - # A fixed Model used in tests. MOCK_MODEL: models.Model = models.Model( "mock-model", - provider_factory=mock_provider, + provider="mock", ) diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index 02b3c409..04979e9b 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -33,6 +33,17 @@ def _provider_metadata_marker( return marker +class _StaticProviderRef(models.ProviderRef): + _provider: models.Provider[Any] = pydantic.PrivateAttr() + + def __init__(self, provider: models.Provider[Any]) -> None: + super().__init__("mock") + self._provider = provider + + def build(self) -> models.Provider[Any]: + return self._provider + + def test_inference_request_params_with_provider_params() -> None: class GatewayParams: pass @@ -271,7 +282,6 @@ async def test_stream_requires_model_messages_or_context() -> None: pass -# Module-level so ``with_protocol`` can serialize a reference to it. class _StreamOverrideProtocol(models.ProviderProtocol[Any]): def stream( self, @@ -296,9 +306,17 @@ async def _stream() -> AsyncGenerator[events_.Event]: return _stream() +class _StreamOverrideProtocolRef(models.ProtocolRef): + def __init__(self) -> None: + super().__init__("test.stream") + + def build(self) -> models.ProviderProtocol[Any]: + return _StreamOverrideProtocol() + + async def test_stream_uses_model_protocol() -> None: async with models.stream( - MOCK_MODEL.with_protocol(_StreamOverrideProtocol), + MOCK_MODEL.with_protocol(_StreamOverrideProtocolRef()), [ai.user_message("Hi")], ) as stream: async for _ in stream: @@ -308,10 +326,10 @@ async def test_stream_uses_model_protocol() -> None: async def test_generate_dispatches_to_provider() -> None: - # A provider class is itself a valid model factory. - model = models.Model("generate-model", provider_factory=MockProvider) - provider = model.provider - assert isinstance(provider, MockProvider) + provider = MockProvider() + model = models.Model( + "generate-model", provider_ref=_StaticProviderRef(provider) + ) sentinel = messages_.Message( role="assistant", parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], @@ -339,7 +357,6 @@ async def _generate( assert result is sentinel -# Module-level so ``with_protocol`` can serialize a reference to it. _GENERATED_IMAGE = messages_.Message( role="assistant", parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], @@ -360,9 +377,17 @@ async def generate( return _GENERATED_IMAGE +class _GenerateOverrideProtocolRef(models.ProtocolRef): + def __init__(self) -> None: + super().__init__("test.generate") + + def build(self) -> models.ProviderProtocol[Any]: + return _GenerateOverrideProtocol() + + async def test_generate_uses_model_protocol() -> None: result = await models.generate( - MOCK_MODEL.with_protocol(_GenerateOverrideProtocol), + MOCK_MODEL.with_protocol(_GenerateOverrideProtocolRef()), [ai.user_message("A cat")], models.ImageParams(n=1), ) @@ -380,12 +405,13 @@ async def probe(self, model: models.Model) -> None: async def test_probe_delegates_to_model_provider() -> None: - model = models.Model("mock-model", provider_factory=_CheckProvider) + provider = _CheckProvider() + model = models.Model( + "mock-model", provider_ref=_StaticProviderRef(provider) + ) await models.probe(model) - provider = model.provider - assert isinstance(provider, _CheckProvider) assert provider.checked_model is model diff --git a/tests/models/core/test_model.py b/tests/models/core/test_model.py index 6edc56e1..db87f797 100644 --- a/tests/models/core/test_model.py +++ b/tests/models/core/test_model.py @@ -3,8 +3,9 @@ from __future__ import annotations import json +from typing import Any -import httpx +import pydantic import pytest from pydantic_core import PydanticSerializationError @@ -12,40 +13,42 @@ from ai import models from ai.providers.openai import OpenAIChatCompletionsProtocol -from ...conftest import MockProvider +from ...conftest import MOCK_PROVIDER, MockProvider -def make_mock_provider(name: str = "mock") -> MockProvider: - """Module-level provider factory used by serialization tests.""" - return MockProvider(name=name) +class _FreshMockProviderRef(models.ProviderRef): + def __init__(self) -> None: + super().__init__("mock") + + def build(self) -> MockProvider: + return MockProvider() -def not_a_provider() -> str: - return "nope" +class _Box(pydantic.BaseModel): + model: models.Model -def test_model_dumps_factory_as_import_reference() -> None: - model = models.Model("mock-model", provider_factory=make_mock_provider) +def test_model_dumps_provider_ref_as_json_data() -> None: + model = models.Model( + "mock-model", + provider=models.ProviderRef("mock", base_url="http://mock.test"), + ) data = model.model_dump(mode="json") assert data == { "id": "mock-model", - "provider_factory": ("tests.models.core.test_model:make_mock_provider"), - "provider_args": {}, - "protocol_factory": None, - "protocol_args": {}, + "provider": { + "id": "mock", + "base_url": "http://mock.test", + }, + "protocol": None, } - # the dump is real JSON json.dumps(data) def test_model_json_round_trip() -> None: - model = models.Model( - "mock-model", - provider_factory=make_mock_provider, - provider_args={"name": "custom"}, - ) + model = ai.get_model("openai:gpt-5") restored = models.Model.model_validate( json.loads(json.dumps(model.model_dump(mode="json"))) @@ -53,8 +56,29 @@ def test_model_json_round_trip() -> None: assert restored == model assert hash(restored) == hash(model) - assert restored.provider_factory is make_mock_provider - assert restored.provider.name == "custom" + assert restored.provider_ref == models.ProviderRef( + "openai", model_id="gpt-5" + ) + + +def test_model_inside_pydantic_model_round_trips() -> None: + box = _Box(model=ai.get_model("openai:gpt-5")) + + restored = _Box.model_validate( + json.loads(json.dumps(box.model_dump(mode="json"))) + ) + + assert restored == box + + +def test_model_validate_does_not_resolve_provider() -> None: + model = models.Model.model_validate( + {"id": "test-model", "provider": {"id": "not-registered"}} + ) + + assert model.provider_ref.id == "not-registered" + with pytest.raises(ValueError, match="unknown provider"): + _ = model.provider def test_get_model_round_trip_builds_equivalent_provider() -> None: @@ -66,16 +90,15 @@ def test_get_model_round_trip_builds_equivalent_provider() -> None: def test_provider_is_built_lazily_and_cached() -> None: - model = models.Model("mock-model", provider_factory=make_mock_provider) + model = models.Model("mock-model", provider="mock") assert model._provider_instance is None - provider = model.provider - assert isinstance(provider, MockProvider) - assert model.provider is provider + assert model.provider is MOCK_PROVIDER + assert model.provider is MOCK_PROVIDER async def test_aclose_drops_cached_provider() -> None: - model = models.Model("mock-model", provider_factory=make_mock_provider) + model = models.Model("mock-model", provider_ref=_FreshMockProviderRef()) provider = model.provider await model.aclose() @@ -85,10 +108,10 @@ async def test_aclose_drops_cached_provider() -> None: def test_with_protocol_round_trips_and_shares_provider() -> None: - model = models.Model("mock-model", provider_factory=make_mock_provider) + model = models.Model("mock-model", provider="mock") provider = model.provider - override = model.with_protocol(OpenAIChatCompletionsProtocol) + override = model.with_protocol("openai.chat_completions") assert override.provider is provider assert isinstance(override.protocol, OpenAIChatCompletionsProtocol) @@ -97,129 +120,86 @@ def test_with_protocol_round_trips_and_shares_provider() -> None: assert isinstance(restored.protocol, OpenAIChatCompletionsProtocol) -def test_accepts_string_factory_reference() -> None: +def test_accepts_string_provider_ref() -> None: + model = models.Model("mock-model", provider="mock") + + assert model.provider_ref == models.ProviderRef("mock") + + +def test_accepts_string_protocol_ref() -> None: model = models.Model( "mock-model", - provider_factory="tests.models.core.test_model:make_mock_provider", + provider="mock", + protocol="openai.chat_completions", ) - assert model.provider_factory is make_mock_provider + assert model.protocol_ref == models.ProtocolRef("openai.chat_completions") -def test_closure_factory_works_in_process_but_does_not_dump() -> None: - provider = MockProvider() - model = models.Model("mock-model", provider_factory=lambda: provider) - - assert model.provider is provider - assert model.serializable is False - with pytest.raises(PydanticSerializationError, match="lambdas"): - model.model_dump() - with pytest.raises(PydanticSerializationError, match="lambdas"): - model.model_dump(mode="json") +def test_unknown_protocol_fails_on_access_not_validation() -> None: + model = models.Model.model_validate( + { + "id": "mock-model", + "provider": {"id": "mock"}, + "protocol": {"name": "not-real"}, + } + ) + with pytest.raises(ai.ConfigurationError, match="unknown protocol"): + _ = model.protocol -def test_factory_defined_inside_function_does_not_dump() -> None: - def local_factory() -> MockProvider: - return MockProvider() - model = models.Model("mock-model", provider_factory=local_factory) +def test_custom_provider_ref_works_in_process_but_does_not_dump() -> None: + model = models.Model("mock-model", provider_ref=_FreshMockProviderRef()) assert isinstance(model.provider, MockProvider) assert model.serializable is False - with pytest.raises(PydanticSerializationError, match="module-level"): + with pytest.raises(PydanticSerializationError, match="provider refs"): model.model_dump() -def test_factory_defined_in_main_does_not_dump( - monkeypatch: pytest.MonkeyPatch, -) -> None: - monkeypatch.setattr(make_mock_provider, "__module__", "__main__") - model = models.Model("mock-model", provider_factory=make_mock_provider) - - assert model.serializable is False - with pytest.raises(PydanticSerializationError, match="__main__"): - model.model_dump() - - -def test_bound_method_factory_does_not_dump() -> None: - provider = MockProvider() - model = models.Model("mock-model", provider_factory=provider.list_models) - - assert model.serializable is False - with pytest.raises(PydanticSerializationError, match="same object"): - model.model_dump() - - -def test_non_json_provider_args_do_not_dump() -> None: - model = models.Model( - "mock-model", - provider_factory=ai.get_provider, - provider_args={"id": "openai", "client": httpx.AsyncClient()}, - ) - - assert model.serializable is False - with pytest.raises(PydanticSerializationError, match="round-trip"): - model.model_dump() - - -def test_serializable_is_true_for_factory_models() -> None: +def test_serializable_is_true_for_plain_refs() -> None: assert ai.get_model("openai:gpt-5").serializable is True - assert ( - models.Model( - "mock-model", provider_factory=make_mock_provider - ).serializable - is True - ) - - -def test_rejects_args_not_matching_factory_signature() -> None: - with pytest.raises(ai.ConfigurationError, match="signature"): - models.Model( - "mock-model", - provider_factory=make_mock_provider, - provider_args={"unknown_arg": 1}, - ) - - -def test_rejects_protocol_args_without_protocol_factory() -> None: - with pytest.raises(ai.ConfigurationError, match="protocol_factory"): - models.Model( - "mock-model", - provider_factory=make_mock_provider, - protocol_args={"x": 1}, + assert models.Model("mock-model", provider="mock").serializable is True + + +def test_provider_ref_rejects_bad_header_type() -> None: + with pytest.raises(pydantic.ValidationError): + models.Model.model_validate( + { + "id": "mock-model", + "provider": { + "id": "mock", + "headers": {"x-test": object()}, + }, + } ) -def test_rejects_malformed_string_reference() -> None: - with pytest.raises(ai.ConfigurationError, match="malformed"): - models.Model("mock-model", provider_factory="no-colon-here") - - -def test_rejects_unimportable_string_reference() -> None: - with pytest.raises(ai.ConfigurationError, match="cannot import"): - models.Model( - "mock-model", provider_factory="tests.no_such_module:factory" - ) +def test_factory_returning_non_provider_fails_on_access() -> None: + class NotAProviderRef(models.ProviderRef): + def __init__(self) -> None: + super().__init__("mock") + def build(self) -> Any: + return "nope" -def test_factory_returning_non_provider_fails_on_access() -> None: - model = models.Model("mock-model", provider_factory=not_a_provider) + model = models.Model("mock-model", provider_ref=NotAProviderRef()) with pytest.raises(ai.ConfigurationError, match="expected a Provider"): _ = model.provider def test_equality_ignores_cached_provider_instance() -> None: - a = models.Model("mock-model", provider_factory=make_mock_provider) - b = models.Model("mock-model", provider_factory=make_mock_provider) - _ = a.provider # cache an instance on one side only + a = models.Model("mock-model", provider="mock") + b = models.Model("mock-model", provider="mock") + _ = a.provider assert a == b assert hash(a) == hash(b) c = models.Model( "mock-model", - provider_factory=make_mock_provider, - provider_args={"name": "other"}, + provider=models.ProviderRef("mock", base_url="http://other.test"), ) assert a != c diff --git a/tests/models/test_resolution.py b/tests/models/test_resolution.py index 08fb5426..53c24512 100644 --- a/tests/models/test_resolution.py +++ b/tests/models/test_resolution.py @@ -171,9 +171,7 @@ def test_get_rejects_empty_model_id() -> None: def test_get_model_accepts_model_protocol_override() -> None: - model = models.get_model( - "openai:gpt-5", protocol_factory=OpenAIChatCompletionsProtocol - ) + model = models.get_model("openai:gpt-5", protocol="openai.chat_completions") assert isinstance(model.protocol, OpenAIChatCompletionsProtocol) assert model.protocol is model.protocol # built once, cached diff --git a/tests/providers/ai_gateway/conftest.py b/tests/providers/ai_gateway/conftest.py index 6d7bb9e1..1bd77902 100644 --- a/tests/providers/ai_gateway/conftest.py +++ b/tests/providers/ai_gateway/conftest.py @@ -6,6 +6,7 @@ from typing import Any import httpx +import pydantic import ai from ai.types import messages @@ -13,6 +14,17 @@ _BASE_URL = "https://gw.test/v3/ai" +class _GatewayProviderRef(ai.ProviderRef): + _provider: ai.Provider[Any] = pydantic.PrivateAttr() + + def __init__(self, provider: ai.Provider[Any]) -> None: + super().__init__("vercel") + self._provider = provider + + def build(self) -> ai.Provider[Any]: + return self._provider + + def sse(*events: dict[str, Any]) -> str: """Build SSE response text from event dicts.""" return "".join(f"data: {json.dumps(e)}\n\n" for e in events) @@ -26,9 +38,8 @@ def mock_model( ) -> ai.Model: """Create a Gateway model wired to a mock transport. - Per-test handlers are live objects, so the factory is a closure and - the model is deliberately not serializable (it never crosses a JSON - boundary in these tests; ``model_dump`` would raise). + Per-test handlers are live objects, so this uses a test-only provider + ref. It never crosses a JSON boundary in these tests. """ provider = ai.get_provider( "vercel", @@ -36,7 +47,7 @@ def mock_model( api_key=api_key, client=httpx.AsyncClient(transport=handler), ) - return ai.Model(model_id, provider_factory=lambda: provider) + return ai.Model(model_id, provider_ref=_GatewayProviderRef(provider)) mock_client = mock_model diff --git a/tests/providers/ai_gateway/test_probe.py b/tests/providers/ai_gateway/test_probe.py index 9683bfc4..7ca82170 100644 --- a/tests/providers/ai_gateway/test_probe.py +++ b/tests/providers/ai_gateway/test_probe.py @@ -12,27 +12,45 @@ _MODEL_ID = "anthropic/claude-opus-4-6" -def probe_provider( - credits_status: int = 200, - config_status: int = 200, - config_body: dict[str, Any] | None = None, - api_key: str | None = "sk-test-key", -) -> ai.Provider[Any]: - """Gateway provider whose mock responses are built from JSON args.""" - credits_body = json.dumps({"balance": "10.00", "totalUsed": "5.00"}) - config_bytes = json.dumps(config_body or {"models": []}).encode() - - def _handler(request: httpx.Request) -> httpx.Response: - if "/v1/credits" in str(request.url): - return httpx.Response(credits_status, content=credits_body.encode()) - return httpx.Response(config_status, content=config_bytes) - - return ai.get_provider( - "vercel", - base_url="https://gateway.test/v3/ai", - api_key=api_key, - client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), - ) +class _ProbeProviderRef(ai.ProviderRef): + credits_status: int = 200 + config_status: int = 200 + config_body: dict[str, Any] | None = None + api_key: str | None = "sk-test-key" + + def __init__( + self, + *, + credits_status: int = 200, + config_status: int = 200, + config_body: dict[str, Any] | None = None, + api_key: str | None = "sk-test-key", + ) -> None: + super().__init__( + "vercel", + credits_status=credits_status, + config_status=config_status, + config_body=config_body, + api_key=api_key, + ) + + def build(self) -> ai.Provider[Any]: + credits_body = json.dumps({"balance": "10.00", "totalUsed": "5.00"}) + config_bytes = json.dumps(self.config_body or {"models": []}).encode() + + def _handler(request: httpx.Request) -> httpx.Response: + if "/v1/credits" in str(request.url): + return httpx.Response( + self.credits_status, content=credits_body.encode() + ) + return httpx.Response(self.config_status, content=config_bytes) + + return ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + api_key=self.api_key, + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) def _gateway_client( @@ -44,13 +62,12 @@ def _gateway_client( ) -> ai.Model: return ai.Model( _MODEL_ID, - provider_factory=probe_provider, - provider_args={ - "credits_status": credits_status, - "config_status": config_status, - "config_body": config_body, - "api_key": api_key, - }, + provider_ref=_ProbeProviderRef( + credits_status=credits_status, + config_status=config_status, + config_body=config_body, + api_key=api_key, + ), ) diff --git a/tests/providers/anthropic/test_probe.py b/tests/providers/anthropic/test_probe.py index 963dc86f..9e5ea67b 100644 --- a/tests/providers/anthropic/test_probe.py +++ b/tests/providers/anthropic/test_probe.py @@ -11,32 +11,46 @@ from typing import Any, ClassVar import httpx +import pydantic import pytest import ai from ai.providers.anthropic import AnthropicCompatibleProvider -def probe_provider( - status_code: int = 200, - json_body: dict[str, Any] | None = None, - base_url: str = "https://anthropic.test", -) -> ai.Provider[Any]: - """Anthropic provider whose mock response is built from JSON args.""" - - def _handler(request: httpx.Request) -> httpx.Response: - body = json.dumps(json_body or {}).encode() - return httpx.Response(status_code, content=body) - - return ai.get_provider( - "anthropic", - base_url=base_url, - api_key="sk-test-key", - client=httpx.AsyncClient( +class _ProbeProviderRef(ai.ProviderRef): + status_code: int = 200 + json_body: dict[str, Any] | None = None + + def __init__( + self, + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://anthropic.test", + ) -> None: + super().__init__( + "anthropic", + status_code=status_code, + json_body=json_body, base_url=base_url, - transport=httpx.MockTransport(_handler), - ), - ) + ) + + def build(self) -> ai.Provider[Any]: + def _handler(request: httpx.Request) -> httpx.Response: + _ = request + body = json.dumps(self.json_body or {}).encode() + return httpx.Response(self.status_code, content=body) + + assert self.base_url is not None + return ai.get_provider( + "anthropic", + base_url=self.base_url, + api_key="sk-test-key", + client=httpx.AsyncClient( + base_url=self.base_url, + transport=httpx.MockTransport(_handler), + ), + ) def _client_with_mock( @@ -46,12 +60,7 @@ def _client_with_mock( ) -> ai.Model: return ai.Model( "claude-opus-4-6", - provider_factory=probe_provider, - provider_args={ - "status_code": status_code, - "json_body": json_body, - "base_url": base_url, - }, + provider_ref=_ProbeProviderRef(status_code, json_body, base_url), ) @@ -69,13 +78,7 @@ async def test_model_not_found_raises_model_not_found() -> None: class _HeaderCaptureProvider(AnthropicCompatibleProvider): - """Custom provider that records request headers. - - A provider subclass is itself a valid model factory: the class is - module-level, so ``ai.Model`` can serialize a reference to it, and - per-instance state (the captured headers) stays on the provider, - reachable through ``model.provider``. - """ + """Custom provider that records request headers.""" handles: ClassVar[tuple[str, ...]] = () @@ -100,8 +103,19 @@ def _handler(request: httpx.Request) -> httpx.Response: ) +class _HeaderCaptureProviderRef(ai.ProviderRef): + _provider: _HeaderCaptureProvider = pydantic.PrivateAttr() + + def __init__(self) -> None: + super().__init__("anthropic") + self._provider = _HeaderCaptureProvider() + + def build(self) -> _HeaderCaptureProvider: + return self._provider + + async def test_custom_anthropic_version_header() -> None: - model = ai.Model("custom-model", provider_factory=_HeaderCaptureProvider) + model = ai.Model("custom-model", provider_ref=_HeaderCaptureProviderRef()) provider = model.provider assert isinstance(provider, _HeaderCaptureProvider) diff --git a/tests/providers/openai/test_probe.py b/tests/providers/openai/test_probe.py index f9a5cd1c..64821423 100644 --- a/tests/providers/openai/test_probe.py +++ b/tests/providers/openai/test_probe.py @@ -9,26 +9,39 @@ import ai -def probe_provider( - status_code: int = 200, - json_body: dict[str, Any] | None = None, - base_url: str = "https://openai.test/v1", -) -> ai.Provider[Any]: - """OpenAI provider whose mock response is built from JSON args.""" - - def _handler(request: httpx.Request) -> httpx.Response: - body = json.dumps(json_body or {}).encode() - return httpx.Response(status_code, content=body) - - return ai.get_provider( - "openai", - base_url=base_url, - api_key="sk-test-key", - client=httpx.AsyncClient( +class _ProbeProviderRef(ai.ProviderRef): + status_code: int = 200 + json_body: dict[str, Any] | None = None + + def __init__( + self, + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://openai.test/v1", + ) -> None: + super().__init__( + "openai", + status_code=status_code, + json_body=json_body, base_url=base_url, - transport=httpx.MockTransport(_handler), - ), - ) + ) + + def build(self) -> ai.Provider[Any]: + def _handler(request: httpx.Request) -> httpx.Response: + _ = request + body = json.dumps(self.json_body or {}).encode() + return httpx.Response(self.status_code, content=body) + + assert self.base_url is not None + return ai.get_provider( + "openai", + base_url=self.base_url, + api_key="sk-test-key", + client=httpx.AsyncClient( + base_url=self.base_url, + transport=httpx.MockTransport(_handler), + ), + ) def _client_with_mock( @@ -38,12 +51,7 @@ def _client_with_mock( ) -> ai.Model: return ai.Model( "gpt-5.4", - provider_factory=probe_provider, - provider_args={ - "status_code": status_code, - "json_body": json_body, - "base_url": base_url, - }, + provider_ref=_ProbeProviderRef(status_code, json_body, base_url), )