Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/.test_scripts/run-with-patched-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(
protocol: ai.ProviderProtocol[Any] | None = None,
) -> None:
super().__init__(
id,
id=id,
provider=provider,
protocol=selected_protocol_for_provider(provider) or protocol,
)
Expand Down
5 changes: 2 additions & 3 deletions examples/builtin_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,12 @@ def format(value: object) -> str:


async def main() -> None:
provider = ai.get_provider("anthropic")
model = ai.get_model("anthropic:claude-sonnet-4-6")
provider = model.provider
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model("claude-sonnet-4-6", provider=provider)

async with ai.stream(model, messages, tools=tools) as s:
async for event in s:
match event:
Expand Down
2 changes: 1 addition & 1 deletion examples/check_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ async def _check(name: str, provider: ai.Provider, model_id: str) -> None:
if not provider.is_configured():
print(f" [SKIP] {provider.name} provider is not configured")
return
model = ai.Model(model_id, provider=provider)
model = ai.Model(id=model_id, provider=provider)
try:
await ai.probe(model)
print(f" [OK] {name}/{model_id}")
Expand Down
2 changes: 1 addition & 1 deletion examples/explicit_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ async def main() -> None:
)

model = ai.Model(
os.environ.get("LOCAL_OPENAI_MODEL", "local-model"),
id=os.environ.get("LOCAL_OPENAI_MODEL", "local-model"),
provider=provider,
)

Expand Down
12 changes: 5 additions & 7 deletions examples/openai_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@


async def main() -> None:
provider = ai.get_provider("openai")
model = ai.get_model(
"openai:gpt-5.5",
protocol=OpenAIChatCompletionsProtocol(),
)
provider = model.provider
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model(
"gpt-5.5",
provider=provider,
protocol=OpenAIChatCompletionsProtocol(),
)

try:
async with ai.stream(model, messages) as stream:
async for event in stream:
Expand Down
19 changes: 9 additions & 10 deletions examples/stream_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import ai

MODELS: list[tuple[str, ai.Provider, str]] = [
("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"),
("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"),
("openai", ai.get_provider("openai"), "gpt-5.5"),
MODELS: list[tuple[str, ai.Model]] = [
("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")),
("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")),
("openai", ai.get_model("openai:gpt-5.5")),
]

messages = [
Expand All @@ -16,15 +16,14 @@
]


async def _run(name: str, provider: ai.Provider, model_id: str) -> None:
print(f"\n{name} / {model_id}")
async def _run(name: str, model: ai.Model) -> None:
print(f"\n{name} / {model.id}")

provider = model.provider
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model(model_id, provider=provider)

try:
async with ai.stream(model, messages) as s:
async for event in s:
Expand All @@ -36,8 +35,8 @@ async def _run(name: str, provider: ai.Provider, model_id: str) -> None:


async def main() -> None:
for name, provider, model_id in MODELS:
await _run(name, provider, model_id)
for name, model in MODELS:
await _run(name, model)


if __name__ == "__main__":
Expand Down
22 changes: 1 addition & 21 deletions examples/temporal-direct/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,26 +49,6 @@
MODEL_ID = "gateway:anthropic/claude-sonnet-4.6"


# ── Workflow-safe model placeholder ──────────────────────────────
#
# ``agent.run`` requires a ``Model``, but a real one can't be built
# inside the workflow: ``ai.get_model("gateway:...")`` constructs an
# ``httpx.AsyncClient`` at provider-init time, which imports
# httpcore/anyio and trips the Temporal sandbox (``threading.local``
# at module load). Our loop never calls the model directly anyway --
# every LLM call is delegated to ``llm_call_activity``, which runs
# outside the sandbox and resolves the real model by id there.
#
# So hand the workflow a placeholder ``Model`` whose provider builds
# no client. It carries the real model id (so the activity can
# resolve it) but is safe to construct inside the sandbox.
class WorkflowModelProvider(ai.Provider[Any]):
"""A clientless provider, safe to construct in a workflow sandbox."""

