-
Notifications
You must be signed in to change notification settings - Fork 10
Implement a provider / provider_name version of serializable models #176
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+351
−55
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,47 +1,177 @@ | ||
| """Model metadata types.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import json | ||
| import os | ||
| from typing import Any, Self | ||
| from typing import TYPE_CHECKING, Any, Literal, Self, cast, overload | ||
|
|
||
| import pydantic | ||
|
|
||
| from ... import _modelsdev | ||
| from ...errors import ConfigurationError | ||
| from ...providers import base | ||
|
|
||
| if TYPE_CHECKING: | ||
| from collections.abc import Callable | ||
|
|
||
| import modelsdotdev | ||
|
|
||
| _DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL" | ||
|
|
||
|
|
||
| class Model: | ||
| class Model(pydantic.BaseModel): | ||
| """Lightweight reference to a model on a specific provider. | ||
|
|
||
| * ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``). | ||
| * ``provider`` — :class:`Provider` that owns this model. | ||
| * ``protocol`` — optional wire-protocol override for this model. | ||
| * ``provider_name`` — models.dev provider id used to rebuild the provider. | ||
| * ``provider_args`` — JSON-friendly provider configuration. | ||
|
|
||
| Passing a live ``provider`` makes the model non-serializable. | ||
| """ | ||
|
|
||
| id: str | ||
| provider_name: str | None = None | ||
| provider_args: dict[str, Any] = pydantic.Field(default_factory=dict) | ||
|
|
||
| _provider: base.Provider[Any] | None = pydantic.PrivateAttr(default=None) | ||
| _protocol: base.ProviderProtocol[Any] | None = pydantic.PrivateAttr( | ||
| default=None | ||
| ) | ||
| _is_serializable: bool = pydantic.PrivateAttr(default=True) | ||
|
|
||
| model_config = pydantic.ConfigDict(arbitrary_types_allowed=True) | ||
|
|
||
| @overload | ||
| def __init__( | ||
| self, | ||
| id: str, | ||
| *, | ||
| provider_name: str, | ||
| provider_args: dict[str, Any] | None = None, | ||
| ) -> None: ... | ||
|
|
||
| @overload | ||
| def __init__( | ||
| self, | ||
| id: str, | ||
| *, | ||
| provider: base.Provider[Any], | ||
| protocol: base.ProviderProtocol[Any] | None = None, | ||
| ) -> None: ... | ||
|
|
||
| def __init__( | ||
| self, | ||
| id: str, | ||
| *, | ||
| provider: base.Provider, | ||
| provider_name: str | None = None, | ||
| provider_args: dict[str, Any] | None = None, | ||
| provider: base.Provider[Any] | None = None, | ||
| protocol: base.ProviderProtocol[Any] | None = None, | ||
| ) -> None: | ||
| self.id = id | ||
| self.provider = provider | ||
| self.protocol = protocol | ||
| if (provider is None) == (provider_name is None): | ||
| raise ConfigurationError( | ||
| "pass exactly one of provider_name or provider" | ||
| ) | ||
| if provider_name == "": | ||
| raise ConfigurationError("provider_name must not be empty") | ||
|
|
||
| if provider is not None: | ||
| if provider_args is not None: | ||
| raise ConfigurationError("provider_args requires provider_name") | ||
| super().__init__( | ||
| id=id, | ||
| provider_name=provider.name, | ||
| provider_args={}, | ||
| ) | ||
| self._provider = provider | ||
| self._protocol = protocol | ||
| self._is_serializable = False | ||
| return | ||
|
|
||
| if protocol is not None: | ||
| raise ConfigurationError( | ||
| "protocol objects are not serializable; " | ||
| "use provider=... live-object mode" | ||
| ) | ||
|
|
||
| super().__init__( | ||
| id=id, | ||
| provider_name=provider_name, | ||
| provider_args=provider_args or {}, | ||
| ) | ||
|
|
||
| def __eq__(self, other: object) -> bool: | ||
| if not isinstance(other, Model): | ||
| return False | ||
| if self._is_serializable and other._is_serializable: | ||
| return ( | ||
| self.id == other.id | ||
| and self.provider_name == other.provider_name | ||
| and self.provider_args == other.provider_args | ||
| ) | ||
| return ( | ||
| isinstance(other, Model) | ||
| not self._is_serializable | ||
| and not other._is_serializable | ||
| and self.id == other.id | ||
| and self.provider is other.provider | ||
| and self.protocol is other.protocol | ||
| and self._provider is other._provider | ||
| and self._protocol is other._protocol | ||
| ) | ||
|
|
||
| def __repr__(self) -> str: | ||
| return f"Model(id={self.id!r}, provider={self.provider!r})" | ||
| provider = ( | ||
| self._provider if self._provider is not None else self.provider_name | ||
| ) | ||
| return f"Model(id={self.id!r}, provider={provider!r})" | ||
|
|
||
| def __hash__(self) -> int: | ||
| return hash((self.id, id(self.provider), id(self.protocol))) | ||
| if self._is_serializable: | ||
| return hash( | ||
| ( | ||
| self.id, | ||
| self.provider_name, | ||
| json.dumps(self.provider_args, sort_keys=True), | ||
| ) | ||
| ) | ||
| return hash((self.id, id(self._provider), id(self._protocol))) | ||
|
|
||
| @property | ||
| def is_serializable(self) -> bool: | ||
| """Whether this model can be serialized as durable JSON data.""" | ||
| return self._is_serializable | ||
|
|
||
| @pydantic.field_validator("provider_args", mode="after") | ||
| @classmethod | ||
| def _normalize_provider_args( | ||
| cls, provider_args: dict[str, Any] | ||
| ) -> dict[str, Any]: | ||
| return { | ||
| key: value | ||
| for key, value in provider_args.items() | ||
| if value is not None | ||
| } | ||
|
|
||
| @property | ||
| def provider(self) -> base.Provider[Any]: | ||
| """Provider for this model, lazily rebuilt for durable models.""" | ||
| if self._provider is None: | ||
| if self.provider_name is None: | ||
| raise ConfigurationError("model has no provider_name") | ||
| self._provider = base.Provider.from_id( | ||
| self.provider_name, | ||
| model_provider_config=self._model_provider_config(), | ||
| **{ | ||
| key: value | ||
| for key, value in self.provider_args.items() | ||
| if value is not None | ||
| }, | ||
| ) | ||
| return self._provider | ||
|
|
||
| @property | ||
| def protocol(self) -> base.ProviderProtocol[Any] | None: | ||
| """Optional wire-protocol override for this model.""" | ||
| return self._protocol | ||
|
|
||
| def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: | ||
| return self.__class__( | ||
|
|
@@ -50,6 +180,107 @@ def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self: | |
| protocol=protocol, | ||
| ) | ||
|
|
||
| def _model_provider_config( | ||
| self, | ||
| ) -> modelsdotdev.ModelProviderConfig | None: | ||
| if self.provider_name is None: | ||
| return None | ||
| model_info = _modelsdev.get_model_by_id( | ||
| f"{self.provider_name}:{self.id}" | ||
| ) | ||
| return None if model_info is None else model_info.provider_config | ||
|
|
||
| @pydantic.model_serializer(mode="wrap") | ||
| def _serialize_model( | ||
| self, | ||
| handler: pydantic.SerializerFunctionWrapHandler, | ||
| info: pydantic.SerializationInfo, | ||
| ) -> dict[str, Any]: | ||
| if info.mode == "json" and not self._is_serializable: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh, why only json mode? |
||
| raise ConfigurationError( | ||
| "Model was constructed with a live provider/protocol and " | ||
| "cannot be serialized. Use provider_name/provider_args instead." | ||
| ) | ||
| return cast("dict[str, Any]", handler(self)) | ||
|
|
||
| def model_dump( | ||
| self, | ||
| *, | ||
| mode: Literal["json", "python"] | str = "python", | ||
| include: Any = None, | ||
| exclude: Any = None, | ||
| context: Any | None = None, | ||
| by_alias: bool | None = None, | ||
| exclude_unset: bool = False, | ||
| exclude_defaults: bool = False, | ||
| exclude_none: bool = False, | ||
| exclude_computed_fields: bool = False, | ||
| round_trip: bool = False, | ||
| warnings: bool | Literal["none", "warn", "error"] = True, | ||
| fallback: Callable[[Any], Any] | None = None, | ||
| serialize_as_any: bool = False, | ||
| ) -> dict[str, Any]: | ||
| if mode == "json" and not self._is_serializable: | ||
| raise ConfigurationError( | ||
| "Model was constructed with a live provider/protocol and " | ||
| "cannot be serialized. Use provider_name/provider_args instead." | ||
| ) | ||
| return super().model_dump( | ||
| mode=mode, | ||
| include=include, | ||
| exclude=exclude, | ||
| context=context, | ||
| by_alias=by_alias, | ||
| exclude_unset=exclude_unset, | ||
| exclude_defaults=exclude_defaults, | ||
| exclude_none=exclude_none, | ||
| exclude_computed_fields=exclude_computed_fields, | ||
| round_trip=round_trip, | ||
| warnings=warnings, | ||
| fallback=fallback, | ||
| serialize_as_any=serialize_as_any, | ||
| ) | ||
|
|
||
| def model_dump_json( | ||
| self, | ||
| *, | ||
| indent: int | None = None, | ||
| ensure_ascii: bool = False, | ||
| include: Any = None, | ||
| exclude: Any = None, | ||
| context: Any | None = None, | ||
| by_alias: bool | None = None, | ||
| exclude_unset: bool = False, | ||
| exclude_defaults: bool = False, | ||
| exclude_none: bool = False, | ||
| exclude_computed_fields: bool = False, | ||
| round_trip: bool = False, | ||
| warnings: bool | Literal["none", "warn", "error"] = True, | ||
| fallback: Callable[[Any], Any] | None = None, | ||
| serialize_as_any: bool = False, | ||
| ) -> str: | ||
| if not self._is_serializable: | ||
| raise ConfigurationError( | ||
| "Model was constructed with a live provider/protocol and " | ||
| "cannot be serialized. Use provider_name/provider_args instead." | ||
| ) | ||
| return super().model_dump_json( | ||
| indent=indent, | ||
| ensure_ascii=ensure_ascii, | ||
| include=include, | ||
| exclude=exclude, | ||
| context=context, | ||
| by_alias=by_alias, | ||
| exclude_unset=exclude_unset, | ||
| exclude_defaults=exclude_defaults, | ||
| exclude_none=exclude_none, | ||
| exclude_computed_fields=exclude_computed_fields, | ||
| round_trip=round_trip, | ||
| warnings=warnings, | ||
| fallback=fallback, | ||
| serialize_as_any=serialize_as_any, | ||
| ) | ||
|
Comment on lines
+206
to
+282
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? Shouldn't _serialize_model handle this? |
||
|
|
||
|
|
||
| def get_model( | ||
| model_id: str | None = None, | ||
|
|
@@ -87,6 +318,12 @@ def get_model( | |
| if not model_id: | ||
| raise ConfigurationError(f"get_model: malformed model_id: {model_id!r}") | ||
|
|
||
| if protocol is not None: | ||
| raise ConfigurationError( | ||
| "protocol objects are not serializable; " | ||
| "construct Model with provider=... live-object mode instead" | ||
| ) | ||
|
|
||
| if ":" not in model_id: | ||
| model_id = f"gateway:{model_id}" | ||
|
|
||
|
|
@@ -95,16 +332,4 @@ def get_model( | |
| provider_id = ref.provider_id | ||
| provider_model_id = ref.model_id | ||
|
|
||
| model_info = _modelsdev.get_model_by_id( | ||
| f"{provider_id}:{provider_model_id}" | ||
| ) | ||
| model_provider_config = ( | ||
| None if model_info is None else model_info.provider_config | ||
| ) | ||
|
|
||
| provider = base.Provider.from_id( | ||
| provider_id, | ||
| model_provider_config=model_provider_config, | ||
| ) | ||
|
|
||
| return Model(provider_model_id, provider=provider, protocol=protocol) | ||
| return Model(provider_model_id, provider_name=provider_id) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.