From 9d725e2b31c8d757a8ac169659dfd6793f262704 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 17 Jun 2026 15:40:29 -0700 Subject: [PATCH 1/2] Rework models api to accept either provider or provider_name --- examples/temporal-direct/main.py | 28 +--- src/ai/models/core/model.py | 276 ++++++++++++++++++++++++++++--- tests/models/test_resolution.py | 101 ++++++++++- 3 files changed, 350 insertions(+), 55 deletions(-) diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 7e30a462..7276cac2 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -49,26 +49,6 @@ MODEL_ID = "gateway:anthropic/claude-sonnet-4.6" -# ── Workflow-safe model placeholder ────────────────────────────── -# -# ``agent.run`` requires a ``Model``, but a real one can't be built -# inside the workflow: ``ai.get_model("gateway:...")`` constructs an -# ``httpx.AsyncClient`` at provider-init time, which imports -# httpcore/anyio and trips the Temporal sandbox (``threading.local`` -# at module load). Our loop never calls the model directly anyway -- -# every LLM call is delegated to ``llm_call_activity``, which runs -# outside the sandbox and resolves the real model by id there. -# -# So hand the workflow a placeholder ``Model`` whose provider builds -# no client. It carries the real model id (so the activity can -# resolve it) but is safe to construct inside the sandbox. -class WorkflowModelProvider(ai.Provider[Any]): - """A clientless provider, safe to construct in a workflow sandbox.""" - - def __init__(self) -> None: - super().__init__(name="workflow-placeholder", base_url="") - - # ── Tool definitions ───────────────────────────────────────────── # # Declared with @ai.tool so the framework can extract JSON schemas @@ -116,7 +96,7 @@ async def get_population_activity(city: str) -> int: @dataclasses.dataclass class LLMParams: - model_id: str + model: dict[str, Any] messages: list[dict[str, Any]] tool_schemas: list[dict[str, Any]] @@ -129,7 +109,7 @@ class LLMResult: @temporalio.activity.defn async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" - model = ai.get_model(params.model_id) + model = ai.Model.model_validate(params.model) messages = [ai.messages.Message.model_validate(m) for m in params.messages] tools = [ ai.Tool( @@ -172,7 +152,7 @@ async def loop( result = await temporalio.workflow.execute_activity( llm_call_activity, LLMParams( - model_id=context.model.id, + model=context.model.model_dump(mode="json"), messages=[m.model_dump() for m in context.messages], tool_schemas=tool_schemas, ), @@ -238,7 +218,7 @@ async def _call() -> ai.events.ToolCallResult: class WeatherWorkflow: @temporalio.workflow.run async def run(self, user_query: str) -> str: - model = ai.Model(MODEL_ID, provider=WorkflowModelProvider()) + model = ai.get_model(MODEL_ID) messages: list[ai.messages.Message] = [ ai.system_message( "Answer questions using the weather and population tools." diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index 5ff6fd5d..ebc2caf8 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -1,47 +1,166 @@ """Model metadata types.""" +from __future__ import annotations + +import json import os -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Literal, Self, cast, overload + +import pydantic from ... import _modelsdev from ...errors import ConfigurationError from ...providers import base +if TYPE_CHECKING: + from collections.abc import Callable + + import modelsdotdev + _DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL" -class Model: +class Model(pydantic.BaseModel): """Lightweight reference to a model on a specific provider. * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). - * ``provider`` — :class:`Provider` that owns this model. - * ``protocol`` — optional wire-protocol override for this model. + * ``provider_name`` — models.dev provider id used to rebuild the provider. + * ``provider_args`` — JSON-friendly provider configuration. + + Passing a live ``provider`` makes the model non-serializable. """ + id: str + provider_name: str | None = None + provider_args: dict[str, Any] = pydantic.Field(default_factory=dict) + + _provider: base.Provider[Any] | None = pydantic.PrivateAttr(default=None) + _protocol: base.ProviderProtocol[Any] | None = pydantic.PrivateAttr( + default=None + ) + _is_serializable: bool = pydantic.PrivateAttr(default=True) + + model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + @overload + def __init__( + self, + id: str, + *, + provider_name: str, + provider_args: dict[str, Any] | None = None, + ) -> None: ... + + @overload + def __init__( + self, + id: str, + *, + provider: base.Provider[Any], + protocol: base.ProviderProtocol[Any] | None = None, + ) -> None: ... + def __init__( self, id: str, *, - provider: base.Provider, + provider_name: str | None = None, + provider_args: dict[str, Any] | None = None, + provider: base.Provider[Any] | None = None, protocol: base.ProviderProtocol[Any] | None = None, ) -> None: - self.id = id - self.provider = provider - self.protocol = protocol + if (provider is None) == (provider_name is None): + raise ConfigurationError( + "pass exactly one of provider_name or provider" + ) + if provider_name == "": + raise ConfigurationError("provider_name must not be empty") + + if provider is not None: + if provider_args is not None: + raise ConfigurationError("provider_args requires provider_name") + super().__init__( + id=id, + provider_name=provider.name, + provider_args={}, + ) + self._provider = provider + self._protocol = protocol + self._is_serializable = False + return + + if protocol is not None: + raise ConfigurationError( + "protocol objects are not serializable; " + "use provider=... live-object mode" + ) + + super().__init__( + id=id, + provider_name=provider_name, + provider_args=provider_args or {}, + ) def __eq__(self, other: object) -> bool: + if not isinstance(other, Model): + return False + if self._is_serializable and other._is_serializable: + return ( + self.id == other.id + and self.provider_name == other.provider_name + and self.provider_args == other.provider_args + ) return ( - isinstance(other, Model) + not self._is_serializable + and not other._is_serializable and self.id == other.id - and self.provider is other.provider - and self.protocol is other.protocol + and self._provider is other._provider + and self._protocol is other._protocol ) def __repr__(self) -> str: - return f"Model(id={self.id!r}, provider={self.provider!r})" + provider = ( + self._provider if self._provider is not None else self.provider_name + ) + return f"Model(id={self.id!r}, provider={provider!r})" def __hash__(self) -> int: - return hash((self.id, id(self.provider), id(self.protocol))) + if self._is_serializable: + return hash( + ( + self.id, + self.provider_name, + json.dumps(self.provider_args, sort_keys=True), + ) + ) + return hash((self.id, id(self._provider), id(self._protocol))) + + @property + def is_serializable(self) -> bool: + """Whether this model can be serialized as durable JSON data.""" + return self._is_serializable + + @property + def provider(self) -> base.Provider[Any]: + """Provider for this model, lazily rebuilt for durable models.""" + if self._provider is None: + if self.provider_name is None: + raise ConfigurationError("model has no provider_name") + self._provider = base.Provider.from_id( + self.provider_name, + model_provider_config=self._model_provider_config(), + **{ + key: value + for key, value in self.provider_args.items() + if value is not None + }, + ) + return self._provider + + @property + def protocol(self) -> base.ProviderProtocol[Any] | None: + """Optional wire-protocol override for this model.""" + return self._protocol def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: return self.__class__( @@ -50,6 +169,117 @@ def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: protocol=protocol, ) + def _model_provider_config( + self, + ) -> modelsdotdev.ModelProviderConfig | None: + if self.provider_name is None: + return None + model_info = _modelsdev.get_model_by_id( + f"{self.provider_name}:{self.id}" + ) + return None if model_info is None else model_info.provider_config + + @pydantic.field_serializer("provider_args") + def _serialize_provider_args( + self, provider_args: dict[str, Any] + ) -> dict[str, Any]: + return { + key: value + for key, value in provider_args.items() + if value is not None + } + + @pydantic.model_serializer(mode="wrap") + def _serialize_model( + self, + handler: pydantic.SerializerFunctionWrapHandler, + info: pydantic.SerializationInfo, + ) -> dict[str, Any]: + if info.mode == "json" and not self._is_serializable: + raise ConfigurationError( + "Model was constructed with a live provider/protocol and " + "cannot be serialized. Use provider_name/provider_args instead." + ) + return cast("dict[str, Any]", handler(self)) + + def model_dump( + self, + *, + mode: Literal["json", "python"] | str = "python", + include: Any = None, + exclude: Any = None, + context: Any | None = None, + by_alias: bool | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_computed_fields: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, + ) -> dict[str, Any]: + if mode == "json" and not self._is_serializable: + raise ConfigurationError( + "Model was constructed with a live provider/protocol and " + "cannot be serialized. Use provider_name/provider_args instead." + ) + return super().model_dump( + mode=mode, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_computed_fields=exclude_computed_fields, + round_trip=round_trip, + warnings=warnings, + fallback=fallback, + serialize_as_any=serialize_as_any, + ) + + def model_dump_json( + self, + *, + indent: int | None = None, + ensure_ascii: bool = False, + include: Any = None, + exclude: Any = None, + context: Any | None = None, + by_alias: bool | None = None, + exclude_unset: bool = False, + exclude_defaults: bool = False, + exclude_none: bool = False, + exclude_computed_fields: bool = False, + round_trip: bool = False, + warnings: bool | Literal["none", "warn", "error"] = True, + fallback: Callable[[Any], Any] | None = None, + serialize_as_any: bool = False, + ) -> str: + if not self._is_serializable: + raise ConfigurationError( + "Model was constructed with a live provider/protocol and " + "cannot be serialized. Use provider_name/provider_args instead." + ) + return super().model_dump_json( + indent=indent, + ensure_ascii=ensure_ascii, + include=include, + exclude=exclude, + context=context, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + exclude_computed_fields=exclude_computed_fields, + round_trip=round_trip, + warnings=warnings, + fallback=fallback, + serialize_as_any=serialize_as_any, + ) + def get_model( model_id: str | None = None, @@ -87,6 +317,12 @@ def get_model( if not model_id: raise ConfigurationError(f"get_model: malformed model_id: {model_id!r}") + if protocol is not None: + raise ConfigurationError( + "protocol objects are not serializable; " + "construct Model with provider=... live-object mode instead" + ) + if ":" not in model_id: model_id = f"gateway:{model_id}" @@ -95,16 +331,4 @@ def get_model( provider_id = ref.provider_id provider_model_id = ref.model_id - model_info = _modelsdev.get_model_by_id( - f"{provider_id}:{provider_model_id}" - ) - model_provider_config = ( - None if model_info is None else model_info.provider_config - ) - - provider = base.Provider.from_id( - provider_id, - model_provider_config=model_provider_config, - ) - - return Model(provider_model_id, provider=provider, protocol=protocol) + return Model(provider_model_id, provider_name=provider_id) diff --git a/tests/models/test_resolution.py b/tests/models/test_resolution.py index 98a286f3..ab345614 100644 --- a/tests/models/test_resolution.py +++ b/tests/models/test_resolution.py @@ -1,3 +1,5 @@ +from typing import Any, cast + import pytest import ai @@ -9,11 +11,99 @@ OpenAIResponsesProtocol, ) +from ..conftest import MockProvider + + +def test_model_durable_json_roundtrips() -> None: + model = ai.Model( + "gpt-5", + provider_name="openai", + provider_args={ + "api_key": None, + "base_url": None, + "headers": None, + "env": None, + }, + ) + + dumped = model.model_dump(mode="json") + + assert dumped == { + "id": "gpt-5", + "provider_name": "openai", + "provider_args": {}, + } + assert model.is_serializable is True + assert ai.Model.model_validate(dumped).model_dump(mode="json") == dumped + + +def test_model_durable_provider_args_serialize_without_none() -> None: + model = ai.Model( + "gpt-5", + provider_name="openai", + provider_args={ + "api_key": None, + "base_url": "https://example.test/v1", + "headers": {"x-test": "yes"}, + }, + ) + + assert model.model_dump(mode="json") == { + "id": "gpt-5", + "provider_name": "openai", + "provider_args": { + "base_url": "https://example.test/v1", + "headers": {"x-test": "yes"}, + }, + } + + +def test_model_live_provider_rejects_json_serialization() -> None: + model = ai.Model("mock-model", provider=MockProvider()) + + assert model.is_serializable is False + with pytest.raises(ConfigurationError, match="live provider/protocol"): + model.model_dump(mode="json") + with pytest.raises(ConfigurationError, match="live provider/protocol"): + model.model_dump_json() + + +def test_model_rejects_invalid_constructor_pairs() -> None: + provider = MockProvider() + model = cast("Any", ai.Model) + + with pytest.raises(ConfigurationError, match="exactly one"): + model("mock-model") + with pytest.raises(ConfigurationError, match="exactly one"): + model("mock-model", provider=provider, provider_name="mock") + with pytest.raises(ConfigurationError, match="provider_args"): + model("mock-model", provider=provider, provider_args={}) + with pytest.raises(ConfigurationError, match="protocol objects"): + model( + "mock-model", + provider_name="mock", + protocol=OpenAIChatCompletionsProtocol(), + ) + + +def test_model_durable_provider_is_cached() -> None: + model = ai.Model("gpt-5", provider_name="openai") + + provider = model.provider + + assert model.provider is provider + def test_get_resolves_provider_qualified_model_id() -> None: model = ai.get_model("openai:gpt-5") assert model.id == "gpt-5" + assert model.model_dump(mode="json") == { + "id": "gpt-5", + "provider_name": "openai", + "provider_args": {}, + } + assert model.is_serializable is True assert model.provider.name == "openai" assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) @@ -161,8 +251,10 @@ def test_provider_from_id_rejects_unsupported_provider_package() -> None: def test_get_rejects_unsupported_provider_package() -> None: + model = models.get_model("google:gemini-2.5-pro") + with pytest.raises(ai.errors.UnsupportedProviderError): - models.get_model("google:gemini-2.5-pro") + _ = model.provider def test_get_rejects_empty_model_id() -> None: @@ -170,12 +262,11 @@ def test_get_rejects_empty_model_id() -> None: models.get_model("") -def test_get_model_accepts_model_protocol_override() -> None: +def test_get_model_rejects_model_protocol_override() -> None: protocol = OpenAIChatCompletionsProtocol() - model = models.get_model("openai:gpt-5", protocol=protocol) - assert model.protocol is protocol - assert isinstance(model.provider.protocol, OpenAIResponsesProtocol) + with pytest.raises(ConfigurationError, match="protocol objects"): + models.get_model("openai:gpt-5", protocol=protocol) def test_get_provider_accepts_provider_protocol_override() -> None: From 9220f5a6a6bd7158228739666c6bb48e8b77e407 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Wed, 17 Jun 2026 16:12:08 -0700 Subject: [PATCH 2/2] Fix None values making it into serialized model representation --- src/ai/models/core/model.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/src/ai/models/core/model.py b/src/ai/models/core/model.py index ebc2caf8..ef1af27f 100644 --- a/src/ai/models/core/model.py +++ b/src/ai/models/core/model.py @@ -140,6 +140,17 @@ def is_serializable(self) -> bool: """Whether this model can be serialized as durable JSON data.""" return self._is_serializable + @pydantic.field_validator("provider_args", mode="after") + @classmethod + def _normalize_provider_args( + cls, provider_args: dict[str, Any] + ) -> dict[str, Any]: + return { + key: value + for key, value in provider_args.items() + if value is not None + } + @property def provider(self) -> base.Provider[Any]: """Provider for this model, lazily rebuilt for durable models.""" @@ -179,16 +190,6 @@ def _model_provider_config( ) return None if model_info is None else model_info.provider_config - @pydantic.field_serializer("provider_args") - def _serialize_provider_args( - self, provider_args: dict[str, Any] - ) -> dict[str, Any]: - return { - key: value - for key, value in provider_args.items() - if value is not None - } - @pydantic.model_serializer(mode="wrap") def _serialize_model( self,