Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 18 additions & 36 deletions examples/.test_scripts/run-with-patched-model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import argparse
import runpy
import sys
from collections.abc import Callable
from typing import Any, TypeVar, cast

import ai
Expand All @@ -30,31 +29,26 @@
from ai.models.core import model as _model
from ai.providers.anthropic import (
AnthropicCompatibleProvider,
AnthropicMessagesProtocol,
)
from ai.providers.openai import (
OpenAIChatCompletionsProtocol,
OpenAICompatibleProvider,
OpenAIResponsesProtocol,
)

PROTOCOLS = ("chat", "messages", "responses")

ModelT = TypeVar("ModelT", bound=ai.Model)


def _protocol_factory(
name: str | None,
) -> Callable[[], ai.ProviderProtocol[Any]] | None:
def _protocol_ref(name: str | None) -> ai.ProtocolRef | None:
if name is None:
return None

if name == "chat":
return OpenAIChatCompletionsProtocol
return ai.ProtocolRef("openai.chat_completions")
if name == "messages":
return AnthropicMessagesProtocol
return ai.ProtocolRef("anthropic.messages")
if name == "responses":
return OpenAIResponsesProtocol
return ai.ProtocolRef("openai.responses")

raise ValueError(f"unsupported protocol: {name}")

Expand Down Expand Up @@ -82,42 +76,37 @@ def _parse_args() -> argparse.Namespace:
def main() -> None:
args = _parse_args()

protocol_factory = _protocol_factory(args.protocol)
protocol_ref = _protocol_ref(args.protocol)

original_get_model = _model.get_model
original_stream = _api.stream
original_generate = _api.generate

def selected_protocol() -> ai.ProviderProtocol[Any] | None:
if protocol_factory is None:
return None
return protocol_factory()

def selected_protocol_for_provider(
provider: ai.Provider[Any],
) -> ai.ProviderProtocol[Any] | None:
if args.protocol is None:
) -> ai.ProtocolRef | None:
if args.protocol is None or protocol_ref is None:
return None
if args.protocol in ("chat", "responses") and isinstance(
provider, OpenAICompatibleProvider
):
return selected_protocol()
return protocol_ref
if args.protocol == "messages" and isinstance(
provider, AnthropicCompatibleProvider
):
return selected_protocol()
return protocol_ref
return None

def selected_protocol_for_model(
model: ai.Model,
) -> ai.ProviderProtocol[Any] | None:
) -> ai.ProtocolRef | None:
return selected_protocol_for_provider(model.provider)

def with_selected_protocol(model: ModelT) -> ModelT:
protocol = selected_protocol_for_model(model)
if protocol is None:
selected = selected_protocol_for_model(model)
if selected is None:
return model
return model.with_protocol(protocol)
return model.with_protocol(selected)

class PatchedContext:
def __init__(self, context: Any) -> None:
Expand Down Expand Up @@ -174,18 +163,11 @@ async def patched_generate(*call_args: Any, **kwargs: Any) -> Any:
return await original_generate(*call_args, **kwargs)

class PatchedModel(_model.Model):
def __init__(
self,
id: str,
*,
provider: ai.Provider[Any],
protocol: ai.ProviderProtocol[Any] | None = None,
) -> None:
super().__init__(
id,
provider=provider,
protocol=selected_protocol_for_provider(provider) or protocol,
)
def __init__(self, id: str, **kwargs: Any) -> None:
super().__init__(id, **kwargs)
override = selected_protocol_for_provider(self.provider)
if override is not None:
self.protocol_ref = override

cast("Any", ai).get_model = patched_get_model
cast("Any", models).get_model = patched_get_model
Expand Down
8 changes: 3 additions & 5 deletions examples/builtin_web_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,11 @@ def format(value: object) -> str:


async def main() -> None:
provider = ai.get_provider("anthropic")
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
model = ai.get_model("anthropic:claude-sonnet-4-6")
if not model.provider.is_configured():
print(f"[SKIP] {model.provider.name} provider is not configured")
return

model = ai.Model("claude-sonnet-4-6", provider=provider)

async with ai.stream(model, messages, tools=tools) as s:
async for event in s:
match event:
Expand Down
33 changes: 16 additions & 17 deletions examples/check_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

import ai

