Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 105 additions & 15 deletions examples/run-examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
uv run examples/run-examples.py --e2e # also run e2e test scripts
uv run examples/run-examples.py --all # run everything
uv run examples/run-examples.py --parallel # run in parallel
uv run examples/run-examples.py stream.py tools_schema.py
# run selected example files
uv run examples/run-examples.py --model gateway:openai/gpt-5.4-mini
# patch ai.get_model() to use the given model for every sample
uv run examples/run-examples.py --protocol=responses
# patch model/provider helpers and ai.stream()/ai.generate()
"""

import argparse
Expand Down Expand Up @@ -43,6 +47,7 @@ class Sample:
Sample("agent_custom_loop.py"),
Sample("agent_nested.py"),
Sample("streaming_tool.py"),
Sample("openai_chat_completions.py"),
Sample("explicit_client.py"),
Sample("multimodal_input.py"),
Sample("check_connection.py"),
Expand Down Expand Up @@ -118,8 +123,62 @@ class Sample:
),
]

KNOWN_SAMPLES = [
*TEXT_SAMPLES,
*IMAGE_SAMPLES,
*VIDEO_SAMPLES,
*BROKEN_SAMPLES,
*E2E_TESTS,
]


def _sample_cmd(sample: Sample, model: str | None) -> list[str]:
def _path_key(path: Path | str) -> str:
return Path(path).as_posix()


def _known_sample_map() -> dict[str, Sample]:
samples: dict[str, Sample] = {}
for sample in KNOWN_SAMPLES:
samples[sample.name] = sample
if sample.cmd is None:
samples[f"samples/{sample.name}"] = sample
samples[f"examples/samples/{sample.name}"] = sample
samples[_path_key(SAMPLES / sample.name)] = sample
else:
samples[f"examples/{sample.name}"] = sample
samples[_path_key(REPO / "examples" / sample.name)] = sample
return samples


def _sample_path(name: str) -> Path:
path = Path(name)
if path.is_absolute():
return path
if path.parts[:1] == ("examples",):
return REPO / path
if path.parts[:1] == ("samples",):
return REPO / "examples" / path
return SAMPLES / path


def _select_sample(name: str, known_samples: dict[str, Sample]) -> Sample | None:
sample = known_samples.get(name)
if sample is not None:
return sample
sample = known_samples.get(_path_key(Path(name).resolve()))
if sample is not None:
return sample
if _sample_path(name).is_file():
return Sample(name)
path = Path(name)
if not path.is_absolute() and path.parts[:1] != ("examples",):
example_path = REPO / "examples" / path
if example_path.is_file():
return Sample(f"examples/{_path_key(path)}")
return None


def _sample_cmd(sample: Sample, model: str | None, protocol: str | None) -> list[str]:
if sample.cmd is not None:
return sample.cmd
base = [
Expand All @@ -132,9 +191,14 @@ def _sample_cmd(sample: Sample, model: str | None) -> list[str]:
str(REPO),
"python",
]
if model is not None:
return [*base, str(PATCH_SCRIPT), model, str(SAMPLES / sample.name)]
return [*base, str(SAMPLES / sample.name)]
if model is not None or protocol is not None:
cmd = [*base, str(PATCH_SCRIPT)]
if model is not None:
cmd.extend(["--model", model])
if protocol is not None:
cmd.extend(["--protocol", protocol])
return [*cmd, str(_sample_path(sample.name))]
return [*base, str(_sample_path(sample.name))]


_env = {k: v for k, v in os.environ.items() if k != "VIRTUAL_ENV"}
Expand All @@ -146,11 +210,11 @@ def _sample_env(sample: Sample) -> dict[str, str]:
return {**_env, **sample.extra_env}


def run_sample(sample: Sample, model: str | None) -> bool:
def run_sample(sample: Sample, model: str | None, protocol: str | None) -> bool:
print(f"{'=' * 20} {sample.name} {'=' * 20}")
sys.stdout.flush()
result = subprocess.run(
_sample_cmd(sample, model),
_sample_cmd(sample, model, protocol),
env=_sample_env(sample),
timeout=sample.timeout,
input=sample.stdin,
Expand All @@ -174,10 +238,12 @@ def print_summary(results: list[tuple[str, bool]]) -> bool:
return any_failed


def run_sample_quiet(sample: Sample, model: str | None) -> tuple[str, bool, str]:
def run_sample_quiet(
sample: Sample, model: str | None, protocol: str | None
) -> tuple[str, bool, str]:
try:
result = subprocess.run(
_sample_cmd(sample, model),
_sample_cmd(sample, model, protocol),
env=_sample_env(sample),
timeout=sample.timeout,
capture_output=True,
Expand Down Expand Up @@ -209,28 +275,52 @@ def main() -> None:
"samples with a custom cmd"
),
)
parser.add_argument(
"--protocol",
choices=("chat", "messages", "responses"),
help=(
"run each sample through run-with-patched-model.py with this "
"underlying protocol; ignored for samples with a custom cmd"
),
)
parser.add_argument(
"examples",
nargs="*",
metavar="example",
help="example file(s) to run, e.g. stream.py or examples/samples/stream.py",
)
args = parser.parse_args()

has_category = args.text or args.image or args.video or args.broken or args.e2e

samples: list[Sample] = []
if args.text or args.all or not has_category:
if args.examples:
known_samples = _known_sample_map()
for example in args.examples:
sample = _select_sample(example, known_samples)
if sample is None:
parser.error(f"unknown example file: {example}")
samples.append(sample)
elif args.text or args.all or not has_category:
samples.extend(TEXT_SAMPLES)
if args.image or args.all:
if not args.examples and (args.image or args.all):
samples.extend(IMAGE_SAMPLES)
if args.video or args.all:
if not args.examples and (args.video or args.all):
samples.extend(VIDEO_SAMPLES)
if args.broken or args.all:
if not args.examples and (args.broken or args.all):
samples.extend(BROKEN_SAMPLES)
if args.e2e or args.all:
if not args.examples and (args.e2e or args.all):
samples.extend(E2E_TESTS)

results: list[tuple[str, bool]] = []

if args.parallel:
outputs: dict[str, str] = {}
with concurrent.futures.ThreadPoolExecutor() as pool:
futures = {pool.submit(run_sample_quiet, s, args.model): s for s in samples}
futures = {
pool.submit(run_sample_quiet, s, args.model, args.protocol): s
for s in samples
}
for future in concurrent.futures.as_completed(futures):
name, ok, output = future.result()
status = "PASS" if ok else "FAIL"
Expand All @@ -253,7 +343,7 @@ def main() -> None:
else:
for sample in samples:
try:
ok = run_sample(sample, args.model)
ok = run_sample(sample, args.model, args.protocol)
except subprocess.TimeoutExpired:
print(f" TIMEOUT after {sample.timeout:g}s\n")
ok = False
Expand Down
148 changes: 135 additions & 13 deletions examples/run-with-patched-model.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#!/usr/bin/env python3
"""Run a Python file with ``ai.get_model()`` patched to always return a fixed model.
"""Run a Python file with model/protocol selection patched in.

Useful for re-running an example against a different model without
editing it.
Useful for re-running an example against a different model or underlying
wire protocol without editing it.

Usage (from repo root):

uv run examples/run-with-patched-model.py <model> <file.py>
uv run examples/run-with-patched-model.py --protocol=responses <file.py>

Example:

Expand All @@ -18,31 +19,152 @@
import argparse
import runpy
import sys
from collections.abc import Callable
from typing import Any

import ai
from ai import models
from ai.models import core
from ai.models.core import api as _api
from ai.models.core import model as _model

PROTOCOLS = ("chat", "messages", "responses")

def main() -> None:

def _protocol_factory(
name: str | None,
) -> Callable[[], ai.ProviderProtocol[Any]] | None:
if name is None:
return None

if name == "chat":
from ai.providers.openai import OpenAIChatCompletionsProtocol

return OpenAIChatCompletionsProtocol
if name == "messages":
from ai.providers.anthropic import AnthropicMessagesProtocol

return AnthropicMessagesProtocol
if name == "responses":
from ai.providers.openai import OpenAIResponsesProtocol

return OpenAIResponsesProtocol

raise ValueError(f"unsupported protocol: {name}")


def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument(
"model", help="model id, e.g. 'gateway:anthropic/claude-sonnet-4.6'"
"--model", help="model id, e.g. 'gateway:anthropic/claude-sonnet-4.6'"
)
parser.add_argument("file", help="path to a python file to execute")
parser.add_argument("--protocol", choices=PROTOCOLS)
parser.add_argument("args", nargs="+", metavar="ARG")
args = parser.parse_args()

original = _model.get_model
if len(args.args) == 1:
args.file = args.args[0]
elif len(args.args) == 2 and args.model is None:
args.model = args.args[0]
args.file = args.args[1]
else:
parser.error("expected <file.py> or legacy <model> <file.py>")

return args


def main() -> None:
args = _parse_args()

protocol_factory = _protocol_factory(args.protocol)

original_get_model = _model.get_model
original_stream = _api.stream
original_generate = _api.generate
original_model = _model.Model

def selected_protocol() -> ai.ProviderProtocol[Any] | None:
if protocol_factory is None:
return None
return protocol_factory()

def selected_protocol_for_provider(
provider: ai.Provider[Any],
) -> ai.ProviderProtocol[Any] | None:
if args.protocol is None:
return None
if args.protocol in ("chat", "responses"):
from ai.providers.openai import OpenAICompatibleProvider

if isinstance(provider, OpenAICompatibleProvider):
return selected_protocol()
if args.protocol == "messages":
from ai.providers.anthropic import AnthropicCompatibleProvider

if isinstance(provider, AnthropicCompatibleProvider):
return selected_protocol()
return None

def selected_protocol_for_model(model: Any) -> ai.ProviderProtocol[Any] | None:
provider = getattr(model, "provider", None)
if provider is None:
return None
return selected_protocol_for_provider(provider)

def patched_get_model(*_args: Any, **_kwargs: Any) -> ai.Model:
model_id = args.model or (_args[0] if _args else _kwargs.get("model_id"))
model = original_get_model(model_id)
model.protocol = selected_protocol_for_model(model)
return model

def patched_stream(*args: Any, **kwargs: Any) -> Any:
model = args[0] if args else getattr(kwargs.get("context"), "model", None)
protocol = selected_protocol_for_model(model)
if protocol is not None:
kwargs["protocol"] = protocol
return original_stream(*args, **kwargs)

async def patched_generate(*args: Any, **kwargs: Any) -> Any:
model = args[0] if args else kwargs.get("model")
protocol = selected_protocol_for_model(model)
if protocol is not None:
kwargs["protocol"] = protocol
return await original_generate(*args, **kwargs)

class PatchedModel(original_model):
def __init__(
self,
id: str,
*,
provider: ai.Provider[Any],
protocol: ai.ProviderProtocol[Any] | None = None,
) -> None:
super().__init__(
id,
provider=provider,
protocol=selected_protocol_for_provider(provider) or protocol,
)

ai.get_model = patched_get_model
models.get_model = patched_get_model
core.get_model = patched_get_model
_model.get_model = patched_get_model

if args.protocol is not None:
ai.Model = PatchedModel
models.Model = PatchedModel
core.Model = PatchedModel
_model.Model = PatchedModel

def patched(*_args: Any, **_kwargs: Any) -> ai.Model:
return original(args.model)
ai.stream = patched_stream
models.stream = patched_stream
core.stream = patched_stream
_api.stream = patched_stream

ai.get_model = patched
models.get_model = patched
core.get_model = patched
_model.get_model = patched
ai.generate = patched_generate
models.generate = patched_generate
core.generate = patched_generate
_api.generate = patched_generate

sys.argv = [args.file]
runpy.run_path(args.file, run_name="__main__")
Expand Down
Loading
Loading