From 824efe06d0fef349110328e826c83c263e186a5b Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 19 May 2026 16:32:28 -0700 Subject: [PATCH] Replace `params: Any` in `stream`/`generate` with generalized types Using `params: Any` for the typing of inference parameters is not great, because discoverability and validation. Even though the variability in upstream provider configuration shapes is great, the overlap in common functionality is big enough to warrant ascribing it in common types and concepts. In addition to common parameters, add a provider-specific options section as well as generic `extra_headers`/`extra_query`/`extra_body` escape hatch to drop to untyped underlying API level. --- examples/coding_agent_minimal.py | 5 +- examples/model_params.py | 12 +- examples/prompt_caching.py | 2 +- src/ai/__init__.py | 67 ++- src/ai/_types.py | 10 + src/ai/agents/_middleware.py | 3 +- src/ai/agents/agent.py | 12 +- src/ai/models/__init__.py | 68 ++- src/ai/models/core/__init__.py | 66 ++- src/ai/models/core/api.py | 16 +- src/ai/models/core/params.py | 376 +++++++++++++- src/ai/providers/ai_gateway/__init__.py | 9 +- src/ai/providers/ai_gateway/client/_client.py | 3 + src/ai/providers/ai_gateway/params.py | 37 ++ src/ai/providers/ai_gateway/protocol.py | 467 +++++++++++++++++- src/ai/providers/ai_gateway/provider.py | 2 +- src/ai/providers/anthropic/protocol.py | 253 +++++++++- src/ai/providers/anthropic/provider.py | 3 +- src/ai/providers/base.py | 4 +- src/ai/providers/openai/protocol.py | 324 +++++++++++- src/ai/providers/openai/provider.py | 4 +- tests/conftest.py | 6 +- tests/models/core/test_api.py | 30 +- tests/providers/ai_gateway/test_stream.py | 198 +++++++- tests/providers/anthropic/test_adapter.py | 96 +++- tests/providers/anthropic/test_tools.py | 6 +- tests/providers/openai/test_adapter.py | 154 ++++-- tests/types/test_integrity.py | 2 +- uv.lock | 12 +- 29 files changed, 2086 insertions(+), 161 deletions(-) create mode 100644 src/ai/_types.py create mode 100644 src/ai/providers/ai_gateway/params.py diff --git a/examples/coding_agent_minimal.py b/examples/coding_agent_minimal.py index e5cf41bc..b5c4f1e4 100644 --- a/examples/coding_agent_minimal.py +++ b/examples/coding_agent_minimal.py @@ -3,7 +3,6 @@ import asyncio import json import sys -import typing import ai @@ -16,9 +15,7 @@ with sed, make sure to double check the result. """ -STREAM_PARAMS: dict[str, typing.Any] = { - "providerOptions": {"gateway": {"caching": "auto"}}, -} +STREAM_PARAMS = ai.InferenceRequestParams(cache=ai.CacheParams(mode="auto")) @ai.tool diff --git a/examples/model_params.py b/examples/model_params.py index 75d0c65a..ec231766 100644 --- a/examples/model_params.py +++ b/examples/model_params.py @@ -13,12 +13,12 @@ async def main() -> None: - params = { - "providerOptions": { - "gateway": {"sort": "cost"}, - "anthropic": {"speed": "fast"}, - } - } + params = ai.InferenceRequestParams( + routing=ai.RoutingParams( + provider_ranking=ai.ProviderRankingStrategy.COST + ), + extra_body={"providerOptions": {"anthropic": {"speed": "fast"}}}, + ) async with ai.stream(model, messages, params=params) as stream: async for event in stream: if isinstance(event, ai.events.TextDelta): diff --git a/examples/prompt_caching.py b/examples/prompt_caching.py index 140e734d..1dcb0006 100644 --- a/examples/prompt_caching.py +++ b/examples/prompt_caching.py @@ -142,7 +142,7 @@ async def _run(user_text: str) -> ai.types.usage.Usage | None: ai.system_message(SYSTEM_PROMPT), ai.user_message(user_text), ] - params = {"providerOptions": {"gateway": {"caching": "auto"}}} + params = ai.InferenceRequestParams(cache=ai.CacheParams(mode="auto")) async with agent.run(model, messages, params=params) as stream: async for _event in stream: pass diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 066d9aaa..d0bce301 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -48,11 +48,41 @@ UnsupportedProviderError, ) from .models import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GeoRegion, ImageParams, + InferenceRequestParams, + MinPSamplerParams, Model, + ModelProviderDefault, + OutputParams, Provider, ProviderProtocol, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, Stream, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, VideoParams, generate, get_model, @@ -72,18 +102,27 @@ ) __all__ = [ - # Models (from models/) + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", "AIError", - # Agents — primary API "Agent", - # Agents — tools "AgentTool", + "CacheParams", + "CloudRegion", "ConfigurationError", "Context", + "ContextManagementParams", + "GeoRegion", "HTTPErrorContext", "ImageParams", + "InferenceRequestParams", "InstallationError", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", "ProviderAPIError", "ProviderAuthenticationError", @@ -99,20 +138,38 @@ "ProviderOverloadedError", "ProviderPermissionDeniedError", "ProviderProtocol", + "ProviderRankingStrategy", "ProviderRateLimitError", "ProviderRequestTooLargeError", "ProviderResponseError", + "ProviderServiceParams", "ProviderServiceUnavailableError", "ProviderStatusError", "ProviderTimeoutError", "ProviderUnprocessableEntityError", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamingStatusTool", "StreamingTextTool", "SubAgentTool", + "TemperatureSamplerParams", + "TokenThreshold", "Tool", "ToolCall", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", "ToolRunner", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "UnsupportedProviderError", "VideoParams", "abort_pending_hook", @@ -120,13 +177,11 @@ "assistant_message", "cancel_hook", "errors", - # Submodules "events", "file_part", "generate", "get_model", "get_provider", - # Agents — hooks "hook", "mcp", "messages", @@ -143,9 +198,7 @@ "tool_result", "tool_result_part", "tools", - # Builders (from types/builders) "user_message", "util", - # Agents — composition "yield_from", ] diff --git a/src/ai/_types.py b/src/ai/_types.py new file mode 100644 index 00000000..e9646855 --- /dev/null +++ b/src/ai/_types.py @@ -0,0 +1,10 @@ +from collections.abc import Iterator +from typing import Protocol, TypeVar + +_T_co = TypeVar("_T_co", covariant=True) + + +class Collection(Protocol[_T_co]): + def __contains__(self, value: object, /) -> bool: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __len__(self) -> int: ... diff --git a/src/ai/agents/_middleware.py b/src/ai/agents/_middleware.py index 4aa0245f..df1f4fe9 100644 --- a/src/ai/agents/_middleware.py +++ b/src/ai/agents/_middleware.py @@ -43,6 +43,7 @@ import pydantic from ..models.core.model import Model + from ..models.core.params import GenerateParams from ..types import events as events_ from ..types.tools import Tool from .agent import Context @@ -71,7 +72,7 @@ class GenerateContext: model: Model messages: list[messages_.Message] - params: Any + params: GenerateParams def __post_init__(self) -> None: object.__setattr__(self, "messages", list(self.messages)) diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 90bfd449..6d7c658f 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -813,7 +813,9 @@ class Context(pydantic.BaseModel): output_type: type[pydantic.BaseModel] | None = pydantic.Field( default=None, exclude=True, repr=False ) - params: Any = pydantic.Field(default=None, exclude=True, repr=False) + params: models.InferenceRequestParams | None = pydantic.Field( + default=None, exclude=True, repr=False + ) _agent_tools_by_name: dict[str, AgentTool] = pydantic.PrivateAttr( default_factory=dict @@ -1184,7 +1186,7 @@ def run( model: models.Model, messages: list[types.messages.Message], *, - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[str]]: ... @overload @@ -1194,7 +1196,7 @@ def run[T: pydantic.BaseModel]( messages: list[types.messages.Message], *, output_type: type[T], - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[T]]: ... def run( @@ -1203,7 +1205,7 @@ def run( messages: list[types.messages.Message], *, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[Any]]: """Run the agent loop, yielding events to the consumer. @@ -1243,7 +1245,7 @@ async def _run( messages: list[types.messages.Message], *, output_type: type[pydantic.BaseModel] | None, - params: Any, + params: models.InferenceRequestParams | None, _middleware: list[middleware_._Middleware] | None, ) -> AsyncIterator[AgentStream[Any]]: context = Context( diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index f2a2cb4f..29cbb800 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -45,23 +45,85 @@ stream, ) from .core.model import Model, get_model -from .core.params import GenerateParams, ImageParams, VideoParams +from .core.params import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GenerateParams, + GeoRegion, + ImageParams, + InferenceRequestParams, + MinPSamplerParams, + ModelProviderDefault, + OutputParams, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, + VideoParams, +) __all__ = [ - # Core types + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", + "CacheParams", + "CloudRegion", + "ContextManagementParams", "Executor", "GenerateExecutor", "GenerateParams", "GenerateRequest", + "GeoRegion", "ImageParams", + "InferenceRequestParams", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", "ProviderProtocol", + "ProviderRankingStrategy", + "ProviderServiceParams", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamExecutor", "StreamRequest", + "TemperatureSamplerParams", + "TokenThreshold", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "VideoParams", - # Public API "generate", "get_model", "probe", diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 03eecffc..a826cfec 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -14,19 +14,83 @@ stream, ) from .model import Model, get_model -from .params import GenerateParams, ImageParams, VideoParams +from .params import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GenerateParams, + GeoRegion, + ImageParams, + InferenceRequestParams, + MinPSamplerParams, + ModelProviderDefault, + OutputParams, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, + VideoParams, +) __all__ = [ + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", + "CacheParams", + "CloudRegion", + "ContextManagementParams", "Executor", "GenerateExecutor", "GenerateParams", "GenerateRequest", + "GeoRegion", "ImageParams", + "InferenceRequestParams", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", + "ProviderRankingStrategy", + "ProviderServiceParams", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamExecutor", "StreamRequest", + "TemperatureSamplerParams", + "TokenThreshold", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "VideoParams", "generate", "get_model", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index c95e906d..9916daff 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -41,7 +41,7 @@ class StreamRequest: messages: list[types.messages.Message] tools: Sequence[types.tools.Tool] | None = None output_type: type[pydantic.BaseModel] | None = None - params: Any = None + params: params_.InferenceRequestParams | None = None @dataclasses.dataclass(frozen=True) @@ -369,14 +369,14 @@ def tools(self) -> list[types.tools.Tool]: ... @property def output_type(self) -> type[pydantic.BaseModel] | None: ... @property - def params(self) -> Any: ... + def params(self) -> params_.InferenceRequestParams | None: ... @overload def stream( *, context: StreamContext, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -384,7 +384,7 @@ def stream[T: pydantic.BaseModel]( *, context: StreamContext, output_type: type[T], - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... @overload @@ -393,7 +393,7 @@ def stream( messages: list[types.messages.Message], *, tools: Sequence[types.tools.Tool] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -403,7 +403,7 @@ def stream[T: pydantic.BaseModel]( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[T], - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... def stream( @@ -413,7 +413,7 @@ def stream( context: StreamContext | None = None, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[Any]]: """Stream an LLM response. @@ -464,7 +464,7 @@ async def _stream( messages: list[types.messages.Message], tools: Sequence[types.tools.Tool] | None, output_type: type[pydantic.BaseModel] | None, - params: Any, + params: params_.InferenceRequestParams | None, executor: StreamExecutor, ) -> AsyncIterator[Stream[Any]]: if messages and messages[-1].replay: diff --git a/src/ai/models/core/params.py b/src/ai/models/core/params.py index 9969c31c..ee7d906f 100644 --- a/src/ai/models/core/params.py +++ b/src/ai/models/core/params.py @@ -1,7 +1,39 @@ -from typing import Any +import dataclasses +import enum +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any, Self, final import pydantic + +@final +class ModelProviderDefault: + """Sentinel for params: default value used by the model/provider.""" + + +DEFAULT = ModelProviderDefault() +"""Sentinel for params: default value used by the model/provider.""" + + +@final +class Unset: + """Sentinel for a value that should be unset/omitted in a request.""" + + +UNSET = Unset() +"""Sentinel for a value that should be unset/omitted in a request.""" + + +@final +class RandomSeed: + """Sentinel for explicitly random seed selection.""" + + +RANDOM = RandomSeed() +"""Sentinel requesting provider-random seed selection.""" + + _PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) @@ -40,3 +72,345 @@ class VideoParams(pydantic.BaseModel): GenerateParams = ImageParams | VideoParams + + +@dataclass(frozen=True, kw_only=True) +class TemperatureSamplerParams: + """Temperature sampling controls.""" + + temperature: float | ModelProviderDefault = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class TopKSamplerParams: + """Top-k sampling controls.""" + + top_k: int | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class TopPSamplerParams: + """Nucleus sampling controls.""" + + top_p: float | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class MinPSamplerParams: + """Minimum probability sampling controls.""" + + min_p: float | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class RepetitionPenaltyParams: + """Penalty controls for repeated or overrepresented tokens.""" + + repetition_penalty: float | ModelProviderDefault | None = DEFAULT + frequency_penalty: float | ModelProviderDefault | None = DEFAULT + presence_penalty: float | ModelProviderDefault | None = DEFAULT + consideration_window: int | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class SeedSamplerParams: + """Random seed controls for sampling.""" + + seed: int | RandomSeed | ModelProviderDefault | None = DEFAULT + + +SamplerParams = ( + TemperatureSamplerParams + | TopKSamplerParams + | TopPSamplerParams + | MinPSamplerParams + | RepetitionPenaltyParams + | SeedSamplerParams +) + + +type SamplerParamsMap = dict[type[SamplerParams], SamplerParams] + + +type ProviderParamsMap = dict[type[Any], Any] + + +class ToolChoiceMode(enum.StrEnum): + """Built-in policies for model tool selection.""" + + AUTO = "auto" + NONE = "none" + REQUIRED = "required" + + +class ToolRef(str): + """A reference to a specific tool (by tool name).""" + + def __repr__(self) -> str: + return f"" + + +@dataclass(frozen=True, kw_only=True, init=False) +class ToolSelection: + """Tool subset paired with a tool choice policy.""" + + tools: frozenset[ToolRef] + mode: ToolChoiceMode + + def __init__( + self, + tools: Iterable[ToolRef] | Iterable[str], + *, + mode: ToolChoiceMode, + ) -> None: + object.__setattr__(self, "tools", frozenset(ToolRef(s) for s in tools)) + object.__setattr__(self, "mode", mode) + + +@dataclass(frozen=True, kw_only=True) +class ToolCallingParams: + """Tool calling parameters.""" + + max_tool_calls: int | ModelProviderDefault | None = DEFAULT + """The maximum number of tool calls the model may make.""" + + parallel_tool_calls: bool | ModelProviderDefault = DEFAULT + """Whether the model may call multiple tools in parallel.""" + + tool_choice: ToolChoiceMode | ToolRef | ToolSelection + """Tool choice policy. + + * `ToolChoiceMode.AUTO`: the model can choose whether and what tools to call + * `ToolChoiceMode.REQUIED`: the model must call (some) tool + * `ToolChoiceMode.NONE`: the model must not call tools + * `ToolRef("tool-name")`: the model must call the specified tool + * `ToolSelection(tools, mode)`: the model should treat the specified set of + tools according to mode. + """ + + +@dataclass(frozen=True, kw_only=True) +class ReasoningParams: + """Model reasoning/thinking options.""" + + effort: str | ModelProviderDefault | None = DEFAULT + """Provider-specific reasoning/thinking effort level. + + None means reasoning is disabled.""" + + +@dataclass(frozen=True, kw_only=True) +class ProviderServiceParams: + """Provider service parameters (service tier).""" + + service_tier: str | ModelProviderDefault = DEFAULT + """Provider-specific service tier.""" + + +@dataclass(frozen=True) +class TokenThreshold: + """Token count used as a trigger threshold.""" + + value: int + """Token count threshold.""" + + +@dataclass(frozen=True, kw_only=True) +class ContextManagementParams: + """Server-side context management parameters.""" + + compaction: TokenThreshold | None = None + """Compaction trigger threshold.""" + + +@dataclass(frozen=True, kw_only=True) +class OutputParams: + """Model output options.""" + + max_tokens: int | None = None + """The maximum number of tokens to generate before stopping.""" + + include: frozenset[str] | None = None + """Additional provider-specific data to include in the model response.""" + + text_verbosity: str | ModelProviderDefault | None = DEFAULT + """Provider-specific text verbosity level.""" + + reasoning_summary: str | ModelProviderDefault | None = DEFAULT + """Provider-specific reasoning summary emission level. + + None means "disabled".""" + + +@dataclass(frozen=True, kw_only=True) +class CacheParams: + """Provider prompt caching behavior.""" + + mode: str | ModelProviderDefault = DEFAULT + """Provider-specific cache mode.""" + + retention: str | ModelProviderDefault = DEFAULT + """Provider-specific cache retention period (time-to-live).""" + + key: str | None = None + """Custom cache key component. Support is provider-specific.""" + + +@final +class GlobalRoutingTarget: + """Sentinel for globally scoped request routing.""" + + def __repr__(self) -> str: + return "GLOBAL" + + +GLOBAL = GlobalRoutingTarget() +"""Sentinel requesting globally scoped request routing.""" + + +class GeoRegion(str): + """A broad geography, e.g. ``us`` or ``eu``.""" + + +class CloudRegion(str): + """A specific cloud/provider region, e.g. ``us-east-1``.""" + + +type RoutingTarget = GlobalRoutingTarget | GeoRegion | CloudRegion + + +@dataclass(frozen=True, kw_only=True) +class RoutingTargetChain: + """Separate Gateway and provider routing targets.""" + + gateway: RoutingTarget + provider: RoutingTarget + + +type RoutingTargetParam = RoutingTarget | RoutingTargetChain + + +class ProviderRankingStrategy(enum.StrEnum): + """Provider ranking strategy.""" + + COST = "cost" + TTFT = "ttft" + TPS = "tps" + PRICE = "price" + LATENCY = "latency" + THROUGHPUT = "throughput" + + +@dataclass(frozen=True, kw_only=True) +class RoutingParams: + """Inference request routing options.""" + + routing_target: RoutingTargetParam | None = None + """Request (geo-/region-) routing target.""" + + provider_allowlist: frozenset[str] | None = None + """Restrict gateway routing to these providers.""" + + provider_order: tuple[str, ...] | None = None + """Preferred provider order.""" + + provider_ranking: ProviderRankingStrategy | None = None + """Dynamic provider sorting strategy.""" + + fallback_models: tuple[str, ...] | None = None + """Fallback models to try after the requested model.""" + + +@dataclass(frozen=True, kw_only=True) +class InferenceRequestParams: + """Model inference request parameters.""" + + sampling: SamplerParamsMap | ModelProviderDefault = DEFAULT + """Advanced token sampling parameters (e.g temperature, max_p etc).""" + + reasoning: ReasoningParams | ModelProviderDefault = DEFAULT + """Model reasoning parameters.""" + + tool_calling: ToolCallingParams | None = None + """Tool calling parameters.""" + + provider_service: ProviderServiceParams | None = None + """Provider-specific service parameters (service tier etc).""" + + safety_identifier: str | None = None + """A stable identifier used for safety monitoring and abuse detection.""" + + metadata: Mapping[str, str] | None = None + """User-specified metadata associated with the request. + + Note that not all providers support attaching metadata to inference + requests, and the ones that do might place restrictions on length of + metadata both in terms of overall length and in terms of individual items. + For example, OpenAI and Open Responses-compatible providers specify that + keys must have a maximum length of 64 characters, values have a maximum + length of 512 characters, and the total number of metadata items must not + exceed 16.""" + + tags: frozenset[str] | None = None + """User-specified tags associated with the request. + + Note that not all providers support attaching tags to inference + requests, and the ones that do might place restrictions on length of + the tags collections as well as restrictions on individual tag value + length. For example, Vercel AI Gateway limits the number of tags to + 10 and the length of each tag to 64 characters. + """ + + output: OutputParams | None = None + """Model output configuration.""" + + cache: CacheParams | None = None + """Prompt cache parameters.""" + + routing: RoutingParams | None = None + """Request routing parameters.""" + + context_management: ContextManagementParams | None = None + """Context management parameters.""" + + provider_params: ProviderParamsMap | None = None + """Provider-specific typed request parameters keyed by params type.""" + + extra_headers: Mapping[str, str | Unset] | None = None + """Extra headers to pass to the provider API.""" + + extra_query: Mapping[str, Any] | None = None + """Extra URL query string arguments to pass to the provider API.""" + + extra_body: Mapping[str, Any] | None = None + """Extra body arguments to pass to the provider API.""" + + def with_temperature( + self, temperature: float | ModelProviderDefault + ) -> Self: + temp_sampling_params: SamplerParamsMap = { + TemperatureSamplerParams: TemperatureSamplerParams( + temperature=temperature + ) + } + + if type(self.sampling) is ModelProviderDefault: + sampling = temp_sampling_params + else: + sampling = self.sampling | temp_sampling_params + + return dataclasses.replace(self, sampling=sampling) + + def with_reasoning_effort( + self, effort: str | ModelProviderDefault | None + ) -> Self: + return dataclasses.replace( + self, + reasoning=ReasoningParams(effort=effort), + ) + + def with_provider_params(self, *provider_params: object) -> Self: + params = dict(self.provider_params or {}) + for value in provider_params: + params[type(value)] = value + return dataclasses.replace(self, provider_params=params) diff --git a/src/ai/providers/ai_gateway/__init__.py b/src/ai/providers/ai_gateway/__init__.py index 3f9e71bd..759afb8a 100644 --- a/src/ai/providers/ai_gateway/__init__.py +++ b/src/ai/providers/ai_gateway/__init__.py @@ -9,11 +9,13 @@ model = ai.get_model("gateway:anthropic/claude-sonnet-4") ids = await ai.get_provider("vercel").list_models() - # Provider-specific request options pass through as raw Gateway body fields. + # Raw provider-specific request options pass through in extra_body. async with ai.stream( model, msgs, - params={"providerOptions": {"anthropic": {"speed": "fast"}}}, + params=ai.InferenceRequestParams( + extra_body={"providerOptions": {"anthropic": {"speed": "fast"}}} + ), tools=[anthropic_tools.web_search(max_uses=5)], ) as s: ... @@ -30,12 +32,15 @@ """ from . import errors, tools +from .params import GatewayParams, ProviderTimeoutsParams from .protocol import GatewayV3Protocol from .provider import GatewayProvider __all__ = [ + "GatewayParams", "GatewayProvider", "GatewayV3Protocol", + "ProviderTimeoutsParams", "errors", "tools", ] diff --git a/src/ai/providers/ai_gateway/client/_client.py b/src/ai/providers/ai_gateway/client/_client.py index d90972c8..206f0b2e 100644 --- a/src/ai/providers/ai_gateway/client/_client.py +++ b/src/ai/providers/ai_gateway/client/_client.py @@ -189,6 +189,7 @@ async def stream( streaming: bool = False, accept: str | None = None, headers: dict[str, str] | None = None, + query: Mapping[str, Any] | None = None, timeout: httpx.Timeout | float | None = None, ) -> AsyncIterator[httpx.Response]: request_headers = self.model_headers( @@ -206,6 +207,7 @@ async def stream( self.url(path), json=body, headers=request_headers, + params=query, ) if timeout is None else self._http.stream( @@ -213,6 +215,7 @@ async def stream( self.url(path), json=body, headers=request_headers, + params=query, timeout=timeout, ) ) diff --git a/src/ai/providers/ai_gateway/params.py b/src/ai/providers/ai_gateway/params.py new file mode 100644 index 00000000..55283dc5 --- /dev/null +++ b/src/ai/providers/ai_gateway/params.py @@ -0,0 +1,37 @@ +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, kw_only=True) +class ProviderTimeoutsParams: + """Gateway per-provider timeout configuration.""" + + byok: Mapping[str, int] | None = None + """Per-provider BYOK attempt timeouts in milliseconds.""" + + +@dataclass(frozen=True, kw_only=True) +class GatewayParams: + """Vercel AI Gateway-specific request parameters.""" + + quota_entity_id: str | None = None + """Gateway quota bucket/entity identifier.""" + + zero_data_retention: bool | None = None + """Request zero-data-retention routing.""" + + hipaa_compliant: bool | None = None + """Require HIPAA-compliant providers.""" + + disallow_prompt_training: bool | None = None + """Require providers that do not train on prompts.""" + + byok: Mapping[str, Iterable[Mapping[str, Any]]] | None = None + """Request-supplied BYOK credentials, keyed by provider.""" + + provider_timeouts: ProviderTimeoutsParams | None = None + """Per-provider routing timeout configuration.""" + + +__all__ = ["GatewayParams", "ProviderTimeoutsParams"] diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index a19e8a8d..992ef2d7 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -6,19 +6,21 @@ import base64 import json -from collections.abc import AsyncGenerator, Mapping, Sequence -from typing import Any +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence +from typing import Any, TypeVar import httpx import pydantic from ... import types from ...models import core +from ...models.core import params as params_ from .. import base from ..anthropic import tools as anthropic_tools from ..openai import tools as openai_tools from . import client as gateway_client from . import errors +from . import params as gateway_params from . import tools as gateway_tools from .client import errors as client_errors @@ -27,6 +29,30 @@ # --------------------------------------------------------------------------- +_ProviderParamsT = TypeVar("_ProviderParamsT") + + +def _provider_params_value( + value: Mapping[type[Any], Any] | None, + params_type: type[_ProviderParamsT], + *, + provider: str, +) -> _ProviderParamsT | None: + if value is None: + return None + if not isinstance(value, Mapping): + raise TypeError(f"{provider} provider_params must be a mapping") + provider_params = value.get(params_type) + if provider_params is None: + return None + if not isinstance(provider_params, params_type): + raise TypeError( + f"{provider} provider_params[{params_type.__name__}] " + f"must be {params_type.__name__}" + ) + return provider_params + + def _extract_prompt(messages: list[types.messages.Message]) -> str: """Concatenate all text from user/system messages into one prompt.""" parts: list[str] = [] @@ -248,11 +274,10 @@ async def _build_request_body( messages: list[types.messages.Message], tools: Sequence[types.tools.Tool] | None = None, output_type: type[Any] | None = None, - params: Any = None, + params: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """Build the ``LanguageModelV3CallOptions`` request body.""" - stream_params = _coerce_params(params) - body: dict[str, Any] = dict(stream_params) + body: dict[str, Any] = dict(params or {}) body["prompt"] = await _messages_to_prompt(messages) if tools: body["tools"] = [_tool_to_v3(tool) for tool in tools] @@ -265,12 +290,423 @@ async def _build_request_body( return body -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _provider_from_model_id(model_id: str) -> str | None: + provider, sep, _ = model_id.partition("/") + return provider if sep else None + + +def _body_provider_options( + body: dict[str, Any], provider: str +) -> dict[str, Any]: + provider_options = body.setdefault("providerOptions", {}) + if not isinstance(provider_options, dict): + raise TypeError("providerOptions must be a dict") + options = provider_options.setdefault(provider, {}) + if not isinstance(options, dict): + raise TypeError(f"providerOptions.{provider} must be a dict") + return options + + +def _sequence(value: Iterable[str] | None) -> list[str] | None: + if value is None: + return None + return list(value) + + +def _target_to_inference_region( + target: params_.RoutingTarget, +) -> dict[str, str]: + if target is params_.GLOBAL: + return {"scope": "global"} + if isinstance(target, params_.GeoRegion): + return {"geoRegion": str(target)} + return {"providerRegion": str(target)} + + +def _apply_provider_target( + body: dict[str, Any], + *, + provider: str | None, + target: params_.RoutingTarget, +) -> None: + target_provider = provider or "gateway" + options = _body_provider_options(body, target_provider) + target_value = "global" if target is params_.GLOBAL else str(target) + if target_provider == "anthropic" and isinstance(target, params_.GeoRegion): + options["inferenceGeo"] = target_value + else: + options["region"] = target_value + + +def _routing_to_gateway_options( + routing: params_.RoutingParams, +) -> dict[str, Any]: + options: dict[str, Any] = {} + for key, value in { + "only": sorted(routing.provider_allowlist) + if routing.provider_allowlist is not None + else None, + "order": _sequence(routing.provider_order), + "sort": routing.provider_ranking, + "models": _sequence(routing.fallback_models), + }.items(): + if value is not None: + options[key] = value + + if routing.routing_target is not None: + target = routing.routing_target + gateway_target = ( + target.gateway + if isinstance(target, params_.RoutingTargetChain) + else target + ) + options["inferenceRegion"] = _target_to_inference_region(gateway_target) + + return options + + +def _apply_gateway_routing( + body: dict[str, Any], + routing: params_.RoutingParams | None, + *, + provider: str | None, +) -> None: + if routing is None: + return + _body_provider_options(body, "gateway").update( + _routing_to_gateway_options(routing) + ) + if isinstance(routing.routing_target, params_.RoutingTargetChain): + _apply_provider_target( + body, + provider=provider, + target=routing.routing_target.provider, + ) + + +def _apply_gateway_params( + body: dict[str, Any], value: gateway_params.GatewayParams | None +) -> None: + if value is None: + return + options = _body_provider_options(body, "gateway") + for key, option in { + "quotaEntityId": value.quota_entity_id, + "zeroDataRetention": value.zero_data_retention, + "hipaaCompliant": value.hipaa_compliant, + "disallowPromptTraining": value.disallow_prompt_training, + }.items(): + if option is not None: + options[key] = option + + if value.byok is not None: + options["byok"] = { + provider: [dict(credential) for credential in credentials] + for provider, credentials in value.byok.items() + } + + if value.provider_timeouts is not None: + provider_timeouts: dict[str, Any] = {} + if value.provider_timeouts.byok is not None: + provider_timeouts["byok"] = dict(value.provider_timeouts.byok) + if provider_timeouts: + options["providerTimeouts"] = provider_timeouts + + +def _merge_extra_body( + body: dict[str, Any], extra_body: Mapping[str, Any] +) -> None: + extra = dict(extra_body) + provider_options = extra.pop("providerOptions", None) + body.update(extra) + if provider_options is None: + return + if not isinstance(provider_options, Mapping): + raise TypeError("extra_body.providerOptions must be a mapping") + existing = body.setdefault("providerOptions", {}) + if not isinstance(existing, dict): + raise TypeError("providerOptions must be a dict") + for provider, options in provider_options.items(): + if not isinstance(provider, str): + raise TypeError("providerOptions keys must be strings") + if not isinstance(options, Mapping): + raise TypeError(f"providerOptions.{provider} must be a mapping") + current = existing.setdefault(provider, {}) + if not isinstance(current, dict): + raise TypeError(f"providerOptions.{provider} must be a dict") + current.update(options) + + +def _gateway_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, + body: dict[str, Any], +) -> str | dict[str, str]: + if isinstance(tool_choice, params_.ToolChoiceMode): + return tool_choice.value + if isinstance(tool_choice, params_.ToolRef): + return {"type": "tool", "toolName": str(tool_choice)} + body["activeTools"] = sorted(str(tool) for tool in tool_choice.tools) + return tool_choice.mode.value + + +def _apply_gateway_reasoning( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str | None, +) -> None: + reasoning = request_params.reasoning + output = request_params.output + effort: str | params_.ModelProviderDefault | None = params_.DEFAULT + if not isinstance(reasoning, params_.ModelProviderDefault): + effort = reasoning.effort + summary = params_.DEFAULT if output is None else output.reasoning_summary + if _is_default(effort) and _is_default(summary): + return + if provider == "openai": + options = _body_provider_options(body, "openai") + if _not_default(effort): + options["reasoningEffort"] = effort + if _not_default(summary): + options["reasoningSummary"] = summary + return + if provider == "anthropic": + options = _body_provider_options(body, "anthropic") + if _not_default(effort): + if effort is None: + options["thinking"] = {"type": "disabled"} + else: + options["effort"] = effort + if _not_default(summary): + if summary is None: + options["thinking"] = {"type": "disabled"} + else: + thinking = dict(options.get("thinking") or {}) + thinking.setdefault("type", "adaptive") + thinking["display"] = summary + options["thinking"] = thinking + return + body["reasoning"] = { + key: value + for key, value in { + "effort": effort, + "summary": summary, + }.items() + if _not_default(value) + } + + +def _apply_gateway_context_management( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str | None, +) -> None: + context_management = request_params.context_management + if context_management is None or context_management.compaction is None: + return + threshold = context_management.compaction.value + if provider == "openai": + _body_provider_options(body, "openai")["contextManagement"] = [ + {"type": "compaction", "compactThreshold": threshold} + ] + return + if provider == "anthropic": + _body_provider_options(body, "anthropic")["contextManagement"] = { + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": threshold}, + } + ] + } + return + raise ValueError( + "AI Gateway context management requires an OpenAI or Anthropic model" + ) + + +def _apply_gateway_sampling( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + body["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + body["topP"] = top_p + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed): + value = _seed_value(seed) + if value is not None: + body["seed"] = value + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k): + body["topK"] = top_k + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("AI Gateway does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + if _not_default(repetition.frequency_penalty): + body["frequencyPenalty"] = repetition.frequency_penalty + if _not_default(repetition.presence_penalty): + body["presencePenalty"] = repetition.presence_penalty + if ( + _not_default(repetition.repetition_penalty) + and repetition.repetition_penalty is not None + ): + raise ValueError( + "AI Gateway does not support repetition_penalty" + ) + if ( + _not_default(repetition.consideration_window) + and repetition.consideration_window is not None + ): + raise ValueError( + "AI Gateway does not support consideration_window" + ) + + +def _gateway_request_options( + value: params_.InferenceRequestParams | None, + *, + model_id: str, +) -> tuple[dict[str, Any], dict[str, str] | None, dict[str, Any] | None]: if value is None: - return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("ai-gateway stream params must be a dict") + return {}, None, None + if not isinstance(value, params_.InferenceRequestParams): + raise TypeError( + "ai-gateway stream params must be InferenceRequestParams" + ) + + body: dict[str, Any] = {} + provider = _provider_from_model_id(model_id) + _apply_gateway_routing(body, value.routing, provider=provider) + _apply_gateway_params( + body, + _provider_params_value( + value.provider_params, + gateway_params.GatewayParams, + provider="ai-gateway", + ), + ) + _apply_gateway_sampling(body, value) + _apply_gateway_reasoning(body, value, provider=provider) + _apply_gateway_context_management(body, value, provider=provider) + + if value.tool_calling is not None: + tool_calling = value.tool_calling + if _not_default(tool_calling.max_tool_calls): + body["maxToolCalls"] = tool_calling.max_tool_calls + if _not_default(tool_calling.parallel_tool_calls): + body["parallelToolCalls"] = tool_calling.parallel_tool_calls + body["toolChoice"] = _gateway_tool_choice( + tool_calling.tool_choice, + body, + ) + + if value.provider_service is not None: + service = value.provider_service + target_provider = provider or "gateway" + options = _body_provider_options(body, target_provider) + if _not_default(service.service_tier): + options["serviceTier"] = service.service_tier + + if value.safety_identifier is not None: + _body_provider_options(body, "gateway")["user"] = ( + value.safety_identifier + ) + + if value.metadata is not None: + body["metadata"] = dict(value.metadata) + if provider in {"openai", "anthropic"}: + _body_provider_options(body, provider)["metadata"] = dict( + value.metadata + ) + + if value.tags is not None: + _body_provider_options(body, "gateway")["tags"] = sorted(value.tags) + + if value.output is not None: + output = value.output + if output.max_tokens is not None: + body["maxOutputTokens"] = output.max_tokens + if output.include is not None: + if provider == "openai": + _body_provider_options(body, "openai")["include"] = sorted( + output.include + ) + else: + body["include"] = sorted(output.include) + if ( + _not_default(output.text_verbosity) + and output.text_verbosity is not None + ): + raise ValueError("AI Gateway does not support text verbosity") + + if value.cache is not None: + cache = value.cache + if _not_default(cache.mode): + _body_provider_options(body, "gateway")["caching"] = cache.mode + if provider == "openai": + options = _body_provider_options(body, "openai") + if cache.key is not None: + options["promptCacheKey"] = cache.key + if _not_default(cache.retention): + options["promptCacheRetention"] = cache.retention + elif _not_default(cache.retention) or cache.key is not None: + options = _body_provider_options(body, "gateway") + if cache.key is not None: + options["cacheKey"] = cache.key + if _not_default(cache.retention): + options["cacheRetention"] = cache.retention + + if value.extra_body is not None: + _merge_extra_body(body, value.extra_body) + + return ( + body, + _filter_extra_headers(value.extra_headers), + dict(value.extra_query) if value.extra_query is not None else None, + ) # --------------------------------------------------------------------------- @@ -506,10 +942,13 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[types.events.Event]: """Stream an LLM response through the AI Gateway v3 protocol.""" - stream_params = _coerce_params(params) + stream_params, extra_headers, extra_query = _gateway_request_options( + params, + model_id=model.id, + ) body = await _build_request_body( messages, tools=tools, @@ -524,6 +963,8 @@ async def stream( model=model, model_type="language", streaming=True, + headers=extra_headers, + query=extra_query, ) as response: yield types.events.StreamStart() streamed_tool_ids: set[str] = set() @@ -684,7 +1125,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: _ = provider diff --git a/src/ai/providers/ai_gateway/provider.py b/src/ai/providers/ai_gateway/provider.py index 71f5b25e..cf92b4f6 100644 --- a/src/ai/providers/ai_gateway/provider.py +++ b/src/ai/providers/ai_gateway/provider.py @@ -84,7 +84,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the AI Gateway v3 protocol.""" return super().stream( diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 05351b21..ea7896e5 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -8,18 +8,20 @@ import base64 import json -from collections.abc import AsyncGenerator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast import pydantic from ... import types +from ...models.core import params as params_ from ...types import events from .. import base from . import _sdk, errors from . import tools as anthropic_tools if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Mapping, Sequence + import anthropic from ...models import core @@ -342,12 +344,246 @@ def _to_content_list(content: Any) -> list[dict[str, Any]]: return [{"type": "text", "text": content}] -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _extra_body(api_kwargs: dict[str, Any]) -> dict[str, Any]: + extra_body = api_kwargs.get("extra_body") + if not isinstance(extra_body, dict): + extra_body = {} + api_kwargs["extra_body"] = extra_body + return extra_body + + +def _apply_output_config( + api_kwargs: dict[str, Any], + values: Mapping[str, Any], +) -> None: + output_config = dict(api_kwargs.get("output_config") or {}) + output_config.update(values) + api_kwargs["output_config"] = output_config + + +def _anthropic_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, +) -> dict[str, Any]: + if isinstance(tool_choice, params_.ToolChoiceMode): + match tool_choice: + case params_.ToolChoiceMode.AUTO: + return {"type": "auto"} + case params_.ToolChoiceMode.REQUIRED: + return {"type": "any"} + case params_.ToolChoiceMode.NONE: + return {"type": "none"} + if isinstance(tool_choice, params_.ToolRef): + return {"type": "tool", "name": str(tool_choice)} + if len(tool_choice.tools) == 1 and tool_choice.mode in { + params_.ToolChoiceMode.AUTO, + params_.ToolChoiceMode.REQUIRED, + }: + return {"type": "tool", "name": str(next(iter(tool_choice.tools)))} + raise ValueError("Anthropic does not support allowed tool subsets") + + +def _apply_sampling( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + api_kwargs["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + api_kwargs["top_p"] = top_p + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k): + if top_k is None: + raise ValueError("Anthropic top_k cannot be None") + api_kwargs["top_k"] = top_k + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed) and _seed_value(seed) is not None: + raise ValueError("Anthropic does not support seed") + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("Anthropic does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + unsupported = { + "repetition_penalty": repetition.repetition_penalty, + "frequency_penalty": repetition.frequency_penalty, + "presence_penalty": repetition.presence_penalty, + "consideration_window": repetition.consideration_window, + } + for key, value in unsupported.items(): + if _not_default(value) and value is not None: + raise ValueError(f"Anthropic does not support {key}") + + +def _apply_anthropic_params( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str, +) -> None: + _ = provider + disable_parallel_tool_use = None + _apply_sampling(api_kwargs, request_params) + + reasoning = request_params.reasoning + output = request_params.output + summary = params_.DEFAULT if output is None else output.reasoning_summary + if not isinstance(reasoning, params_.ModelProviderDefault) and _not_default( + reasoning.effort + ): + if reasoning.effort is None: + api_kwargs["thinking"] = {"type": "disabled"} + else: + _apply_output_config(api_kwargs, {"effort": reasoning.effort}) + if _not_default(summary): + if summary is None: + api_kwargs["thinking"] = {"type": "disabled"} + else: + thinking = dict(api_kwargs.get("thinking") or {}) + thinking.setdefault("type", "adaptive") + thinking["display"] = summary + api_kwargs["thinking"] = thinking + + if request_params.tool_calling is not None: + tool_calling = request_params.tool_calling + if ( + _not_default(tool_calling.max_tool_calls) + and tool_calling.max_tool_calls is not None + ): + raise ValueError("Anthropic does not support max_tool_calls") + tool_choice = _anthropic_tool_choice(tool_calling.tool_choice) + if _not_default(tool_calling.parallel_tool_calls): + disable_parallel_tool_use = not tool_calling.parallel_tool_calls + if disable_parallel_tool_use is not None: + if tool_choice["type"] == "none": + raise ValueError( + "Anthropic cannot set parallel tool calls with tool none" + ) + tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use + api_kwargs["tool_choice"] = tool_choice + elif disable_parallel_tool_use is not None: + api_kwargs["tool_choice"] = { + "type": "auto", + "disable_parallel_tool_use": disable_parallel_tool_use, + } + + if request_params.provider_service is not None: + service = request_params.provider_service + if _not_default(service.service_tier): + api_kwargs["service_tier"] = service.service_tier + + metadata: dict[str, str] = {} + if request_params.metadata is not None: + metadata.update(request_params.metadata) + if request_params.safety_identifier is not None: + metadata["user_id"] = request_params.safety_identifier + if metadata: + api_kwargs["metadata"] = metadata + + if request_params.context_management is not None: + context_management = request_params.context_management + if context_management.compaction is not None: + _extra_body(api_kwargs)["context_management"] = { + "edits": [ + { + "type": "compact_20260112", + "trigger": { + "type": "input_tokens", + "value": context_management.compaction.value, + }, + } + ] + } + + if request_params.output is not None: + output = request_params.output + if output.max_tokens is not None: + api_kwargs["max_tokens"] = output.max_tokens + if output.include is not None: + raise ValueError("Anthropic does not support output include") + if ( + _not_default(output.text_verbosity) + and output.text_verbosity is not None + ): + raise ValueError("Anthropic does not support text verbosity") + + if request_params.cache is not None: + cache = request_params.cache + if cache.key is not None: + raise ValueError("Anthropic does not support cache keys") + cache_control: dict[str, Any] = {"type": "ephemeral"} + if _not_default(cache.retention): + cache_control["ttl"] = cache.retention + api_kwargs["cache_control"] = cache_control + + extra_headers = _filter_extra_headers(request_params.extra_headers) + if extra_headers is not None: + api_kwargs["extra_headers"] = extra_headers + if ( + request_params.context_management is not None + and request_params.context_management.compaction is not None + ): + _add_builtin_beta_headers( + api_kwargs, + {"compact-2026-01-12", "context-management-2025-06-27"}, + ) + if request_params.extra_query is not None: + api_kwargs["extra_query"] = dict(request_params.extra_query) + if request_params.extra_body is not None: + _extra_body(api_kwargs).update(request_params.extra_body) + + +def _coerce_params( + value: params_.InferenceRequestParams | None, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("anthropic stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_anthropic_params( + api_kwargs, + value, + provider=PROVIDER_NAME, + ) + return api_kwargs + raise TypeError("anthropic stream params must be InferenceRequestParams") def _add_builtin_beta_headers( @@ -393,7 +629,7 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: """Stream through the Anthropic messages protocol using *sdk_client*. @@ -407,6 +643,9 @@ async def stream( """ anthropic_sdk = _sdk.import_sdk(provider=provider) stream_params = _coerce_params(params) + if params is not None: + stream_params = {} + _apply_anthropic_params(stream_params, params, provider=provider) system_prompt, anthropic_messages = await _messages_to_anthropic(messages) custom_tools, builtin_tools = _split_tools(tools or ()) @@ -621,7 +860,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: return stream( diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index 87042a48..ef5ac5fb 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -21,6 +21,7 @@ import pydantic from ...models.core import model as model_ + from ...models.core import params as params_ from ...types import events from ...types import messages as messages_ from ...types import tools as tools_ @@ -136,7 +137,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the Anthropic messages protocol.""" return super().stream( diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 91fd33b9..28d1af10 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -38,7 +38,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: """Stream a language-model response using *client*.""" @@ -212,7 +212,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream a language-model response from this provider.""" selected_protocol = model.protocol or self.protocol diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index 41be970e..e6c4c9b6 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -14,6 +14,7 @@ from ... import errors as ai_errors from ... import types from ...models import core +from ...models.core import params as params_ from .. import base from . import _sdk, errors from . import tools as openai_tools @@ -208,12 +209,288 @@ async def _messages_to_openai( # --------------------------------------------------------------------------- -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _tool_ref_name(tool_ref: params_.ToolRef | str) -> str: + return str(tool_ref) + + +def _tools_for_openai_allowed( + tools: frozenset[params_.ToolRef], + *, + responses: bool, +) -> list[dict[str, Any]]: + if responses: + return [ + {"type": "function", "name": _tool_ref_name(tool)} + for tool in sorted(tools) + ] + return [ + {"type": "function", "function": {"name": _tool_ref_name(tool)}} + for tool in sorted(tools) + ] + + +def _openai_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, + *, + responses: bool, +) -> Any: + if isinstance(tool_choice, params_.ToolChoiceMode): + return tool_choice.value + if isinstance(tool_choice, params_.ToolRef): + if responses: + return {"type": "function", "name": _tool_ref_name(tool_choice)} + return { + "type": "function", + "function": {"name": _tool_ref_name(tool_choice)}, + } + if tool_choice.mode == params_.ToolChoiceMode.NONE: + raise ValueError("OpenAI allowed tool selection does not support none") + if responses: + return { + "type": "allowed_tools", + "mode": tool_choice.mode.value, + "tools": _tools_for_openai_allowed( + tool_choice.tools, + responses=True, + ), + } + return { + "type": "allowed_tools", + "allowed_tools": { + "mode": tool_choice.mode.value, + "tools": _tools_for_openai_allowed( + tool_choice.tools, + responses=False, + ), + }, + } + + +def _apply_sampling( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + chat: bool, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + api_kwargs["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + api_kwargs["top_p"] = top_p + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed): + value = _seed_value(seed) + if value is not None: + if chat: + api_kwargs["seed"] = value + else: + raise ValueError( + "OpenAI responses does not support seed" + ) + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k) and top_k is not None: + raise ValueError("OpenAI does not support top_k") + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("OpenAI does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + if _not_default(repetition.frequency_penalty): + api_kwargs["frequency_penalty"] = ( + repetition.frequency_penalty + ) + if _not_default(repetition.presence_penalty): + api_kwargs["presence_penalty"] = repetition.presence_penalty + if ( + _not_default(repetition.repetition_penalty) + and repetition.repetition_penalty is not None + ): + raise ValueError( + "OpenAI does not support repetition_penalty" + ) + if ( + _not_default(repetition.consideration_window) + and repetition.consideration_window is not None + ): + raise ValueError( + "OpenAI does not support consideration_window" + ) + _ = chat + + +def _apply_common_openai_params( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str, + responses: bool, +) -> None: + _ = provider + _apply_sampling(api_kwargs, request_params, chat=not responses) + + reasoning = request_params.reasoning + output = request_params.output + effort: str | params_.ModelProviderDefault | None = params_.DEFAULT + if not isinstance(reasoning, params_.ModelProviderDefault): + effort = reasoning.effort + summary = params_.DEFAULT if output is None else output.reasoning_summary + if _not_default(effort) or _not_default(summary): + if responses: + reasoning_kwargs: dict[str, Any] = dict( + api_kwargs.get("reasoning") or {} + ) + if _not_default(effort): + reasoning_kwargs["effort"] = ( + "none" if effort is None else effort + ) + if _not_default(summary): + reasoning_kwargs["summary"] = summary + api_kwargs["reasoning"] = reasoning_kwargs + else: + if _not_default(effort): + api_kwargs["reasoning_effort"] = ( + "none" if effort is None else effort + ) + if _not_default(summary) and summary is not None: + raise ValueError( + "OpenAI chat completions does not support reasoning summary" + ) + + if request_params.tool_calling is not None: + tool_calling = request_params.tool_calling + if _not_default(tool_calling.max_tool_calls): + if responses: + api_kwargs["max_tool_calls"] = tool_calling.max_tool_calls + elif tool_calling.max_tool_calls is not None: + raise ValueError( + "OpenAI chat completions does not support max_tool_calls" + ) + if _not_default(tool_calling.parallel_tool_calls): + api_kwargs["parallel_tool_calls"] = tool_calling.parallel_tool_calls + api_kwargs["tool_choice"] = _openai_tool_choice( + tool_calling.tool_choice, + responses=responses, + ) + + if request_params.provider_service is not None: + service = request_params.provider_service + if _not_default(service.service_tier): + api_kwargs["service_tier"] = service.service_tier + + if request_params.safety_identifier is not None: + api_kwargs["safety_identifier"] = request_params.safety_identifier + if request_params.metadata is not None: + api_kwargs["metadata"] = dict(request_params.metadata) + + if request_params.context_management is not None: + context_management = request_params.context_management + if context_management.compaction is not None: + if responses: + api_kwargs["context_management"] = [ + { + "type": "compaction", + "compact_threshold": ( + context_management.compaction.value + ), + } + ] + else: + raise ValueError( + "OpenAI chat completions does not support " + "context management" + ) + + if request_params.output is not None: + output = request_params.output + if output.max_tokens is not None: + api_kwargs[ + "max_output_tokens" if responses else "max_completion_tokens" + ] = output.max_tokens + if output.include is not None: + if responses: + api_kwargs["include"] = sorted(output.include) + else: + raise ValueError( + "OpenAI chat completions does not support output include" + ) + if _not_default(output.text_verbosity): + if responses: + text_config = dict(api_kwargs.get("text") or {}) + text_config["verbosity"] = output.text_verbosity + api_kwargs["text"] = text_config + elif output.text_verbosity is not None: + raise ValueError( + "OpenAI chat completions does not support text verbosity" + ) + + if request_params.cache is not None: + cache = request_params.cache + if cache.key is not None: + api_kwargs["prompt_cache_key"] = cache.key + if _not_default(cache.retention): + api_kwargs["prompt_cache_retention"] = cache.retention + + extra_headers = _filter_extra_headers(request_params.extra_headers) + if extra_headers is not None: + api_kwargs["extra_headers"] = extra_headers + if request_params.extra_query is not None: + api_kwargs["extra_query"] = dict(request_params.extra_query) + if request_params.extra_body is not None: + api_kwargs["extra_body"] = dict(request_params.extra_body) + + +def _coerce_params( + value: params_.InferenceRequestParams | None, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("openai stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_common_openai_params( + api_kwargs, + value, + provider="openai", + responses=False, + ) + return api_kwargs + raise TypeError("openai stream params must be InferenceRequestParams") async def stream( @@ -223,7 +500,7 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: """Stream through the OpenAI chat completions protocol.""" @@ -237,6 +514,14 @@ async def stream( ) stream_params = _coerce_params(params) + if params is not None: + stream_params = {} + _apply_common_openai_params( + stream_params, + params, + provider=provider, + responses=False, + ) openai_messages = await _messages_to_openai(messages) openai_tools = _tools_to_openai(tools) if tools else None @@ -396,7 +681,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: return stream( @@ -431,12 +716,25 @@ def stream( ) -def _coerce_responses_params(value: Any) -> dict[str, Any]: +def _coerce_responses_params( + value: params_.InferenceRequestParams | None, + *, + provider: str, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("openai responses stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_common_openai_params( + api_kwargs, + value, + provider=provider, + responses=True, + ) + return api_kwargs + raise TypeError( + "openai responses stream params must be InferenceRequestParams" + ) def _json_dumps(value: Any) -> str: @@ -940,11 +1238,11 @@ async def _stream_responses( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: openai_sdk = _sdk.import_sdk(provider=provider) - stream_params = _coerce_responses_params(params) + stream_params = _coerce_responses_params(params, provider=provider) protected = sorted(_RESPONSES_PROTECTED_PARAMS & stream_params.keys()) if protected: raise ValueError( @@ -1452,7 +1750,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: return _stream_responses( diff --git a/src/ai/providers/openai/provider.py b/src/ai/providers/openai/provider.py index 92b93044..1541ef4e 100644 --- a/src/ai/providers/openai/provider.py +++ b/src/ai/providers/openai/provider.py @@ -21,6 +21,7 @@ import pydantic from ...models.core import model as model_ + from ...models.core import params as params_ from ...types import events from ...types import messages as messages_ from ...types import tools as tools_ @@ -108,6 +109,7 @@ def _make_sdk_client( api_key=self.api_key or "", default_headers=self.headers, http_client=http_client, + _enforce_credentials=False, ) @property @@ -134,7 +136,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via this provider's configured OpenAI-compatible protocol.""" return super().stream( diff --git a/tests/conftest.py b/tests/conftest.py index 92ef235d..3f742f05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ def stream( *, tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, ) -> AsyncGenerator[events_.Event]: if model.protocol is not None: return model.protocol.stream( @@ -76,7 +76,7 @@ async def generate( self, model: models.Model, messages: list[messages_.Message], - params: Any, + params: models.GenerateParams, ) -> messages_.Message: if model.protocol is not None: return await model.protocol.generate( @@ -227,7 +227,7 @@ async def generate( self, model: models.Model, messages: list[messages_.Message], - params: Any = None, + params: models.GenerateParams, ) -> messages_.Message: if self._call_index >= len(self._responses): raise RuntimeError( diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index b4a0d97b..e7c740f0 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -33,6 +33,30 @@ def _provider_metadata_marker( return marker +def test_inference_request_params_with_provider_params() -> None: + class GatewayParams: + pass + + class OpenAIParams: + pass + + original_gateway = GatewayParams() + replacement_gateway = GatewayParams() + openai = OpenAIParams() + + base = ai.InferenceRequestParams( + provider_params={GatewayParams: original_gateway} + ) + updated = base.with_provider_params(replacement_gateway, openai) + + assert base.provider_params == {GatewayParams: original_gateway} + assert updated is not base + assert updated.provider_params == { + GatewayParams: replacement_gateway, + OpenAIParams: openai, + } + + async def test_stream_aggregates_registered_adapter_events() -> None: mock = mock_llm([[text_msg("Hello world")]]) @@ -193,7 +217,7 @@ async def _spy_stream( MOCK_PROVIDER._stream_impl = _spy_stream - params = {"raw": "ok"} + params = models.InferenceRequestParams(extra_body={"raw": "ok"}) async with models.stream( MOCK_MODEL, [ai.user_message("Hi")], @@ -257,7 +281,7 @@ def stream( *, tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events_.Event]: _ = client, model, messages, tools, output_type, params, provider @@ -296,7 +320,7 @@ async def test_generate_dispatches_to_provider() -> None: async def _generate( model: models.Model, messages: list[messages_.Message], - params: Any = None, + params: models.GenerateParams, ) -> messages_.Message: nonlocal called called = True diff --git a/tests/providers/ai_gateway/test_stream.py b/tests/providers/ai_gateway/test_stream.py index 10962a6a..de8ac02f 100644 --- a/tests/providers/ai_gateway/test_stream.py +++ b/tests/providers/ai_gateway/test_stream.py @@ -16,7 +16,7 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, cast import httpx import pytest @@ -24,6 +24,7 @@ import ai from ai import models from ai.models.core import model as model_ +from ai.providers.ai_gateway import GatewayParams, ProviderTimeoutsParams from ai.types import events, messages from .conftest import mock_client, mock_model, sse, user_msg @@ -345,22 +346,30 @@ def handler(req: httpx.Request) -> httpx.Response: httpx.MockTransport(handler), model_id="anthropic/claude-sonnet-4", ) - request_params = { - "providerOptions": { - "gateway": { - "order": ["bedrock", "anthropic"], - "zeroDataRetention": True, - }, - "anthropic": { - "speed": "fast", - "futureAnthropicField": True, - }, - "google": { - "thinkingConfig": {"budgetTokens": 1024}, + request_params = ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123)}, + output=ai.OutputParams(reasoning_summary="detailed"), + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + routing=ai.RoutingParams(provider_order=("bedrock", "anthropic")), + provider_params={ + GatewayParams: GatewayParams(zero_data_retention=True) + }, + extra_body={ + "providerOptions": { + "anthropic": { + "speed": "fast", + "futureAnthropicField": True, + }, + "google": { + "thinkingConfig": {"budgetTokens": 1024}, + }, }, + "futureGatewayField": True, }, - "futureGatewayField": True, - } + ) async with models.stream( model, [user_msg("Hi")], @@ -375,15 +384,165 @@ def handler(req: httpx.Request) -> httpx.Response: "zeroDataRetention": True, }, "anthropic": { + "effort": "high", "speed": "fast", "futureAnthropicField": True, + "contextManagement": { + "edits": [ + { + "type": "compact_20260112", + "trigger": { + "type": "input_tokens", + "value": 120_000, + }, + } + ] + }, + "thinking": {"type": "adaptive", "display": "detailed"}, }, "google": { "thinkingConfig": {"budgetTokens": 1024}, }, } + assert captured_body["seed"] == 123 assert captured_body["futureGatewayField"] is True + async def test_gateway_omits_random_seed(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="openai/gpt-5.4", + ) + request_params = ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=-1)} + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert "seed" not in captured_body + + async def test_gateway_routing_params_map_to_provider_options( + self, + ) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="anthropic/claude-sonnet-4", + ) + byok: dict[str, list[dict[str, Any]]] = { + "anthropic": [{"apiKey": "sk-test"}] + } + request_params = ai.InferenceRequestParams( + cache=ai.CacheParams(mode="auto"), + safety_identifier="user_123", + tags=frozenset({"team:search", "env:prod"}), + routing=ai.RoutingParams( + provider_allowlist=frozenset({"anthropic", "bedrock"}), + provider_order=("bedrock", "anthropic"), + provider_ranking=ai.ProviderRankingStrategy.LATENCY, + fallback_models=("openai/gpt-5-mini",), + routing_target=ai.CloudRegion("us-east-1"), + ), + provider_params={ + GatewayParams: GatewayParams( + quota_entity_id="quota_123", + zero_data_retention=True, + hipaa_compliant=True, + disallow_prompt_training=True, + byok=byok, + provider_timeouts=ProviderTimeoutsParams( + byok={"anthropic": 5000} + ), + ), + }, + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert captured_body["providerOptions"]["gateway"] == { + "caching": "auto", + "only": ["anthropic", "bedrock"], + "order": ["bedrock", "anthropic"], + "sort": "latency", + "models": ["openai/gpt-5-mini"], + "user": "user_123", + "tags": ["env:prod", "team:search"], + "quotaEntityId": "quota_123", + "zeroDataRetention": True, + "hipaaCompliant": True, + "disallowPromptTraining": True, + "inferenceRegion": {"providerRegion": "us-east-1"}, + "byok": {"anthropic": [{"apiKey": "sk-test"}]}, + "providerTimeouts": {"byok": {"anthropic": 5000}}, + } + + async def test_gateway_openai_context_management_maps_to_provider_options( + self, + ) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="openai/gpt-5.4", + ) + request_params = ai.InferenceRequestParams( + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ) + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert captured_body["providerOptions"]["openai"] == { + "contextManagement": [ + {"type": "compaction", "compactThreshold": 120_000} + ] + } + async def test_gateway_rejects_non_dict_params(self) -> None: def handler(req: httpx.Request) -> httpx.Response: raise AssertionError("request should not be sent") @@ -392,13 +551,14 @@ def handler(req: httpx.Request) -> httpx.Response: httpx.MockTransport(handler), model_id="openai/gpt-5.4", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): async with models.stream( model, [user_msg("Hi")], - params=[ - {"providerOptions": {"openai": {"serviceTier": "auto"}}} - ], + params=cast( + Any, + [{"providerOptions": {"openai": {"serviceTier": "auto"}}}], + ), ) as stream: async for _ in stream: pass diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 01d23da6..a2167d98 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -53,7 +53,7 @@ async def _drain(stream: Any) -> None: pass -async def test_raw_params_pass_through_to_sdk_kwargs( +async def test_params_translate_to_sdk_kwargs( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, captured = _patch_client(monkeypatch) @@ -63,41 +63,54 @@ async def test_raw_params_pass_through_to_sdk_kwargs( fake, _MODEL, [ai.user_message("Hi")], - params={ - "max_tokens": 123, - "speed": "fast", - "thinking": {"type": "disabled"}, - "output_config": { - "effort": "high", - "task_budget": {"type": "tokens", "total": 20000}, - }, - "tool_choice": { - "type": "auto", - "disable_parallel_tool_use": True, + params=ai.InferenceRequestParams( + output=ai.OutputParams(max_tokens=123, reasoning_summary=None), + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + tool_calling=ai.ToolCallingParams( + tool_choice=ai.ToolChoiceMode.AUTO, + parallel_tool_calls=False, + ), + extra_body={ + "speed": "fast", + "future_option": {"enabled": True}, }, - "extra_body": {"future_option": {"enabled": True}}, - "extra_headers": {"x-anthropic-feature": "enabled"}, - }, + extra_headers={"x-anthropic-feature": "enabled"}, + ), provider="anthropic", ) ) assert captured["max_tokens"] == 123 - assert captured["speed"] == "fast" assert captured["thinking"] == {"type": "disabled"} assert captured["output_config"] == { "effort": "high", - "task_budget": {"type": "tokens", "total": 20000}, } assert captured["tool_choice"] == { "type": "auto", "disable_parallel_tool_use": True, } - assert captured["extra_body"] == {"future_option": {"enabled": True}} - assert captured["extra_headers"] == {"x-anthropic-feature": "enabled"} + assert captured["extra_body"] == { + "context_management": { + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": 120_000}, + } + ] + }, + "speed": "fast", + "future_option": {"enabled": True}, + } + assert captured["extra_headers"] == { + "anthropic-beta": "compact-2026-01-12,context-management-2025-06-27", + "x-anthropic-feature": "enabled", + } -async def test_non_dict_params_rejected_by_adapter( +async def test_non_inference_params_rejected_by_adapter( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, _ = _patch_client(monkeypatch) @@ -106,14 +119,53 @@ async def test_non_dict_params_rejected_by_adapter( fake, _MODEL, [ai.user_message("Hi")], - params=[{"speed": "fast"}], + params=cast(Any, [{"speed": "fast"}]), provider="anthropic", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): await _drain(stream) +async def test_seed_rejected_by_adapter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch_client(monkeypatch) + + stream = protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123)} + ), + provider="anthropic", + ) + + with pytest.raises(ValueError, match="seed"): + await _drain(stream) + + +async def test_random_seed_omitted_by_adapter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, captured = _patch_client(monkeypatch) + + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=-1)} + ), + provider="anthropic", + ) + ) + + assert "seed" not in captured + + async def test_reasoning_signature_round_trips_from_provider_metadata( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/providers/anthropic/test_tools.py b/tests/providers/anthropic/test_tools.py index 509a4fbf..7c1e5008 100644 --- a/tests/providers/anthropic/test_tools.py +++ b/tests/providers/anthropic/test_tools.py @@ -30,7 +30,7 @@ async def _capture_tools( monkeypatch: pytest.MonkeyPatch, tools: list[Any], *, - params: dict[str, Any] | None = None, + params: ai.InferenceRequestParams | None = None, ) -> dict[str, Any]: _ = monkeypatch captured: dict[str, Any] = {} @@ -170,7 +170,9 @@ async def test_user_anthropic_beta_header_wins( anthropic_tools.web_search(), anthropic_tools.web_fetch(), ], - params={"extra_headers": {"anthropic-beta": "custom-beta-2026-01-01"}}, + params=ai.InferenceRequestParams( + extra_headers={"anthropic-beta": "custom-beta-2026-01-01"} + ), ) assert _beta_header(captured) == "custom-beta-2026-01-01" diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index 902c24ad..30d6ba86 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -111,7 +111,10 @@ async def close(self) -> None: self.closed = True -_MODEL = ai.Model("gpt-5.4", provider=ai.get_provider("openai")) +_MODEL = ai.Model( + "gpt-5.4", + provider=ai.get_provider("openai", api_key="sk-test"), +) def _patch( @@ -157,7 +160,7 @@ async def test_responses_request_uses_responses_input() -> None: assert "messages" not in captured -async def test_responses_raw_params_and_structured_output() -> None: +async def test_responses_params_and_structured_output() -> None: fake, captured = _patch_responses() await _drain( @@ -166,25 +169,76 @@ async def test_responses_raw_params_and_structured_output() -> None: _MODEL, [ai.user_message("Hi")], output_type=_Answer, - params={ - "reasoning": {"effort": "high"}, - "include": ["file_search_call.results"], - "text": {"verbosity": "low"}, - "extra_headers": {"x-openai-feature": "enabled"}, - }, + params=ai.InferenceRequestParams( + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + output=ai.OutputParams( + include=frozenset({"file_search_call.results"}), + reasoning_summary="auto", + text_verbosity="low", + ), + extra_body={"future_option": True}, + extra_headers={"x-openai-feature": "enabled"}, + ), provider="openai", ) ) - assert captured["reasoning"] == {"effort": "high"} + assert captured["reasoning"] == {"effort": "high", "summary": "auto"} + assert captured["context_management"] == [ + {"type": "compaction", "compact_threshold": 120_000} + ] assert captured["include"] == ["file_search_call.results"] assert captured["extra_headers"] == {"x-openai-feature": "enabled"} + assert captured["extra_body"] == {"future_option": True} assert captured["text"]["verbosity"] == "low" assert captured["text"]["format"]["type"] == "json_schema" assert captured["text"]["format"]["name"] == "_Answer" assert captured["text"]["format"]["strict"] is True +async def test_chat_rejects_text_verbosity( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch(monkeypatch) + + with pytest.raises(ValueError, match="text verbosity"): + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + output=ai.OutputParams(text_verbosity="low") + ), + provider="openai", + ) + ) + + +async def test_chat_rejects_context_management( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch(monkeypatch) + + with pytest.raises(ValueError, match="context management"): + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ) + ), + provider="openai", + ) + ) + + async def test_responses_tools_convert_function_and_provider_tools() -> None: fake, captured = _patch_responses() @@ -421,7 +475,7 @@ async def test_system_messages_use_openai_system_role( assert captured["messages"][0] == {"role": "system", "content": "rules"} -async def test_raw_params_pass_through_to_sdk_kwargs( +async def test_params_translate_to_sdk_kwargs( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, captured = _patch(monkeypatch) @@ -431,27 +485,71 @@ async def test_raw_params_pass_through_to_sdk_kwargs( fake, _MODEL, [ai.user_message("Hi")], - params={ - "logprobs": 3, - "verbosity": "low", - "max_completion_tokens": 128, - "extra_body": {"future_option": True}, - "extra_headers": {"x-openai-feature": "enabled"}, - "stream_options": {"include_usage": False, "custom": True}, - }, + params=ai.InferenceRequestParams( + sampling={ + ai.TemperatureSamplerParams: ai.TemperatureSamplerParams( + temperature=0.2 + ), + ai.TopPSamplerParams: ai.TopPSamplerParams(top_p=0.9), + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123), + }, + output=ai.OutputParams(max_tokens=128), + provider_service=ai.ProviderServiceParams(service_tier="auto"), + extra_body={"future_option": True, "verbosity": "low"}, + extra_headers={"x-openai-feature": "enabled"}, + ), provider="openai", ) ) - assert captured["logprobs"] == 3 - assert captured["verbosity"] == "low" + assert captured["temperature"] == 0.2 + assert captured["top_p"] == 0.9 + assert captured["seed"] == 123 assert captured["max_completion_tokens"] == 128 - assert captured["extra_body"] == {"future_option": True} + assert captured["service_tier"] == "auto" + assert captured["extra_body"] == {"future_option": True, "verbosity": "low"} assert captured["extra_headers"] == {"x-openai-feature": "enabled"} - assert captured["stream_options"] == { - "include_usage": False, - "custom": True, - } + + +async def test_chat_omits_explicit_random_seed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, captured = _patch(monkeypatch) + + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=ai.RANDOM) + } + ), + provider="openai", + ) + ) + + assert "seed" not in captured + + +async def test_responses_rejects_seed() -> None: + fake, _ = _patch_responses() + + with pytest.raises(ValueError, match="seed"): + await _drain( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123) + } + ), + provider="openai", + ) + ) async def test_strict_json_schema_flows_into_response_format( @@ -472,7 +570,7 @@ async def test_strict_json_schema_flows_into_response_format( assert captured["response_format"]["json_schema"]["strict"] is True -async def test_non_dict_params_rejected_by_adapter( +async def test_non_inference_params_rejected_by_adapter( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, _ = _patch(monkeypatch) @@ -481,11 +579,11 @@ async def test_non_dict_params_rejected_by_adapter( fake, _MODEL, [ai.user_message("Hi")], - params=[{"reasoning_effort": "high"}], + params=cast(Any, [{"reasoning_effort": "high"}]), provider="openai", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): await _drain(stream) diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index d65eecc9..9a7db8b3 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -630,7 +630,7 @@ async def test_generate_sanitizes_internal_messages() -> None: async def _spy_gen( model: models.Model, messages: list[messages.Message], - params: Any, + params: models.GenerateParams, ) -> messages.Message: received.append(list(messages)) return sentinel diff --git a/uv.lock b/uv.lock index 0f3f5c68..e168bba7 100644 --- a/uv.lock +++ b/uv.lock @@ -106,7 +106,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.83.0" +version = "0.103.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -118,9 +118,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/db/e5/02cd2919ec327b24234abb73082e6ab84c451182cc3cc60681af700f4c63/anthropic-0.83.0.tar.gz", hash = "sha256:a8732c68b41869266c3034541a31a29d8be0f8cd0a714f9edce3128b351eceb4", size = 534058, upload-time = "2026-02-19T19:26:38.904Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/57/0b758b08cf4606c94d63a997d67a0063f7438efbaf81cfedd0d7c0c69d67/anthropic-0.103.1.tar.gz", hash = "sha256:21c12f4fc0fdd87a2e80d58479cd0af640062b3cfb82bbfa01c7977acd4defeb", size = 848877, upload-time = "2026-05-19T15:43:27.698Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/75/b9d58e4e2a4b1fc3e75ffbab978f999baf8b7c4ba9f96e60edb918ba386b/anthropic-0.83.0-py3-none-any.whl", hash = "sha256:f069ef508c73b8f9152e8850830d92bd5ef185645dbacf234bb213344a274810", size = 456991, upload-time = "2026-02-19T19:26:40.114Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ec/cf357cf571377a39552c1530390a9b79bbdb6ea463f48fbe4e3624141e3b/anthropic-0.103.1-py3-none-any.whl", hash = "sha256:b9a523fac34e64caf6ee55fdbda213950e6a744b906fce100d34909aad2cd8f4", size = 832551, upload-time = "2026-05-19T15:43:29.663Z" }, ] [[package]] @@ -741,7 +741,7 @@ wheels = [ [[package]] name = "openai" -version = "2.14.0" +version = "2.37.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -753,9 +753,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d8/b1/12fe1c196bea326261718eb037307c1c1fe1dedc2d2d4de777df822e6238/openai-2.14.0.tar.gz", hash = "sha256:419357bedde9402d23bf8f2ee372fca1985a73348debba94bddff06f19459952", size = 626938, upload-time = "2025-12-19T03:28:45.742Z" } +sdist = { url = "https://files.pythonhosted.org/packages/32/50/5901f01ef14e6c27788beb91e54fef5d6204fb5fb9e97402fc8a14de2e32/openai-2.37.0.tar.gz", hash = "sha256:f4bc562cc5f3a43d40d678105572d9d44765f6e0f50c125f63055419b72f4bd9", size = 754706, upload-time = "2026-05-15T22:30:35.428Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/4b/7c1a00c2c3fbd004253937f7520f692a9650767aa73894d7a34f0d65d3f4/openai-2.14.0-py3-none-any.whl", hash = "sha256:7ea40aca4ffc4c4a776e77679021b47eec1160e341f42ae086ba949c9dcc9183", size = 1067558, upload-time = "2025-12-19T03:28:43.727Z" }, + { url = "https://files.pythonhosted.org/packages/ed/4c/bce61680d0699a78a405fd9a67989b175ba020590428831aab2ab1d2be7c/openai-2.37.0-py3-none-any.whl", hash = "sha256:814633888b8f3b1ffd6615697c6e4ef93632d08b7c2e28c8c5ef3556e5a10107", size = 1303238, upload-time = "2026-05-15T22:30:32.767Z" }, ] [[package]]