def __init__(self) -> None:
super().__init__(name="workflow-placeholder", base_url="")


# ── Tool definitions ─────────────────────────────────────────────
#
# Declared with @ai.tool so the framework can extract JSON schemas
Expand Down Expand Up @@ -238,7 +218,7 @@ async def _call() -> ai.events.ToolCallResult:
class WeatherWorkflow:
@temporalio.workflow.run
async def run(self, user_query: str) -> str:
model = ai.Model(MODEL_ID, provider=WorkflowModelProvider())
model = ai.get_model(MODEL_ID)
messages: list[ai.messages.Message] = [
ai.system_message(
"Answer questions using the weather and population tools."
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ requires-python = ">=3.12"
dependencies = [
"httpx>=0.28.1",
"modelsdotdev==0.*",
"pydantic>=2.12.5",
"pydantic>=2.13",
"typing-extensions>=4.15.0",
]

Expand Down
6 changes: 3 additions & 3 deletions src/ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import ai
model = ai.get_model("openai:gpt-5.4")
provider = ai.get_provider("openai", base_url="http://localhost:11434/v1")
model = ai.Model("llama3", provider=provider)
model = ai.Model(id="llama3", provider=provider)
model = ai.get_model("anthropic:claude-sonnet-4-6")
provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com")
model = ai.Model("claude-sonnet-4-6", provider=provider)
model = ai.Model(id="claude-sonnet-4-6", provider=provider)
model = ai.get_model("anthropic/claude-sonnet-4") # defaults to Gateway

# stream — auto-creates client from env vars
Expand All @@ -24,7 +24,7 @@
base_url="https://custom.example.com/v1",
api_key="sk-...",
)
model = ai.Model("gpt-5.4", provider=provider)
model = ai.Model(id="gpt-5.4", provider=provider)
async with ai.stream(model, msgs) as s:
...

Expand Down
37 changes: 11 additions & 26 deletions src/ai/models/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,52 +3,37 @@
import os
from typing import Any, Self

import pydantic

from ... import _modelsdev
from ...errors import ConfigurationError
from ...providers import base

_DEFAULT_MODEL_ENV = "AI_SDK_DEFAULT_MODEL"


class Model:
class Model(pydantic.BaseModel):
"""Lightweight reference to a model on a specific provider.

* ``id`` — identifier sent to the provider (e.g. ``"claude-sonnet-4-6"``).
* ``provider`` — :class:`Provider` that owns this model.
* ``protocol`` — optional wire-protocol override for this model.
"""

def __init__(
self,
id: str,
*,
provider: base.Provider,
protocol: base.ProviderProtocol[Any] | None = None,
) -> None:
self.id = id
self.provider = provider
self.protocol = protocol

def __eq__(self, other: object) -> bool:
return (
isinstance(other, Model)
and self.id == other.id
and self.provider is other.provider
and self.protocol is other.protocol
)
id: str
provider: base.Provider[Any]
protocol: base.ProviderProtocol[Any] | None = pydantic.Field(
default=None, exclude_if=lambda v: v is None
)

def __repr__(self) -> str:
return f"Model(id={self.id!r}, provider={self.provider!r})"

def __hash__(self) -> int:
return hash((self.id, id(self.provider), id(self.protocol)))
return hash((self.id, self.provider, self.protocol))

def with_protocol(self, protocol: base.ProviderProtocol[Any]) -> Self:
return self.__class__(
id=self.id,
provider=self.provider,
protocol=protocol,
)
return self.model_copy(update={"protocol": protocol})


def get_model(
Expand Down Expand Up @@ -107,4 +92,4 @@ def get_model(
model_provider_config=model_provider_config,
)

return Model(provider_model_id, provider=provider, protocol=protocol)
return Model(id=provider_model_id, provider=provider, protocol=protocol)
4 changes: 3 additions & 1 deletion src/ai/providers/ai_gateway/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import base64
import json
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from typing import Any, TypeVar
from typing import Any, Literal, TypeVar

import httpx
import pydantic
Expand Down Expand Up @@ -1179,6 +1179,8 @@ async def generate(
class GatewayV3Protocol(base.ProviderProtocol[gateway_client.GatewayClient]):
"""AI Gateway v3 wire protocol."""

protocol_class_id: Literal["gateway_v3"] = "gateway_v3"

def stream(
self,
client: gateway_client.GatewayClient,
Expand Down
80 changes: 38 additions & 42 deletions src/ai/providers/ai_gateway/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, ClassVar
from typing import TYPE_CHECKING, Any, ClassVar, Literal

import pydantic

from ... import errors as ai_errors
from .. import base
Expand All @@ -20,7 +22,6 @@

import httpx
import modelsdotdev
import pydantic

from ...models.core import model as model_
from ...models.core import params as params_
Expand All @@ -37,45 +38,39 @@ class GatewayProvider(base.Provider[gateway_client.GatewayClient]):

handles: ClassVar[tuple[str, ...]] = ("vercel", "@ai-sdk/gateway")

def __init__(
self,
*,
api_key: str | None = None,
base_url: str = _BASE_URL,
headers: Mapping[str, str] | None = None,
env: Mapping[str, str] | None = None,
client: httpx.AsyncClient | None = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> None:
super().__init__(
name="ai-gateway",
base_url=base_url,
protocol=protocol or protocol_module.GatewayV3Protocol(),
api_key=api_key,
api_key_env=_API_KEY_ENV,
headers=headers,
env=env,
)
self._set_client(
gateway_client.GatewayClient(
base_url=self.base_url,
api_key=self.api_key,
headers=self.headers,
client=client,
)
)
provider_class_id: Literal["gateway"] = "gateway"
name: Literal["ai-gateway"] = "ai-gateway"
default_base_url: str = _BASE_URL
api_key_env: str | None = _API_KEY_ENV

_http_client: httpx.AsyncClient | None = pydantic.PrivateAttr(default=None)

def _set_http_client(self, client: httpx.AsyncClient | None) -> None:
self._http_client = client

@property
def client(self) -> gateway_client.GatewayClient:
client = super().client
client.base_url = self.base_url
client.api_key = self.api_key
client.headers = self.headers
return client
if self._client is None:
self._set_client(
gateway_client.GatewayClient(
base_url=self.base_url,
api_key=self.api_key,
headers=self.headers,
client=self._http_client,
)
)
return super().client # same return value, no None in the type

def default_protocol(
self,
) -> base.ProviderProtocol[gateway_client.GatewayClient]:
"""Return the default Gateway protocol."""
return protocol_module.GatewayV3Protocol()

async def aclose(self) -> None:
"""Close the provider-owned Gateway client, if any."""
await self.client.aclose()
if self._client is not None:
await self.client.aclose()

def stream(
self,
Expand Down Expand Up @@ -117,14 +112,15 @@ def from_modelsdev_provider(
client: httpx.AsyncClient | None = None,
protocol: base.ProviderProtocol[Any] | None = None,
) -> base.Provider[gateway_client.GatewayClient]:
return cls(
api_key=api_key,
base_url=base_url or _BASE_URL,
headers=headers,
env=env,
client=client,
protocol=protocol,
provider_instance = cls(
default_base_url=base_url or _BASE_URL,
protocol_override=protocol,
api_key_value=api_key,
headers=dict(headers or {}),
env=dict(env or {}),
)
provider_instance._set_http_client(client)
return provider_instance

@property
def tools(self) -> ModuleType:
Expand Down
2 changes: 1 addition & 1 deletion src/ai/providers/anthropic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
model = ai.get_model("anthropic:claude-sonnet-4-6")
provider = ai.get_provider("anthropic", base_url="https://anthropic.example.com")
model = ai.Model("claude-sonnet-4-6", provider=provider)
model = ai.Model(id="claude-sonnet-4-6", provider=provider)
ids = await ai.get_provider("anthropic").list_models()
# built-in tools
Expand Down
Loading
Loading