Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 4 additions & 24 deletions examples/temporal-direct/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,6 @@
MODEL_ID = "gateway:anthropic/claude-sonnet-4.6"


# ── Workflow-safe model placeholder ──────────────────────────────
#
# ``agent.run`` requires a ``Model``, but a real one can't be built
# inside the workflow: ``ai.get_model("gateway:...")`` constructs an
# ``httpx.AsyncClient`` at provider-init time, which imports
# httpcore/anyio and trips the Temporal sandbox (``threading.local``
# at module load). Our loop never calls the model directly anyway --
# every LLM call is delegated to ``llm_call_activity``, which runs
# outside the sandbox and resolves the real model by id there.
#
# So hand the workflow a placeholder ``Model`` whose provider builds
# no client. It carries the real model id (so the activity can
# resolve it) but is safe to construct inside the sandbox.
class WorkflowModelProvider(ai.Provider[Any]):
"""A clientless provider, safe to construct in a workflow sandbox."""

def __init__(self) -> None:
super().__init__(name="workflow-placeholder", base_url="")


# ── Tool definitions ─────────────────────────────────────────────
#
# Declared with @ai.tool so the framework can extract JSON schemas
Expand Down Expand Up @@ -116,7 +96,7 @@ async def get_population_activity(city: str) -> int:

@dataclasses.dataclass
class LLMParams:
model_id: str
model: dict[str, Any]
messages: list[dict[str, Any]]
tool_schemas: list[dict[str, Any]]

