diff --git a/examples/.test_scripts/run-with-patched-model.py b/examples/.test_scripts/run-with-patched-model.py index 253977ab..cdba686c 100644 --- a/examples/.test_scripts/run-with-patched-model.py +++ b/examples/.test_scripts/run-with-patched-model.py @@ -20,7 +20,6 @@ import argparse import runpy import sys -from collections.abc import Callable from typing import Any, TypeVar, cast import ai @@ -30,12 +29,9 @@ from ai.models.core import model as _model from ai.providers.anthropic import ( AnthropicCompatibleProvider, - AnthropicMessagesProtocol, ) from ai.providers.openai import ( - OpenAIChatCompletionsProtocol, OpenAICompatibleProvider, - OpenAIResponsesProtocol, ) PROTOCOLS = ("chat", "messages", "responses") @@ -43,18 +39,16 @@ ModelT = TypeVar("ModelT", bound=ai.Model) -def _protocol_factory( - name: str | None, -) -> Callable[[], ai.ProviderProtocol[Any]] | None: +def _protocol_ref(name: str | None) -> ai.ProtocolRef | None: if name is None: return None if name == "chat": - return OpenAIChatCompletionsProtocol + return ai.ProtocolRef("openai.chat_completions") if name == "messages": - return AnthropicMessagesProtocol + return ai.ProtocolRef("anthropic.messages") if name == "responses": - return OpenAIResponsesProtocol + return ai.ProtocolRef("openai.responses") raise ValueError(f"unsupported protocol: {name}") @@ -82,42 +76,37 @@ def _parse_args() -> argparse.Namespace: def main() -> None: args = _parse_args() - protocol_factory = _protocol_factory(args.protocol) + protocol_ref = _protocol_ref(args.protocol) original_get_model = _model.get_model original_stream = _api.stream original_generate = _api.generate - def selected_protocol() -> ai.ProviderProtocol[Any] | None: - if protocol_factory is None: - return None - return protocol_factory() - def selected_protocol_for_provider( provider: ai.Provider[Any], - ) -> ai.ProviderProtocol[Any] | None: - if args.protocol is None: + ) -> ai.ProtocolRef | None: + if args.protocol is None or protocol_ref is None: return None if args.protocol in ("chat", "responses") and isinstance( provider, OpenAICompatibleProvider ): - return selected_protocol() + return protocol_ref if args.protocol == "messages" and isinstance( provider, AnthropicCompatibleProvider ): - return selected_protocol() + return protocol_ref return None def selected_protocol_for_model( model: ai.Model, - ) -> ai.ProviderProtocol[Any] | None: + ) -> ai.ProtocolRef | None: return selected_protocol_for_provider(model.provider) def with_selected_protocol(model: ModelT) -> ModelT: - protocol = selected_protocol_for_model(model) - if protocol is None: + selected = selected_protocol_for_model(model) + if selected is None: return model - return model.with_protocol(protocol) + return model.with_protocol(selected) class PatchedContext: def __init__(self, context: Any) -> None: @@ -174,18 +163,11 @@ async def patched_generate(*call_args: Any, **kwargs: Any) -> Any: return await original_generate(*call_args, **kwargs) class PatchedModel(_model.Model): - def __init__( - self, - id: str, - *, - provider: ai.Provider[Any], - protocol: ai.ProviderProtocol[Any] | None = None, - ) -> None: - super().__init__( - id, - provider=provider, - protocol=selected_protocol_for_provider(provider) or protocol, - ) + def __init__(self, id: str, **kwargs: Any) -> None: + super().__init__(id, **kwargs) + override = selected_protocol_for_provider(self.provider) + if override is not None: + self.protocol_ref = override cast("Any", ai).get_model = patched_get_model cast("Any", models).get_model = patched_get_model diff --git a/examples/builtin_web_search.py b/examples/builtin_web_search.py index bba7ad86..97d49257 100644 --- a/examples/builtin_web_search.py +++ b/examples/builtin_web_search.py @@ -45,13 +45,11 @@ def format(value: object) -> str: async def main() -> None: - provider = ai.get_provider("anthropic") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") + model = ai.get_model("anthropic:claude-sonnet-4-6") + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") return - model = ai.Model("claude-sonnet-4-6", provider=provider) - async with ai.stream(model, messages, tools=tools) as s: async for event in s: match event: diff --git a/examples/check_connection.py b/examples/check_connection.py index 85b82fff..320ff236 100644 --- a/examples/check_connection.py +++ b/examples/check_connection.py @@ -5,10 +5,10 @@ import ai -PROVIDERS: list[tuple[str, ai.Provider, str]] = [ - ("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"), - ("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"), - ("openai", ai.get_provider("openai"), "gpt-5.4-mini"), +MODELS: list[tuple[str, ai.Model]] = [ + ("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")), + ("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")), + ("openai", ai.get_model("openai:gpt-5.4-mini")), ] _failed = False @@ -20,23 +20,22 @@ def _fail(msg: str) -> None: print(msg) -async def _check(name: str, provider: ai.Provider, model_id: str) -> None: - if not provider.is_configured(): - print(f" [SKIP] {provider.name} provider is not configured") +async def _check(name: str, model: ai.Model) -> None: + if not model.provider.is_configured(): + print(f" [SKIP] {model.provider.name} provider is not configured") return - model = ai.Model(model_id, provider=provider) try: await ai.probe(model) - print(f" [OK] {name}/{model_id}") + print(f" [OK] {name}/{model.id}") except Exception as exc: - _fail(f" [ERR] {name}/{model_id}: {exc}") + _fail(f" [ERR] {name}/{model.id}: {exc}") -async def _list_models(name: str, provider: ai.Provider) -> None: - if not provider.is_configured(): +async def _list_models(name: str, model: ai.Model) -> None: + if not model.provider.is_configured(): return try: - ids: list[str] = await provider.list_models() + ids: list[str] = await model.provider.list_models() print(f" {name}: {len(ids)} models (last: {ids[-1]})") except Exception as exc: _fail(f" {name}: [ERR] {exc}") @@ -44,12 +43,12 @@ async def _list_models(name: str, provider: ai.Provider) -> None: async def main() -> None: print("Checking connections...\n") - for name, provider, model_id in PROVIDERS: - await _check(name, provider, model_id) + for name, model in MODELS: + await _check(name, model) print("\nListing models...\n") - for name, provider, _ in PROVIDERS: - await _list_models(name, provider) + for name, model in MODELS: + await _list_models(name, model) print() if _failed: diff --git a/examples/explicit_client.py b/examples/explicit_client.py index dbbadd18..e01d264e 100644 --- a/examples/explicit_client.py +++ b/examples/explicit_client.py @@ -10,18 +10,16 @@ async def main() -> None: # Example for local OpenAI-compatible servers like LM Studio. - provider = ai.get_provider( - "openai", - base_url=os.environ.get( - "LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1" - ), - api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"), - headers={"X-Custom-Header": "example"}, - ) - model = ai.Model( os.environ.get("LOCAL_OPENAI_MODEL", "local-model"), - provider=provider, + provider=ai.ProviderRef( + "openai", + base_url=os.environ.get( + "LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1" + ), + api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"), + headers={"X-Custom-Header": "example"}, + ), ) try: @@ -39,8 +37,8 @@ async def main() -> None: print(event.chunk, end="", flush=True) print() finally: - # Explicit providers need explicit cleanup. - await provider.aclose() + # Explicitly configured providers need explicit cleanup. + await model.aclose() if __name__ == "__main__": diff --git a/examples/openai_chat_completions.py b/examples/openai_chat_completions.py index 5a909725..b7f8d285 100644 --- a/examples/openai_chat_completions.py +++ b/examples/openai_chat_completions.py @@ -3,7 +3,6 @@ import asyncio import ai -from ai.providers.openai import OpenAIChatCompletionsProtocol messages = [ ai.system_message("Be concise."), @@ -14,16 +13,12 @@ async def main() -> None: - provider = ai.get_provider("openai") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") - return - - model = ai.Model( - "gpt-5.5", - provider=provider, - protocol=OpenAIChatCompletionsProtocol(), + model = ai.get_model("openai:gpt-5.5").with_protocol( + "openai.chat_completions" ) + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") + return try: async with ai.stream(model, messages) as stream: @@ -32,7 +27,7 @@ async def main() -> None: print(event.chunk, end="", flush=True) print() finally: - await provider.aclose() + await model.aclose() if __name__ == "__main__": diff --git a/examples/stream_all.py b/examples/stream_all.py index a4e01f7b..41574969 100644 --- a/examples/stream_all.py +++ b/examples/stream_all.py @@ -4,10 +4,10 @@ import ai -MODELS: list[tuple[str, ai.Provider, str]] = [ - ("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"), - ("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"), - ("openai", ai.get_provider("openai"), "gpt-5.5"), +MODELS: list[tuple[str, ai.Model]] = [ + ("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")), + ("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")), + ("openai", ai.get_model("openai:gpt-5.5")), ] messages = [ @@ -16,15 +16,13 @@ ] -async def _run(name: str, provider: ai.Provider, model_id: str) -> None: - print(f"\n{name} / {model_id}") +async def _run(name: str, model: ai.Model) -> None: + print(f"\n{name} / {model.id}") - if not provider.is_configured(): - print(f"[SKIP] {provider.name} provider is not configured") + if not model.provider.is_configured(): + print(f"[SKIP] {model.provider.name} provider is not configured") return - model = ai.Model(model_id, provider=provider) - try: async with ai.stream(model, messages) as s: async for event in s: @@ -36,8 +34,8 @@ async def _run(name: str, provider: ai.Provider, model_id: str) -> None: async def main() -> None: - for name, provider, model_id in MODELS: - await _run(name, provider, model_id) + for name, model in MODELS: + await _run(name, model) if __name__ == "__main__": diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 7e30a462..59184f11 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -49,25 +49,13 @@ MODEL_ID = "gateway:anthropic/claude-sonnet-4.6" -# ── Workflow-safe model placeholder ────────────────────────────── -# -# ``agent.run`` requires a ``Model``, but a real one can't be built -# inside the workflow: ``ai.get_model("gateway:...")`` constructs an -# ``httpx.AsyncClient`` at provider-init time, which imports -# httpcore/anyio and trips the Temporal sandbox (``threading.local`` -# at module load). Our loop never calls the model directly anyway -- -# every LLM call is delegated to ``llm_call_activity``, which runs -# outside the sandbox and resolves the real model by id there. -# -# So hand the workflow a placeholder ``Model`` whose provider builds -# no client. It carries the real model id (so the activity can -# resolve it) but is safe to construct inside the sandbox. -class WorkflowModelProvider(ai.Provider[Any]): - """A clientless provider, safe to construct in a workflow sandbox.""" - - def __init__(self) -> None: - super().__init__(name="workflow-placeholder", base_url="") - +# ``ai.Model`` is fully serializable: it stores a provider *recipe* +# (factory reference + JSON args) and only builds the provider — and +# its ``httpx.AsyncClient`` — on first use. The workflow can therefore +# construct the model with ``ai.get_model`` inside the Temporal sandbox +# (nothing network-shaped is created there) and ship it to +# ``llm_call_activity`` as plain JSON, where the provider gets built +# for real. # ── Tool definitions ───────────────────────────────────────────── # @@ -116,7 +104,7 @@ async def get_population_activity(city: str) -> int: @dataclasses.dataclass class LLMParams: - model_id: str + model: dict[str, Any] messages: list[dict[str, Any]] tool_schemas: list[dict[str, Any]] @@ -129,7 +117,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.get_model(params.model_id) + model = ai.Model.model_validate(params.model) messages = [ai.messages.Message.model_validate(m) for m in params.messages] tools = [ ai.Tool( @@ -172,7 +160,7 @@ async def loop( result = await temporalio.workflow.execute_activity( llm_call_activity, LLMParams( - model_id=context.model.id, + model=context.model.model_dump(mode="json"), messages=[m.model_dump() for m in context.messages], tool_schemas=tool_schemas, ), @@ -238,7 +226,9 @@ async def _call() -> ai.events.ToolCallResult: class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.Model(MODEL_ID, provider=WorkflowModelProvider()) + # Safe in the sandbox: the provider (and its HTTP client) is + # only built lazily, inside the LLM activity. + model = ai.get_model(MODEL_ID) messages: list[ai.messages.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 2df6e7f4..757da4a7 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -62,9 +62,11 @@ Model, ModelProviderDefault, OutputParams, + ProtocolRef, Provider, ProviderProtocol, ProviderRankingStrategy, + ProviderRef, ProviderServiceParams, RandomSeed, ReasoningParams, @@ -125,6 +127,7 @@ "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderAPIError", "ProviderAuthenticationError", @@ -142,6 +145,7 @@ "ProviderProtocol", "ProviderRankingStrategy", "ProviderRateLimitError", + "ProviderRef", "ProviderRequestTooLargeError", "ProviderResponseError", "ProviderServiceParams", diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index 29cbb800..612e530e 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -1,16 +1,26 @@ """models — composable model layer. +A :class:`Model` holds JSON-safe provider settings instead of a live +provider object. The provider and its client are built lazily on first +use. Models produced by ``get_model`` round-trip through +``model.model_dump(mode="json")`` / ``Model.model_validate()``. + Usage:: import ai model = ai.get_model("openai:gpt-5.4") - provider = ai.get_provider("openai", base_url="http://localhost:11434/v1") - model = ai.Model("llama3", provider=provider) model = ai.get_model("anthropic:claude-sonnet-4-6") - provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com") - model = ai.Model("claude-sonnet-4-6", provider=provider) model = ai.get_model("anthropic/claude-sonnet-4") # defaults to Gateway + # custom provider configuration — JSON-friendly args + model = ai.Model( + "llama3", + provider=ai.ProviderRef( + "openai", + base_url="http://localhost:11434/v1", + ), + ) + # stream — auto-creates client from env vars msgs = [ai.user_message("hello")] async with ai.stream(model, msgs) as s: @@ -18,15 +28,9 @@ if isinstance(event, ai.events.TextDelta): print(event.chunk, end="", flush=True) - # explicit provider for custom auth / transport - provider = ai.get_provider( - "openai", - base_url="https://custom.example.com/v1", - api_key="sk-...", - ) - model = ai.Model("gpt-5.4", provider=provider) - async with ai.stream(model, msgs) as s: - ... + # models serialize and rebuild their provider on first use + data = model.model_dump(mode="json") + model = ai.Model.model_validate(data) # list available models ids = await ai.get_provider("openai").list_models() @@ -44,7 +48,7 @@ probe, stream, ) -from .core.model import Model, get_model +from .core.model import Model, ProtocolRef, ProviderRef, get_model from .core.params import ( DEFAULT, GLOBAL, @@ -100,9 +104,11 @@ "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderProtocol", "ProviderRankingStrategy", + "ProviderRef", "ProviderServiceParams", "RandomSeed", "ReasoningParams", diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index a826cfec..7aabe535 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -13,7 +13,7 @@ probe, stream, ) -from .model import Model, get_model +from .model import Model, ProtocolRef, ProviderRef, get_model from .params import ( DEFAULT, GLOBAL, @@ -69,8 +69,10 @@ "Model", "ModelProviderDefault", "OutputParams", + "ProtocolRef", "Provider", "ProviderRankingStrategy", + "ProviderRef", "ProviderServiceParams", "RandomSeed", "ReasoningParams", diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 5ff6fd5d..51068ff7 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,7 +1,10 @@ """Model metadata types.""" +import importlib import os -from typing import Any, Self +from typing import Any, Self, cast + +import pydantic from ... import _modelsdev from ...errors import ConfigurationError @@ -10,51 +13,289 @@ _DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL" -class Model: - """Lightweight reference to a model on a specific provider. +type ProtocolName = str + + +class ProtocolRef(pydantic.BaseModel): + """JSON-safe reference to a built-in provider protocol.""" + + name: ProtocolName + + model_config = pydantic.ConfigDict(frozen=True) + + def __init__(self, name: ProtocolName | None = None, **data: Any) -> None: + if name is not None: + data["name"] = name + super().__init__(**data) + + @pydantic.model_validator(mode="before") + @classmethod + def _coerce(cls, value: Any) -> Any: + if isinstance(value, str): + return {"name": value} + return value + + def build(self) -> base.ProviderProtocol[Any]: + match self.name: + case "openai.responses": + protocol: Any = importlib.import_module( + "ai.providers.openai" + ).OpenAIResponsesProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "openai.chat_completions": + protocol = importlib.import_module( + "ai.providers.openai" + ).OpenAIChatCompletionsProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "anthropic.messages": + protocol = importlib.import_module( + "ai.providers.anthropic" + ).AnthropicMessagesProtocol + return cast("base.ProviderProtocol[Any]", protocol()) + case "gateway.v3": + protocol = importlib.import_module( + "ai.providers.ai_gateway" + ).GatewayV3Protocol + return cast("base.ProviderProtocol[Any]", protocol()) + case _: + raise ConfigurationError( + f"unknown protocol reference: {self.name!r}" + ) + + def __hash__(self) -> int: + return hash(self.name) + + +class ProviderRef(pydantic.BaseModel): + """JSON-safe provider settings used by :class:`Model`.""" + + id: str + model_id: str | None = None + base_url: str | None = None + api_key: str | None = None + headers: dict[str, str] = pydantic.Field(default_factory=dict) + env: dict[str, str] = pydantic.Field(default_factory=dict) + + model_config = pydantic.ConfigDict(frozen=True) + + def __init__(self, id: str | None = None, **data: Any) -> None: + if id is not None: + data["id"] = id + super().__init__(**data) + + @pydantic.model_validator(mode="before") + @classmethod + def _coerce(cls, value: Any) -> Any: + if isinstance(value, str): + return {"id": value} + return value + + def build(self) -> base.Provider[Any]: + model_provider_config = None + if self.model_id is not None: + model_info = _modelsdev.get_model_by_id( + f"{self.id}:{self.model_id}" + ) + if model_info is not None: + model_provider_config = model_info.provider_config + return base.Provider.from_id( + self.id, + model_provider_config=model_provider_config, + base_url=self.base_url, + api_key=self.api_key, + headers=self.headers, + env=self.env, + ) + + def __hash__(self) -> int: + return hash( + ( + self.id, + self.model_id, + self.base_url, + self.api_key, + tuple(sorted(self.headers.items())), + tuple(sorted(self.env.items())), + ) + ) + + +class Model(pydantic.BaseModel): + """Reference to a model on a specific provider. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). - * ``provider`` — :class:`Provider` that owns this model. - * ``protocol`` — optional wire-protocol override for this model. + * ``provider_ref`` — JSON-safe provider settings. + * ``protocol_ref`` — optional JSON-safe wire-protocol override. + + The provider is built lazily on first :attr:`provider` access and + cached. """ + id: str + provider_ref: ProviderRef = pydantic.Field( + validation_alias=pydantic.AliasChoices("provider_ref", "provider"), + serialization_alias="provider", + ) + protocol_ref: ProtocolRef | None = pydantic.Field( + default=None, + validation_alias=pydantic.AliasChoices("protocol_ref", "protocol"), + serialization_alias="protocol", + ) + + _provider_instance: base.Provider[Any] | None = pydantic.PrivateAttr( + default=None + ) + _protocol_instance: base.ProviderProtocol[Any] | None = ( + pydantic.PrivateAttr(default=None) + ) + def __init__( self, id: str, *, - provider: base.Provider, - protocol: base.ProviderProtocol[Any] | None = None, + provider: ProviderRef | str | None = None, + protocol: ProtocolRef | ProtocolName | None = None, + provider_ref: ProviderRef | str | None = None, + protocol_ref: ProtocolRef | ProtocolName | None = None, ) -> None: - self.id = id - self.provider = provider - self.protocol = protocol + provider_ref = provider_ref if provider_ref is not None else provider + protocol_ref = protocol_ref if protocol_ref is not None else protocol + if provider_ref is None: + raise TypeError("Model requires a provider") + super().__init__( + id=id, + provider_ref=provider_ref, + protocol_ref=protocol_ref, + ) + + model_config = pydantic.ConfigDict( + populate_by_name=True, + serialize_by_alias=True, + ) + + @pydantic.field_validator("provider_ref", mode="before") + @classmethod + def _coerce_provider_ref(cls, value: Any) -> Any: + if isinstance(value, str): + return ProviderRef(value) + return value + + @pydantic.field_validator("protocol_ref", mode="before") + @classmethod + def _coerce_protocol_ref(cls, value: Any) -> Any: + if isinstance(value, str): + return ProtocolRef(value) + return value + + @pydantic.field_serializer("provider_ref") + def _serialize_provider_ref( + self, + value: ProviderRef, + info: pydantic.FieldSerializationInfo, + ) -> dict[str, Any]: + if type(value) is not ProviderRef: + raise ConfigurationError( + "custom provider refs cannot be serialized" + ) + return value.model_dump( + mode=info.mode, + exclude_defaults=True, + exclude_none=True, + ) + + @pydantic.field_serializer("protocol_ref") + def _serialize_protocol_ref( + self, + value: ProtocolRef | None, + info: pydantic.FieldSerializationInfo, + ) -> dict[str, Any] | None: + if value is None: + return None + if type(value) is not ProtocolRef: + raise ConfigurationError( + "custom protocol refs cannot be serialized" + ) + return value.model_dump( + mode=info.mode, + exclude_defaults=True, + exclude_none=True, + ) + + @property + def serializable(self) -> bool: + """Whether this model round-trips through ``model_dump``.""" + return type(self.provider_ref) is ProviderRef and ( + self.protocol_ref is None or type(self.protocol_ref) is ProtocolRef + ) + + @property + def provider(self) -> base.Provider[Any]: + """Provider instance, built lazily from the provider ref and cached.""" + provider = self._provider_instance + if provider is None: + provider = self.provider_ref.build() + if not isinstance(provider, base.Provider): + raise ConfigurationError( + f"provider ref {self.provider_ref!r} returned " + f"{type(provider).__name__}, expected a Provider" + ) + self._provider_instance = provider + return provider + + @property + def protocol(self) -> base.ProviderProtocol[Any] | None: + """Protocol override instance, built lazily and cached.""" + if self.protocol_ref is None: + return None + protocol = self._protocol_instance + if protocol is None: + protocol = self.protocol_ref.build() + if not isinstance(protocol, base.ProviderProtocol): + raise ConfigurationError( + f"protocol ref {self.protocol_ref!r} returned " + f"{type(protocol).__name__}, expected a ProviderProtocol" + ) + self._protocol_instance = protocol + return protocol + + async def aclose(self) -> None: + """Close the provider if this model lazily built one.""" + if self._provider_instance is not None: + await self._provider_instance.aclose() + self._provider_instance = None def __eq__(self, other: object) -> bool: + # Pydantic's default __eq__ also compares private attributes. + # Compare the JSON recipe only, so a lazily built provider does + # not affect equality. return ( isinstance(other, Model) and self.id == other.id - and self.provider is other.provider - and self.protocol is other.protocol + and self.provider_ref == other.provider_ref + and self.protocol_ref == other.protocol_ref ) - def __repr__(self) -> str: - return f"Model(id={self.id!r}, provider={self.provider!r})" - def __hash__(self) -> int: - return hash((self.id, id(self.provider), id(self.protocol))) + return hash((self.id, self.provider_ref, self.protocol_ref)) - def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: - return self.__class__( - id=self.id, - provider=self.provider, - protocol=protocol, + def with_protocol( + self, + protocol: ProtocolRef | ProtocolName, + ) -> Self: + model = self.__class__( + self.id, + provider_ref=self.provider_ref, + protocol_ref=protocol, ) + # Keep sharing an already-built provider instance. + model._provider_instance = self._provider_instance + return model def get_model( model_id: str | None = None, *, - protocol: base.ProviderProtocol[Any] | None = None, + protocol: ProtocolRef | ProtocolName | None = None, ) -> Model: """Resolve a model ID into a :class:`Model`. @@ -102,9 +343,14 @@ def get_model( None if model_info is None else model_info.provider_config ) - provider = base.Provider.from_id( - provider_id, - model_provider_config=model_provider_config, + # Fail early on unknown or unsupported providers without building a + # provider (and its client); the model only stores the recipe. + base.Provider.resolve_type( + provider_id, model_provider_config=model_provider_config ) - return Model(provider_model_id, provider=provider, protocol=protocol) + return Model( + provider_model_id, + provider=ProviderRef(provider_id, model_id=provider_model_id), + protocol=protocol, + ) diff --git a/src/ai/providers/anthropic/__init__.py b/src/ai/providers/anthropic/__init__.py index a183b555..785b9290 100644 --- a/src/ai/providers/anthropic/__init__.py +++ b/src/ai/providers/anthropic/__init__.py @@ -6,8 +6,13 @@ from ai.providers.anthropic import tools as anthropic_tools model = ai.get_model("anthropic:claude-sonnet-4-6") - provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com") - model = ai.Model("claude-sonnet-4-6", provider=provider) + model = ai.Model( + "claude-sonnet-4-6", + provider=ai.ProviderRef( + "anthropic", + base_url="https://anthropic.example.com", + ), + ) ids = await ai.get_provider("anthropic").list_models() # built-in tools diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index ef5ac5fb..43854c0e 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -39,7 +39,7 @@ class AnthropicCompatibleProvider(base.Provider[AnthropicSDKClient]): - """Callable provider for Anthropic-compatible APIs.""" + """Provider for Anthropic-compatible APIs.""" handles: ClassVar[tuple[str, ...]] = ("anthropic", "@ai-sdk/anthropic") diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 28d1af10..1350b7f0 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -73,7 +73,7 @@ class Provider(Generic[ClientT]): def __init_subclass__(cls, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) - for handle in cls.handles: + for handle in cls.__dict__.get("handles", ()): existing = _PROVIDER_REGISTRY.get(handle) if existing is not None and existing is not cls: raise RuntimeError(f"duplicate provider handle: {handle!r}") @@ -277,6 +277,46 @@ def from_id( """Return a concrete provider for a models.dev provider ID.""" modelsdev_provider = _modelsdev.get_provider_by_id(known_id) if modelsdev_provider is None: + provider_type = _PROVIDER_REGISTRY.get(known_id) + if provider_type is None: + raise ValueError(f"unknown provider id: {known_id!r}") + return provider_type.from_provider_id( + known_id, + base_url=base_url, + api_key=api_key, + headers=headers, + env=env, + client=client, + protocol=protocol, + ) + + provider_type = cls.resolve_type( + known_id, model_provider_config=model_provider_config + ) + return provider_type.from_modelsdev_provider( + modelsdev_provider, + model_provider_config=model_provider_config, + base_url=base_url, + api_key=api_key, + headers=headers, + env=env, + client=client, + protocol=protocol, + ) + + @classmethod + def resolve_type( + cls, + known_id: str, + *, + model_provider_config: modelsdotdev.ModelProviderConfig | None = None, + ) -> type[Provider[Any]]: + """Return the registered provider class without building a client.""" + modelsdev_provider = _modelsdev.get_provider_by_id(known_id) + if modelsdev_provider is None: + provider_type = _PROVIDER_REGISTRY.get(known_id) + if provider_type is not None: + return provider_type raise ValueError(f"unknown provider id: {known_id!r}") for handle in ( @@ -285,19 +325,26 @@ def from_id( ): provider_type = _PROVIDER_REGISTRY.get(handle) if provider_type is not None: - return provider_type.from_modelsdev_provider( - modelsdev_provider, - model_provider_config=model_provider_config, - base_url=base_url, - api_key=api_key, - headers=headers, - env=env, - client=client, - protocol=protocol, - ) + return provider_type raise UnsupportedProviderError(modelsdev_provider.id) + @classmethod + def from_provider_id( + cls, + known_id: str, + *, + base_url: str | None = None, + api_key: str | None = None, + headers: Mapping[str, str] | None = None, + env: Mapping[str, str] | None = None, + client: Any | None = None, + protocol: ProviderProtocol[Any] | None = None, + ) -> Provider[Any]: + """Construct this provider from a directly registered provider ID.""" + _ = base_url, api_key, headers, env, client, protocol + raise UnsupportedProviderError(known_id) + @classmethod def from_modelsdev_provider( cls, diff --git a/src/ai/providers/openai/__init__.py b/src/ai/providers/openai/__init__.py index b590d20d..4c490652 100644 --- a/src/ai/providers/openai/__init__.py +++ b/src/ai/providers/openai/__init__.py @@ -5,8 +5,13 @@ import ai model = ai.get_model("openai:gpt-5.4") - provider = ai.get_provider("openai", base_url="http://localhost:11434/v1") - model = ai.Model("llama3", provider=provider) + model = ai.Model( + "llama3", + provider=ai.ProviderRef( + "openai", + base_url="http://localhost:11434/v1", + ), + ) ids = await ai.get_provider("openai").list_models() The optional upstream OpenAI SDK is loaded lazily when the provider creates or diff --git a/tests/conftest.py b/tests/conftest.py index 3f742f05..871181d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import AsyncGenerator, AsyncIterable, Sequence +from collections.abc import AsyncGenerator, AsyncIterable, Mapping, Sequence from typing import Any, cast import pydantic @@ -20,6 +20,8 @@ class MockProvider(models.Provider): Carries just enough state so that ``Model`` objects can be constructed. """ + handles = ("mock",) + def __init__( self, *, @@ -35,6 +37,21 @@ def __init__( self._stream_impl: Any | None = None self._generate_impl: Any | None = None + @classmethod + def from_provider_id( + cls, + known_id: str, + *, + base_url: str | None = None, + api_key: str | None = None, + headers: Mapping[str, str] | None = None, + env: Mapping[str, str] | None = None, + client: Any | None = None, + protocol: models.ProviderProtocol[Any] | None = None, + ) -> models.Provider[Any]: + _ = known_id, base_url, api_key, headers, env, client, protocol + return MOCK_PROVIDER + async def list_models(self) -> list[str]: return [] @@ -98,10 +115,11 @@ async def generate( MOCK_PROVIDER = MockProvider() + # A fixed Model used in tests. MOCK_MODEL: models.Model = models.Model( - id="mock-model", - provider=MOCK_PROVIDER, + "mock-model", + provider="mock", ) diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index e7c740f0..04979e9b 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -33,6 +33,17 @@ def _provider_metadata_marker( return marker +class _StaticProviderRef(models.ProviderRef): + _provider: models.Provider[Any] = pydantic.PrivateAttr() + + def __init__(self, provider: models.Provider[Any]) -> None: + super().__init__("mock") + self._provider = provider + + def build(self) -> models.Provider[Any]: + return self._provider + + def test_inference_request_params_with_provider_params() -> None: class GatewayParams: pass @@ -271,32 +282,41 @@ async def test_stream_requires_model_messages_or_context() -> None: pass -async def test_stream_uses_model_protocol() -> None: - class OverrideProtocol(models.ProviderProtocol[Any]): - def stream( - self, - client: Any, - model: models.Model, - messages: list[messages_.Message], - *, - tools: Sequence[ai.tools.Tool] | None = None, - output_type: type[pydantic.BaseModel] | None = None, - params: models.InferenceRequestParams | None = None, - provider: str, - ) -> AsyncGenerator[events_.Event]: - _ = client, model, messages, tools, output_type, params, provider - - async def _stream() -> AsyncGenerator[events_.Event]: - yield events_.StreamStart() - yield events_.TextStart(block_id="text") - yield events_.TextDelta(block_id="text", chunk="override") - yield events_.TextEnd(block_id="text") - yield events_.StreamEnd() - - return _stream() +class _StreamOverrideProtocol(models.ProviderProtocol[Any]): + def stream( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + *, + tools: Sequence[ai.tools.Tool] | None = None, + output_type: type[pydantic.BaseModel] | None = None, + params: models.InferenceRequestParams | None = None, + provider: str, + ) -> AsyncGenerator[events_.Event]: + _ = client, model, messages, tools, output_type, params, provider + + async def _stream() -> AsyncGenerator[events_.Event]: + yield events_.StreamStart() + yield events_.TextStart(block_id="text") + yield events_.TextDelta(block_id="text", chunk="override") + yield events_.TextEnd(block_id="text") + yield events_.StreamEnd() + + return _stream() + + +class _StreamOverrideProtocolRef(models.ProtocolRef): + def __init__(self) -> None: + super().__init__("test.stream") + + def build(self) -> models.ProviderProtocol[Any]: + return _StreamOverrideProtocol() + +async def test_stream_uses_model_protocol() -> None: async with models.stream( - MOCK_MODEL.with_protocol(OverrideProtocol()), + MOCK_MODEL.with_protocol(_StreamOverrideProtocolRef()), [ai.user_message("Hi")], ) as stream: async for _ in stream: @@ -308,8 +328,7 @@ async def _stream() -> AsyncGenerator[events_.Event]: async def test_generate_dispatches_to_provider() -> None: provider = MockProvider() model = models.Model( - id="generate-model", - provider=provider, + "generate-model", provider_ref=_StaticProviderRef(provider) ) sentinel = messages_.Message( role="assistant", @@ -338,32 +357,42 @@ async def _generate( assert result is sentinel -async def test_generate_uses_model_protocol() -> None: - sentinel = messages_.Message( - role="assistant", - parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], - ) +_GENERATED_IMAGE = messages_.Message( + role="assistant", + parts=[messages_.FilePart(data=b"\x89PNG", media_type="image/png")], +) + + +class _GenerateOverrideProtocol(models.ProviderProtocol[Any]): + async def generate( + self, + client: Any, + model: models.Model, + messages: list[messages_.Message], + params: models.GenerateParams, + *, + provider: str, + ) -> messages_.Message: + _ = client, model, messages, params, provider + return _GENERATED_IMAGE + + +class _GenerateOverrideProtocolRef(models.ProtocolRef): + def __init__(self) -> None: + super().__init__("test.generate") - class OverrideProtocol(models.ProviderProtocol[Any]): - async def generate( - self, - client: Any, - model: models.Model, - messages: list[messages_.Message], - params: models.GenerateParams, - *, - provider: str, - ) -> messages_.Message: - _ = client, model, messages, params, provider - return sentinel + def build(self) -> models.ProviderProtocol[Any]: + return _GenerateOverrideProtocol() + +async def test_generate_uses_model_protocol() -> None: result = await models.generate( - MOCK_MODEL.with_protocol(OverrideProtocol()), + MOCK_MODEL.with_protocol(_GenerateOverrideProtocolRef()), [ai.user_message("A cat")], models.ImageParams(n=1), ) - assert result is sentinel + assert result is _GENERATED_IMAGE class _CheckProvider(MockProvider): @@ -377,7 +406,9 @@ async def probe(self, model: models.Model) -> None: async def test_probe_delegates_to_model_provider() -> None: provider = _CheckProvider() - model = models.Model("mock-model", provider=provider) + model = models.Model( + "mock-model", provider_ref=_StaticProviderRef(provider) + ) await models.probe(model) diff --git a/tests/models/core/test_model.py b/tests/models/core/test_model.py new file mode 100644 index 00000000..db87f797 --- /dev/null +++ b/tests/models/core/test_model.py @@ -0,0 +1,205 @@ +"""Tests for ``Model`` serialization and lazy provider construction.""" + +from __future__ import annotations + +import json +from typing import Any + +import pydantic +import pytest +from pydantic_core import PydanticSerializationError + +import ai +from ai import models +from ai.providers.openai import OpenAIChatCompletionsProtocol + +from ...conftest import MOCK_PROVIDER, MockProvider + + +class _FreshMockProviderRef(models.ProviderRef): + def __init__(self) -> None: + super().__init__("mock") + + def build(self) -> MockProvider: + return MockProvider() + + +class _Box(pydantic.BaseModel): + model: models.Model + + +def test_model_dumps_provider_ref_as_json_data() -> None: + model = models.Model( + "mock-model", + provider=models.ProviderRef("mock", base_url="http://mock.test"), + ) + + data = model.model_dump(mode="json") + + assert data == { + "id": "mock-model", + "provider": { + "id": "mock", + "base_url": "http://mock.test", + }, + "protocol": None, + } + json.dumps(data) + + +def test_model_json_round_trip() -> None: + model = ai.get_model("openai:gpt-5") + + restored = models.Model.model_validate( + json.loads(json.dumps(model.model_dump(mode="json"))) + ) + + assert restored == model + assert hash(restored) == hash(model) + assert restored.provider_ref == models.ProviderRef( + "openai", model_id="gpt-5" + ) + + +def test_model_inside_pydantic_model_round_trips() -> None: + box = _Box(model=ai.get_model("openai:gpt-5")) + + restored = _Box.model_validate( + json.loads(json.dumps(box.model_dump(mode="json"))) + ) + + assert restored == box + + +def test_model_validate_does_not_resolve_provider() -> None: + model = models.Model.model_validate( + {"id": "test-model", "provider": {"id": "not-registered"}} + ) + + assert model.provider_ref.id == "not-registered" + with pytest.raises(ValueError, match="unknown provider"): + _ = model.provider + + +def test_get_model_round_trip_builds_equivalent_provider() -> None: + model = ai.get_model("openai:gpt-5") + restored = models.Model.model_validate(model.model_dump(mode="json")) + + assert restored == model + assert restored.provider.name == "openai" + + +def test_provider_is_built_lazily_and_cached() -> None: + model = models.Model("mock-model", provider="mock") + + assert model._provider_instance is None + assert model.provider is MOCK_PROVIDER + assert model.provider is MOCK_PROVIDER + + +async def test_aclose_drops_cached_provider() -> None: + model = models.Model("mock-model", provider_ref=_FreshMockProviderRef()) + provider = model.provider + + await model.aclose() + + assert model._provider_instance is None + assert model.provider is not provider + + +def test_with_protocol_round_trips_and_shares_provider() -> None: + model = models.Model("mock-model", provider="mock") + provider = model.provider + + override = model.with_protocol("openai.chat_completions") + + assert override.provider is provider + assert isinstance(override.protocol, OpenAIChatCompletionsProtocol) + + restored = models.Model.model_validate(override.model_dump(mode="json")) + assert isinstance(restored.protocol, OpenAIChatCompletionsProtocol) + + +def test_accepts_string_provider_ref() -> None: + model = models.Model("mock-model", provider="mock") + + assert model.provider_ref == models.ProviderRef("mock") + + +def test_accepts_string_protocol_ref() -> None: + model = models.Model( + "mock-model", + provider="mock", + protocol="openai.chat_completions", + ) + + assert model.protocol_ref == models.ProtocolRef("openai.chat_completions") + + +def test_unknown_protocol_fails_on_access_not_validation() -> None: + model = models.Model.model_validate( + { + "id": "mock-model", + "provider": {"id": "mock"}, + "protocol": {"name": "not-real"}, + } + ) + + with pytest.raises(ai.ConfigurationError, match="unknown protocol"): + _ = model.protocol + + +def test_custom_provider_ref_works_in_process_but_does_not_dump() -> None: + model = models.Model("mock-model", provider_ref=_FreshMockProviderRef()) + + assert isinstance(model.provider, MockProvider) + assert model.serializable is False + with pytest.raises(PydanticSerializationError, match="provider refs"): + model.model_dump() + + +def test_serializable_is_true_for_plain_refs() -> None: + assert ai.get_model("openai:gpt-5").serializable is True + assert models.Model("mock-model", provider="mock").serializable is True + + +def test_provider_ref_rejects_bad_header_type() -> None: + with pytest.raises(pydantic.ValidationError): + models.Model.model_validate( + { + "id": "mock-model", + "provider": { + "id": "mock", + "headers": {"x-test": object()}, + }, + } + ) + + +def test_factory_returning_non_provider_fails_on_access() -> None: + class NotAProviderRef(models.ProviderRef): + def __init__(self) -> None: + super().__init__("mock") + + def build(self) -> Any: + return "nope" + + model = models.Model("mock-model", provider_ref=NotAProviderRef()) + + with pytest.raises(ai.ConfigurationError, match="expected a Provider"): + _ = model.provider + + +def test_equality_ignores_cached_provider_instance() -> None: + a = models.Model("mock-model", provider="mock") + b = models.Model("mock-model", provider="mock") + _ = a.provider + + assert a == b + assert hash(a) == hash(b) + + c = models.Model( + "mock-model", + provider=models.ProviderRef("mock", base_url="http://other.test"), + ) + assert a != c diff --git a/tests/models/test_resolution.py b/tests/models/test_resolution.py index 98a286f3..53c24512 100644 --- a/tests/models/test_resolution.py +++ b/tests/models/test_resolution.py @@ -171,10 +171,10 @@ def test_get_rejects_empty_model_id() -> None: def test_get_model_accepts_model_protocol_override() -> None: - protocol = OpenAIChatCompletionsProtocol() - model = models.get_model("openai:gpt-5", protocol=protocol) + model = models.get_model("openai:gpt-5", protocol="openai.chat_completions") - assert model.protocol is protocol + assert isinstance(model.protocol, OpenAIChatCompletionsProtocol) + assert model.protocol is model.protocol # built once, cached assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) diff --git a/tests/providers/ai_gateway/conftest.py b/tests/providers/ai_gateway/conftest.py index c3cba9d2..1bd77902 100644 --- a/tests/providers/ai_gateway/conftest.py +++ b/tests/providers/ai_gateway/conftest.py @@ -6,10 +6,24 @@ from typing import Any import httpx +import pydantic import ai from ai.types import messages +_BASE_URL = "https://gw.test/v3/ai" + + +class _GatewayProviderRef(ai.ProviderRef): + _provider: ai.Provider[Any] = pydantic.PrivateAttr() + + def __init__(self, provider: ai.Provider[Any]) -> None: + super().__init__("vercel") + self._provider = provider + + def build(self) -> ai.Provider[Any]: + return self._provider + def sse(*events: dict[str, Any]) -> str: """Build SSE response text from event dicts.""" @@ -22,14 +36,18 @@ def mock_model( model_id: str = "test-provider/test-model", api_key: str = "test-key", ) -> ai.Model: - """Create a Gateway model wired to a mock transport.""" + """Create a Gateway model wired to a mock transport. + + Per-test handlers are live objects, so this uses a test-only provider + ref. It never crosses a JSON boundary in these tests. + """ provider = ai.get_provider( "vercel", - base_url="https://gw.test/v3/ai", + base_url=_BASE_URL, api_key=api_key, client=httpx.AsyncClient(transport=handler), ) - return ai.Model(model_id, provider=provider) + return ai.Model(model_id, provider_ref=_GatewayProviderRef(provider)) mock_client = mock_model diff --git a/tests/providers/ai_gateway/test_probe.py b/tests/providers/ai_gateway/test_probe.py index 2803c914..7ca82170 100644 --- a/tests/providers/ai_gateway/test_probe.py +++ b/tests/providers/ai_gateway/test_probe.py @@ -12,6 +12,47 @@ _MODEL_ID = "anthropic/claude-opus-4-6" +class _ProbeProviderRef(ai.ProviderRef): + credits_status: int = 200 + config_status: int = 200 + config_body: dict[str, Any] | None = None + api_key: str | None = "sk-test-key" + + def __init__( + self, + *, + credits_status: int = 200, + config_status: int = 200, + config_body: dict[str, Any] | None = None, + api_key: str | None = "sk-test-key", + ) -> None: + super().__init__( + "vercel", + credits_status=credits_status, + config_status=config_status, + config_body=config_body, + api_key=api_key, + ) + + def build(self) -> ai.Provider[Any]: + credits_body = json.dumps({"balance": "10.00", "totalUsed": "5.00"}) + config_bytes = json.dumps(self.config_body or {"models": []}).encode() + + def _handler(request: httpx.Request) -> httpx.Response: + if "/v1/credits" in str(request.url): + return httpx.Response( + self.credits_status, content=credits_body.encode() + ) + return httpx.Response(self.config_status, content=config_bytes) + + return ai.get_provider( + "vercel", + base_url="https://gateway.test/v3/ai", + api_key=self.api_key, + client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + ) + + def _gateway_client( *, credits_status: int = 200, @@ -19,21 +60,15 @@ def _gateway_client( config_body: dict[str, Any] | None = None, api_key: str | None = "sk-test-key", ) -> ai.Model: - credits_body = json.dumps({"balance": "10.00", "totalUsed": "5.00"}) - config_bytes = json.dumps(config_body or {"models": []}).encode() - - def _handler(request: httpx.Request) -> httpx.Response: - if "/v1/credits" in str(request.url): - return httpx.Response(credits_status, content=credits_body.encode()) - return httpx.Response(config_status, content=config_bytes) - - provider = ai.get_provider( - "vercel", - base_url="https://gateway.test/v3/ai", - api_key=api_key, - client=httpx.AsyncClient(transport=httpx.MockTransport(_handler)), + return ai.Model( + _MODEL_ID, + provider_ref=_ProbeProviderRef( + credits_status=credits_status, + config_status=config_status, + config_body=config_body, + api_key=api_key, + ), ) - return ai.Model(_MODEL_ID, provider=provider) async def test_auth_ok_model_present_succeeds() -> None: diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 15c4b59d..c7abc141 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -45,7 +45,7 @@ def _patch_client( return cast("anthropic.AsyncAnthropic", fake), captured -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _drain(stream: Any) -> None: diff --git a/tests/providers/anthropic/test_probe.py b/tests/providers/anthropic/test_probe.py index e9e92884..9e5ea67b 100644 --- a/tests/providers/anthropic/test_probe.py +++ b/tests/providers/anthropic/test_probe.py @@ -8,34 +8,60 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, ClassVar import httpx +import pydantic import pytest import ai from ai.providers.anthropic import AnthropicCompatibleProvider +class _ProbeProviderRef(ai.ProviderRef): + status_code: int = 200 + json_body: dict[str, Any] | None = None + + def __init__( + self, + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://anthropic.test", + ) -> None: + super().__init__( + "anthropic", + status_code=status_code, + json_body=json_body, + base_url=base_url, + ) + + def build(self) -> ai.Provider[Any]: + def _handler(request: httpx.Request) -> httpx.Response: + _ = request + body = json.dumps(self.json_body or {}).encode() + return httpx.Response(self.status_code, content=body) + + assert self.base_url is not None + return ai.get_provider( + "anthropic", + base_url=self.base_url, + api_key="sk-test-key", + client=httpx.AsyncClient( + base_url=self.base_url, + transport=httpx.MockTransport(_handler), + ), + ) + + def _client_with_mock( status_code: int = 200, - json_body: Any = None, + json_body: dict[str, Any] | None = None, base_url: str = "https://anthropic.test", ) -> ai.Model: - def _handler(request: httpx.Request) -> httpx.Response: - body = json.dumps(json_body or {}).encode() - return httpx.Response(status_code, content=body) - - provider = ai.get_provider( - "anthropic", - base_url=base_url, - api_key="sk-test-key", - client=httpx.AsyncClient( - base_url=base_url, - transport=httpx.MockTransport(_handler), - ), + return ai.Model( + "claude-opus-4-6", + provider_ref=_ProbeProviderRef(status_code, json_body, base_url), ) - return ai.Model("claude-opus-4-6", provider=provider) async def test_200_succeeds() -> None: @@ -51,27 +77,48 @@ async def test_model_not_found_raises_model_not_found() -> None: assert exc_info.value.model_id == model.id +class _HeaderCaptureProvider(AnthropicCompatibleProvider): + """Custom provider that records request headers.""" + + handles: ClassVar[tuple[str, ...]] = () + + def __init__(self) -> None: + self.captured_headers: dict[str, str] = {} + + def _handler(request: httpx.Request) -> httpx.Response: + self.captured_headers.update(dict(request.headers)) + body = json.dumps({"id": "custom-model", "type": "model"}).encode() + return httpx.Response(200, content=body) + + super().__init__( + name="custom-anthropic", + default_base_url="https://anthropic.test", + api_key="sk-test-key", + anthropic_version="2024-01-01", + headers={"X-Custom-Header": "example"}, + client=httpx.AsyncClient( + base_url="https://anthropic.test", + transport=httpx.MockTransport(_handler), + ), + ) + + +class _HeaderCaptureProviderRef(ai.ProviderRef): + _provider: _HeaderCaptureProvider = pydantic.PrivateAttr() + + def __init__(self) -> None: + super().__init__("anthropic") + self._provider = _HeaderCaptureProvider() + + def build(self) -> _HeaderCaptureProvider: + return self._provider + + async def test_custom_anthropic_version_header() -> None: - captured_headers: dict[str, str] = {} - - def _handler(request: httpx.Request) -> httpx.Response: - captured_headers.update(dict(request.headers)) - body = json.dumps({"id": "custom-model", "type": "model"}).encode() - return httpx.Response(200, content=body) - - provider = AnthropicCompatibleProvider( - name="custom-anthropic", - default_base_url="https://anthropic.test", - api_key="sk-test-key", - anthropic_version="2024-01-01", - headers={"X-Custom-Header": "example"}, - client=httpx.AsyncClient( - base_url="https://anthropic.test", - transport=httpx.MockTransport(_handler), - ), - ) + model = ai.Model("custom-model", provider_ref=_HeaderCaptureProviderRef()) - model = ai.Model("custom-model", provider=provider) + provider = model.provider + assert isinstance(provider, _HeaderCaptureProvider) await provider.probe(model) - assert captured_headers["anthropic-version"] == "2024-01-01" - assert captured_headers["x-custom-header"] == "example" + assert provider.captured_headers["anthropic-version"] == "2024-01-01" + assert provider.captured_headers["x-custom-header"] == "example" diff --git a/tests/providers/anthropic/test_provider.py b/tests/providers/anthropic/test_provider.py index 1692b48d..41ba00a5 100644 --- a/tests/providers/anthropic/test_provider.py +++ b/tests/providers/anthropic/test_provider.py @@ -181,15 +181,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: headers={"X-Custom-Header": "example"}, ) - model = ai.Model("custom-model", provider=provider) assert repr(provider) == "anthropic" assert isinstance(provider.protocol, AnthropicMessagesProtocol) assert provider.base_url == "https://custom.example.com" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True - assert model.id == "custom-model" - assert model.provider is provider def test_get_provider_env_overrides_base_url_env() -> None: diff --git a/tests/providers/anthropic/test_stream.py b/tests/providers/anthropic/test_stream.py index ad0400dc..adfd2ee7 100644 --- a/tests/providers/anthropic/test_stream.py +++ b/tests/providers/anthropic/test_stream.py @@ -28,7 +28,7 @@ snapshot_block, ) -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _drain( diff --git a/tests/providers/anthropic/test_tools.py b/tests/providers/anthropic/test_tools.py index 7c1e5008..a9f0fe33 100644 --- a/tests/providers/anthropic/test_tools.py +++ b/tests/providers/anthropic/test_tools.py @@ -23,7 +23,7 @@ from .conftest import FakeAnthropicClient -_MODEL = ai.Model("claude-sonnet-4-6", provider=ai.get_provider("anthropic")) +_MODEL = ai.get_model("anthropic:claude-sonnet-4-6") async def _capture_tools( diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index 30d6ba86..d8db6810 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -111,10 +111,7 @@ async def close(self) -> None: self.closed = True -_MODEL = ai.Model( - "gpt-5.4", - provider=ai.get_provider("openai", api_key="sk-test"), -) +_MODEL = ai.get_model("openai:gpt-5.4") def _patch( diff --git a/tests/providers/openai/test_probe.py b/tests/providers/openai/test_probe.py index 437067a8..64821423 100644 --- a/tests/providers/openai/test_probe.py +++ b/tests/providers/openai/test_probe.py @@ -9,25 +9,50 @@ import ai +class _ProbeProviderRef(ai.ProviderRef): + status_code: int = 200 + json_body: dict[str, Any] | None = None + + def __init__( + self, + status_code: int = 200, + json_body: dict[str, Any] | None = None, + base_url: str = "https://openai.test/v1", + ) -> None: + super().__init__( + "openai", + status_code=status_code, + json_body=json_body, + base_url=base_url, + ) + + def build(self) -> ai.Provider[Any]: + def _handler(request: httpx.Request) -> httpx.Response: + _ = request + body = json.dumps(self.json_body or {}).encode() + return httpx.Response(self.status_code, content=body) + + assert self.base_url is not None + return ai.get_provider( + "openai", + base_url=self.base_url, + api_key="sk-test-key", + client=httpx.AsyncClient( + base_url=self.base_url, + transport=httpx.MockTransport(_handler), + ), + ) + + def _client_with_mock( status_code: int = 200, - json_body: Any = None, + json_body: dict[str, Any] | None = None, base_url: str = "https://openai.test/v1", ) -> ai.Model: - def _handler(request: httpx.Request) -> httpx.Response: - body = json.dumps(json_body or {}).encode() - return httpx.Response(status_code, content=body) - - provider = ai.get_provider( - "openai", - base_url=base_url, - api_key="sk-test-key", - client=httpx.AsyncClient( - base_url=base_url, - transport=httpx.MockTransport(_handler), - ), + return ai.Model( + "gpt-5.4", + provider_ref=_ProbeProviderRef(status_code, json_body, base_url), ) - return ai.Model("gpt-5.4", provider=provider) async def test_200_succeeds() -> None: @@ -63,7 +88,6 @@ async def test_no_api_key_raises_not_configured( ) -> None: monkeypatch.delenv("OPENAI_API_KEY", raising=False) - provider = ai.get_provider("openai", base_url="https://openai.test/v1") - model = ai.Model("gpt-5.4", provider=provider) + model = ai.get_model("openai:gpt-5.4") with pytest.raises(ai.ProviderNotConfiguredError): - await provider.probe(model) + await model.provider.probe(model) diff --git a/tests/providers/openai/test_provider.py b/tests/providers/openai/test_provider.py index b7ff6429..8467815f 100644 --- a/tests/providers/openai/test_provider.py +++ b/tests/providers/openai/test_provider.py @@ -210,15 +210,12 @@ def test_get_provider_accepts_base_url_and_api_key() -> None: headers={"X-Custom-Header": "example"}, ) - model = ai.Model("custom-model", provider=provider) assert repr(provider) == "openai" assert isinstance(provider.protocol, OpenAIResponsesProtocol) assert provider.base_url == "https://custom.example.com/v1" assert provider.api_key == "sk-custom" assert provider.headers == {"X-Custom-Header": "example"} assert provider.is_configured() is True - assert model.id == "custom-model" - assert model.provider is provider assert isinstance(provider, ai.Provider)