PROVIDERS: list[tuple[str, ai.Provider, str]] = [
("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"),
("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"),
("openai", ai.get_provider("openai"), "gpt-5.4-mini"),
MODELS: list[tuple[str, ai.Model]] = [
("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")),
("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")),
("openai", ai.get_model("openai:gpt-5.4-mini")),
]

_failed = False
Expand All @@ -20,36 +20,35 @@ def _fail(msg: str) -> None:
print(msg)


async def _check(name: str, provider: ai.Provider, model_id: str) -> None:
if not provider.is_configured():
print(f" [SKIP] {provider.name} provider is not configured")
async def _check(name: str, model: ai.Model) -> None:
if not model.provider.is_configured():
print(f" [SKIP] {model.provider.name} provider is not configured")
return
model = ai.Model(model_id, provider=provider)
try:
await ai.probe(model)
print(f" [OK] {name}/{model_id}")
print(f" [OK] {name}/{model.id}")
except Exception as exc:
_fail(f" [ERR] {name}/{model_id}: {exc}")
_fail(f" [ERR] {name}/{model.id}: {exc}")


async def _list_models(name: str, provider: ai.Provider) -> None:
if not provider.is_configured():
async def _list_models(name: str, model: ai.Model) -> None:
if not model.provider.is_configured():
return
try:
ids: list[str] = await provider.list_models()
ids: list[str] = await model.provider.list_models()
print(f" {name}: {len(ids)} models (last: {ids[-1]})")
except Exception as exc:
_fail(f" {name}: [ERR] {exc}")


async def main() -> None:
print("Checking connections...\n")
for name, provider, model_id in PROVIDERS:
await _check(name, provider, model_id)
for name, model in MODELS:
await _check(name, model)

print("\nListing models...\n")
for name, provider, _ in PROVIDERS:
await _list_models(name, provider)
for name, model in MODELS:
await _list_models(name, model)

print()
if _failed:
Expand Down
22 changes: 10 additions & 12 deletions examples/explicit_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,16 @@

async def main() -> None:
# Example for local OpenAI-compatible servers like LM Studio.
provider = ai.get_provider(
"openai",
base_url=os.environ.get(
"LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1"
),
api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"),
headers={"X-Custom-Header": "example"},
)

model = ai.Model(
os.environ.get("LOCAL_OPENAI_MODEL", "local-model"),
provider=provider,
provider=ai.ProviderRef(
"openai",
base_url=os.environ.get(
"LOCAL_OPENAI_BASE_URL", "http://localhost:1234/v1"
),
api_key=os.environ.get("LOCAL_OPENAI_API_KEY", "some-key"),
headers={"X-Custom-Header": "example"},
),
)

try:
Expand All @@ -39,8 +37,8 @@ async def main() -> None:
print(event.chunk, end="", flush=True)
print()
finally:
# Explicit providers need explicit cleanup.
await provider.aclose()
# Explicitly configured providers need explicit cleanup.
await model.aclose()


if __name__ == "__main__":
Expand Down
17 changes: 6 additions & 11 deletions examples/openai_chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import asyncio

import ai
from ai.providers.openai import OpenAIChatCompletionsProtocol