Expand All @@ -129,7 +109,7 @@ class LLMResult:
@temporalio.activity.defn
async def llm_call_activity(params: LLMParams) -> LLMResult:
"""Call the LLM, drain the stream, return the final message."""
model = ai.get_model(params.model_id)
model = ai.Model.model_validate(params.model)
messages = [ai.messages.Message.model_validate(m) for m in params.messages]
tools = [
ai.Tool(
Expand Down Expand Up @@ -172,7 +152,7 @@ async def loop(
result = await temporalio.workflow.execute_activity(
llm_call_activity,
LLMParams(
model_id=context.model.id,
model=context.model.model_dump(mode="json"),
messages=[m.model_dump() for m in context.messages],
tool_schemas=tool_schemas,
),
Expand Down Expand Up @@ -238,7 +218,7 @@ async def _call() -> ai.events.ToolCallResult:
class WeatherWorkflow:
@temporalio.workflow.run
async def run(self, user_query: str) -> str:
model = ai.Model(MODEL_ID, provider=WorkflowModelProvider())
model = ai.get_model(MODEL_ID)
messages: list[ai.messages.Message] = [
ai.system_message(
"Answer questions using the weather and population tools."
Expand Down
277 changes: 251 additions & 26 deletions src/ai/models/core/model.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,177 @@
"""Model metadata types."""

from __future__ import annotations

import json
import os
from typing import Any, Self
from typing import TYPE_CHECKING, Any, Literal, Self, cast, overload

import pydantic

from ... import _modelsdev
from ...errors import ConfigurationError
from ...providers import base

if TYPE_CHECKING:
from collections.abc import Callable

import modelsdotdev

_DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL"


class Model:
class Model(pydantic.BaseModel):
"""Lightweight reference to a model on a specific provider.

* ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``).
* ``provider`` — :class:`Provider` that owns this model.
* ``protocol`` — optional wire-protocol override for this model.
* ``provider_name`` — models.dev provider id used to rebuild the provider.
* ``provider_args`` — JSON-friendly provider configuration.

Passing a live ``provider`` makes the model non-serializable.
"""

id: str
provider_name: str | None = None
provider_args: dict[str, Any] = pydantic.Field(default_factory=dict)

_provider: base.Provider[Any] | None = pydantic.PrivateAttr(default=None)
_protocol: base.ProviderProtocol[Any] | None = pydantic.PrivateAttr(
default=None
)
_is_serializable: bool = pydantic.PrivateAttr(default=True)

model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)

@overload
def __init__(
self,
id: str,
*,
provider_name: str,
provider_args: dict[str, Any] | None = None,
) -> None: ...

@overload
def __init__(
self,
id: str,
*,
provider: base.Provider[Any],
protocol: base.ProviderProtocol[Any] | None = None,
) -> None: ...

def __init__(
self,
id: str,
*,
provider: base.Provider,
provider_name: str | None = None,
provider_args: dict[str, Any] | None = None,
provider: base.Provider[Any] | None = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> None:
self.id = id
self.provider = provider
self.protocol = protocol
if (provider is None) == (provider_name is None):
raise ConfigurationError(
"pass exactly one of provider_name or provider"
)
if provider_name == "":
raise ConfigurationError("provider_name must not be empty")

if provider is not None:
if provider_args is not None:
raise ConfigurationError("provider_args requires provider_name")
super().__init__(
id=id,
provider_name=provider.name,
provider_args={},
)
self._provider = provider
self._protocol = protocol
self._is_serializable = False
return

if protocol is not None:
raise ConfigurationError(
"protocol objects are not serializable; "
"use provider=... live-object mode"
)

super().__init__(
id=id,
provider_name=provider_name,
provider_args=provider_args or {},
Comment thread
anbuzin marked this conversation as resolved.
)

def __eq__(self, other: object) -> bool:
if not isinstance(other, Model):
return False
if self._is_serializable and other._is_serializable:
return (
self.id == other.id
and self.provider_name == other.provider_name
and self.provider_args == other.provider_args
)
return (
isinstance(other, Model)
not self._is_serializable
and not other._is_serializable
and self.id == other.id
and self.provider is other.provider
and self.protocol is other.protocol
and self._provider is other._provider
and self._protocol is other._protocol
)

def __repr__(self) -> str:
return f"Model(id={self.id!r}, provider={self.provider!r})"
provider = (
self._provider if self._provider is not None else self.provider_name
)
return f"Model(id={self.id!r}, provider={provider!r})"

def __hash__(self) -> int:
return hash((self.id, id(self.provider), id(self.protocol)))
if self._is_serializable:
return hash(
(
self.id,
self.provider_name,
json.dumps(self.provider_args, sort_keys=True),
)
)
return hash((self.id, id(self._provider), id(self._protocol)))

@property
def is_serializable(self) -> bool:
"""Whether this model can be serialized as durable JSON data."""
return self._is_serializable

@pydantic.field_validator("provider_args", mode="after")
@classmethod
def _normalize_provider_args(
cls, provider_args: dict[str, Any]
) -> dict[str, Any]:
return {
key: value
for key, value in provider_args.items()
if value is not None
}

@property
def provider(self) -> base.Provider[Any]:
"""Provider for this model, lazily rebuilt for durable models."""
if self._provider is None:
if self.provider_name is None:
raise ConfigurationError("model has no provider_name")
self._provider = base.Provider.from_id(
self.provider_name,
model_provider_config=self._model_provider_config(),
**{
key: value
for key, value in self.provider_args.items()
if value is not None
},
)
return self._provider

@property
def protocol(self) -> base.ProviderProtocol[Any] | None:
"""Optional wire-protocol override for this model."""
return self._protocol

def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self:
return self.__class__(
Expand All @@ -50,6 +180,107 @@ def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self:
protocol=protocol,
)

def _model_provider_config(
self,
) -> modelsdotdev.ModelProviderConfig | None:
if self.provider_name is None:
return None
model_info = _modelsdev.get_model_by_id(
f"{self.provider_name}:{self.id}"
)
return None if model_info is None else model_info.provider_config

@pydantic.model_serializer(mode="wrap")
def _serialize_model(
self,
handler: pydantic.SerializerFunctionWrapHandler,
info: pydantic.SerializationInfo,
) -> dict[str, Any]:
if info.mode == "json" and not self._is_serializable:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, why only json mode?

raise ConfigurationError(
"Model was constructed with a live provider/protocol and "
"cannot be serialized. Use provider_name/provider_args instead."
)
return cast("dict[str, Any]", handler(self))

def model_dump(
self,
*,
mode: Literal["json", "python"] | str = "python",
include: Any = None,
exclude: Any = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> dict[str, Any]:
if mode == "json" and not self._is_serializable:
raise ConfigurationError(
"Model was constructed with a live provider/protocol and "
"cannot be serialized. Use provider_name/provider_args instead."
)
return super().model_dump(
mode=mode,
include=include,
exclude=exclude,
context=context,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_computed_fields=exclude_computed_fields,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
)

def model_dump_json(
self,
*,
indent: int | None = None,
ensure_ascii: bool = False,
include: Any = None,
exclude: Any = None,
context: Any | None = None,
by_alias: bool | None = None,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
exclude_computed_fields: bool = False,
round_trip: bool = False,
warnings: bool | Literal["none", "warn", "error"] = True,
fallback: Callable[[Any], Any] | None = None,
serialize_as_any: bool = False,
) -> str:
if not self._is_serializable:
raise ConfigurationError(
"Model was constructed with a live provider/protocol and "
"cannot be serialized. Use provider_name/provider_args instead."
)
return super().model_dump_json(
indent=indent,
ensure_ascii=ensure_ascii,
include=include,
exclude=exclude,
context=context,
by_alias=by_alias,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
exclude_computed_fields=exclude_computed_fields,
round_trip=round_trip,
warnings=warnings,
fallback=fallback,
serialize_as_any=serialize_as_any,
)
Comment on lines +206 to +282

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed? Shouldn't _serialize_model handle this?



def get_model(
model_id: str | None = None,
Expand Down Expand Up @@ -87,6 +318,12 @@ def get_model(
if not model_id:
raise ConfigurationError(f"get_model: malformed model_id: {model_id!r}")

if protocol is not None:
raise ConfigurationError(
"protocol objects are not serializable; "
"construct Model with provider=... live-object mode instead"
)

if ":" not in model_id:
model_id = f"gateway:{model_id}"

Expand All @@ -95,16 +332,4 @@ def get_model(
provider_id = ref.provider_id
provider_model_id = ref.model_id

model_info = _modelsdev.get_model_by_id(
f"{provider_id}:{provider_model_id}"
)
model_provider_config = (
None if model_info is None else model_info.provider_config
)

provider = base.Provider.from_id(
provider_id,
model_provider_config=model_provider_config,
)

return Model(provider_model_id, provider=provider, protocol=protocol)
return Model(provider_model_id, provider_name=provider_id)
Loading
Loading