messages = [
ai.system_message("Be concise."),
Expand All @@ -14,16 +13,12 @@


async def main() -> None:
provider = ai.get_provider("openai")
if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
return

model = ai.Model(
"gpt-5.5",
provider=provider,
protocol=OpenAIChatCompletionsProtocol(),
model = ai.get_model("openai:gpt-5.5").with_protocol(
"openai.chat_completions"
)
if not model.provider.is_configured():
print(f"[SKIP] {model.provider.name} provider is not configured")
return

try:
async with ai.stream(model, messages) as stream:
Expand All @@ -32,7 +27,7 @@ async def main() -> None:
print(event.chunk, end="", flush=True)
print()
finally:
await provider.aclose()
await model.aclose()


if __name__ == "__main__":
Expand Down
22 changes: 10 additions & 12 deletions examples/stream_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import ai

MODELS: list[tuple[str, ai.Provider, str]] = [
("ai_gateway", ai.get_provider("vercel"), "anthropic/claude-sonnet-4.6"),
("anthropic", ai.get_provider("anthropic"), "claude-sonnet-4-6"),
("openai", ai.get_provider("openai"), "gpt-5.5"),
MODELS: list[tuple[str, ai.Model]] = [
("ai_gateway", ai.get_model("gateway:anthropic/claude-sonnet-4.6")),
("anthropic", ai.get_model("anthropic:claude-sonnet-4-6")),
("openai", ai.get_model("openai:gpt-5.5")),
]

messages = [
Expand All @@ -16,15 +16,13 @@
]


async def _run(name: str, provider: ai.Provider, model_id: str) -> None:
print(f"\n{name} / {model_id}")
async def _run(name: str, model: ai.Model) -> None:
print(f"\n{name} / {model.id}")

if not provider.is_configured():
print(f"[SKIP] {provider.name} provider is not configured")
if not model.provider.is_configured():
print(f"[SKIP] {model.provider.name} provider is not configured")
return

model = ai.Model(model_id, provider=provider)

try:
async with ai.stream(model, messages) as s:
async for event in s:
Expand All @@ -36,8 +34,8 @@ async def _run(name: str, provider: ai.Provider, model_id: str) -> None:


async def main() -> None:
for name, provider, model_id in MODELS:
await _run(name, provider, model_id)
for name, model in MODELS:
await _run(name, model)


if __name__ == "__main__":
Expand Down
36 changes: 13 additions & 23 deletions examples/temporal-direct/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,13 @@
MODEL_ID = "gateway:anthropic/claude-sonnet-4.6"


# ── Workflow-safe model placeholder ──────────────────────────────
#
# ``agent.run`` requires a ``Model``, but a real one can't be built
# inside the workflow: ``ai.get_model("gateway:...")`` constructs an
# ``httpx.AsyncClient`` at provider-init time, which imports
# httpcore/anyio and trips the Temporal sandbox (``threading.local``
# at module load). Our loop never calls the model directly anyway --
# every LLM call is delegated to ``llm_call_activity``, which runs
# outside the sandbox and resolves the real model by id there.
#
# So hand the workflow a placeholder ``Model`` whose provider builds
# no client. It carries the real model id (so the activity can
# resolve it) but is safe to construct inside the sandbox.
class WorkflowModelProvider(ai.Provider[Any]):
"""A clientless provider, safe to construct in a workflow sandbox."""

def __init__(self) -> None:
super().__init__(name="workflow-placeholder", base_url="")

# ``ai.Model`` is fully serializable: it stores a provider *recipe*
# (factory reference + JSON args) and only builds the provider — and
# its ``httpx.AsyncClient`` — on first use. The workflow can therefore
# construct the model with ``ai.get_model`` inside the Temporal sandbox
# (nothing network-shaped is created there) and ship it to
# ``llm_call_activity`` as plain JSON, where the provider gets built
# for real.
Comment on lines +52 to +58

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Classic backward-looking AI comment. Probably delete?


# ── Tool definitions ─────────────────────────────────────────────
#
Expand Down Expand Up @@ -116,7 +104,7 @@ async def get_population_activity(city: str) -> int:

@dataclasses.dataclass
class LLMParams:
model_id: str
model: dict[str, Any]
messages: list[dict[str, Any]]
tool_schemas: list[dict[str, Any]]

Expand All @@ -129,7 +117,7 @@ class LLMResult:
@temporalio.activity.defn
async def llm_call_activity(params: LLMParams) -> LLMResult:
"""Call the LLM, drain the stream, return the final message."""
model = ai.get_model(params.model_id)
model = ai.Model.model_validate(params.model)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing to do about it, but the overloading of model is unfortunate

messages = [ai.messages.Message.model_validate(m) for m in params.messages]
tools = [
ai.Tool(
Expand Down Expand Up @@ -172,7 +160,7 @@ async def loop(
result = await temporalio.workflow.execute_activity(
llm_call_activity,
LLMParams(
model_id=context.model.id,
model=context.model.model_dump(mode="json"),
messages=[m.model_dump() for m in context.messages],
tool_schemas=tool_schemas,
),
Expand Down Expand Up @@ -238,7 +226,9 @@ async def _call() -> ai.events.ToolCallResult:
class WeatherWorkflow:
@temporalio.workflow.run
async def run(self, user_query: str) -> str:
model = ai.Model(MODEL_ID, provider=WorkflowModelProvider())
# Safe in the sandbox: the provider (and its HTTP client) is
# only built lazily, inside the LLM activity.
model = ai.get_model(MODEL_ID)
messages: list[ai.messages.Message] = [
ai.system_message(
"Answer questions using the weather and population tools."
Expand Down
Loading
Loading