diff --git a/examples/coding_agent_minimal.py b/examples/coding_agent_minimal.py index e5cf41bc..b5c4f1e4 100644 --- a/examples/coding_agent_minimal.py +++ b/examples/coding_agent_minimal.py @@ -3,7 +3,6 @@ import asyncio import json import sys -import typing import ai @@ -16,9 +15,7 @@ with sed, make sure to double check the result. """ -STREAM_PARAMS: dict[str, typing.Any] = { - "providerOptions": {"gateway": {"caching": "auto"}}, -} +STREAM_PARAMS = ai.InferenceRequestParams(cache=ai.CacheParams(mode="auto")) @ai.tool diff --git a/examples/model_params.py b/examples/model_params.py index 75d0c65a..ec231766 100644 --- a/examples/model_params.py +++ b/examples/model_params.py @@ -13,12 +13,12 @@ async def main() -> None: - params = { - "providerOptions": { - "gateway": {"sort": "cost"}, - "anthropic": {"speed": "fast"}, - } - } + params = ai.InferenceRequestParams( + routing=ai.RoutingParams( + provider_ranking=ai.ProviderRankingStrategy.COST + ), + extra_body={"providerOptions": {"anthropic": {"speed": "fast"}}}, + ) async with ai.stream(model, messages, params=params) as stream: async for event in stream: if isinstance(event, ai.events.TextDelta): diff --git a/examples/prompt_caching.py b/examples/prompt_caching.py index 140e734d..1dcb0006 100644 --- a/examples/prompt_caching.py +++ b/examples/prompt_caching.py @@ -142,7 +142,7 @@ async def _run(user_text: str) -> ai.types.usage.Usage | None: ai.system_message(SYSTEM_PROMPT), ai.user_message(user_text), ] - params = {"providerOptions": {"gateway": {"caching": "auto"}}} + params = ai.InferenceRequestParams(cache=ai.CacheParams(mode="auto")) async with agent.run(model, messages, params=params) as stream: async for _event in stream: pass diff --git a/src/ai/__init__.py b/src/ai/__init__.py index 066d9aaa..d0bce301 100644 --- a/src/ai/__init__.py +++ b/src/ai/__init__.py @@ -48,11 +48,41 @@ UnsupportedProviderError, ) from .models import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GeoRegion, ImageParams, + InferenceRequestParams, + MinPSamplerParams, Model, + ModelProviderDefault, + OutputParams, Provider, ProviderProtocol, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, Stream, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, VideoParams, generate, get_model, @@ -72,18 +102,27 @@ ) __all__ = [ - # Models (from models/) + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", "AIError", - # Agents — primary API "Agent", - # Agents — tools "AgentTool", + "CacheParams", + "CloudRegion", "ConfigurationError", "Context", + "ContextManagementParams", + "GeoRegion", "HTTPErrorContext", "ImageParams", + "InferenceRequestParams", "InstallationError", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", "ProviderAPIError", "ProviderAuthenticationError", @@ -99,20 +138,38 @@ "ProviderOverloadedError", "ProviderPermissionDeniedError", "ProviderProtocol", + "ProviderRankingStrategy", "ProviderRateLimitError", "ProviderRequestTooLargeError", "ProviderResponseError", + "ProviderServiceParams", "ProviderServiceUnavailableError", "ProviderStatusError", "ProviderTimeoutError", "ProviderUnprocessableEntityError", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamingStatusTool", "StreamingTextTool", "SubAgentTool", + "TemperatureSamplerParams", + "TokenThreshold", "Tool", "ToolCall", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", "ToolRunner", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "UnsupportedProviderError", "VideoParams", "abort_pending_hook", @@ -120,13 +177,11 @@ "assistant_message", "cancel_hook", "errors", - # Submodules "events", "file_part", "generate", "get_model", "get_provider", - # Agents — hooks "hook", "mcp", "messages", @@ -143,9 +198,7 @@ "tool_result", "tool_result_part", "tools", - # Builders (from types/builders) "user_message", "util", - # Agents — composition "yield_from", ] diff --git a/src/ai/_types.py b/src/ai/_types.py new file mode 100644 index 00000000..e9646855 --- /dev/null +++ b/src/ai/_types.py @@ -0,0 +1,10 @@ +from collections.abc import Iterator +from typing import Protocol, TypeVar + +_T_co = TypeVar("_T_co", covariant=True) + + +class Collection(Protocol[_T_co]): + def __contains__(self, value: object, /) -> bool: ... + def __iter__(self) -> Iterator[_T_co]: ... + def __len__(self) -> int: ... diff --git a/src/ai/agents/_middleware.py b/src/ai/agents/_middleware.py index 4aa0245f..df1f4fe9 100644 --- a/src/ai/agents/_middleware.py +++ b/src/ai/agents/_middleware.py @@ -43,6 +43,7 @@ import pydantic from ..models.core.model import Model + from ..models.core.params import GenerateParams from ..types import events as events_ from ..types.tools import Tool from .agent import Context @@ -71,7 +72,7 @@ class GenerateContext: model: Model messages: list[messages_.Message] - params: Any + params: GenerateParams def __post_init__(self) -> None: object.__setattr__(self, "messages", list(self.messages)) diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 90bfd449..6d7c658f 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -813,7 +813,9 @@ class Context(pydantic.BaseModel): output_type: type[pydantic.BaseModel] | None = pydantic.Field( default=None, exclude=True, repr=False ) - params: Any = pydantic.Field(default=None, exclude=True, repr=False) + params: models.InferenceRequestParams | None = pydantic.Field( + default=None, exclude=True, repr=False + ) _agent_tools_by_name: dict[str, AgentTool] = pydantic.PrivateAttr( default_factory=dict @@ -1184,7 +1186,7 @@ def run( model: models.Model, messages: list[types.messages.Message], *, - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[str]]: ... @overload @@ -1194,7 +1196,7 @@ def run[T: pydantic.BaseModel]( messages: list[types.messages.Message], *, output_type: type[T], - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[T]]: ... def run( @@ -1203,7 +1205,7 @@ def run( messages: list[types.messages.Message], *, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, _middleware: list[middleware_._Middleware] | None = None, ) -> AbstractAsyncContextManager[AgentStream[Any]]: """Run the agent loop, yielding events to the consumer. @@ -1243,7 +1245,7 @@ async def _run( messages: list[types.messages.Message], *, output_type: type[pydantic.BaseModel] | None, - params: Any, + params: models.InferenceRequestParams | None, _middleware: list[middleware_._Middleware] | None, ) -> AsyncIterator[AgentStream[Any]]: context = Context( diff --git a/src/ai/models/__init__.py b/src/ai/models/__init__.py index f2a2cb4f..29cbb800 100644 --- a/src/ai/models/__init__.py +++ b/src/ai/models/__init__.py @@ -45,23 +45,85 @@ stream, ) from .core.model import Model, get_model -from .core.params import GenerateParams, ImageParams, VideoParams +from .core.params import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GenerateParams, + GeoRegion, + ImageParams, + InferenceRequestParams, + MinPSamplerParams, + ModelProviderDefault, + OutputParams, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, + VideoParams, +) __all__ = [ - # Core types + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", + "CacheParams", + "CloudRegion", + "ContextManagementParams", "Executor", "GenerateExecutor", "GenerateParams", "GenerateRequest", + "GeoRegion", "ImageParams", + "InferenceRequestParams", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", "ProviderProtocol", + "ProviderRankingStrategy", + "ProviderServiceParams", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamExecutor", "StreamRequest", + "TemperatureSamplerParams", + "TokenThreshold", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "VideoParams", - # Public API "generate", "get_model", "probe", diff --git a/src/ai/models/core/__init__.py b/src/ai/models/core/__init__.py index 03eecffc..a826cfec 100644 --- a/src/ai/models/core/__init__.py +++ b/src/ai/models/core/__init__.py @@ -14,19 +14,83 @@ stream, ) from .model import Model, get_model -from .params import GenerateParams, ImageParams, VideoParams +from .params import ( + DEFAULT, + GLOBAL, + RANDOM, + UNSET, + CacheParams, + CloudRegion, + ContextManagementParams, + GenerateParams, + GeoRegion, + ImageParams, + InferenceRequestParams, + MinPSamplerParams, + ModelProviderDefault, + OutputParams, + ProviderRankingStrategy, + ProviderServiceParams, + RandomSeed, + ReasoningParams, + RepetitionPenaltyParams, + RoutingParams, + RoutingTarget, + RoutingTargetChain, + SeedSamplerParams, + TemperatureSamplerParams, + TokenThreshold, + ToolCallingParams, + ToolChoiceMode, + ToolRef, + ToolSelection, + TopKSamplerParams, + TopPSamplerParams, + Unset, + VideoParams, +) __all__ = [ + "DEFAULT", + "GLOBAL", + "RANDOM", + "UNSET", + "CacheParams", + "CloudRegion", + "ContextManagementParams", "Executor", "GenerateExecutor", "GenerateParams", "GenerateRequest", + "GeoRegion", "ImageParams", + "InferenceRequestParams", + "MinPSamplerParams", "Model", + "ModelProviderDefault", + "OutputParams", "Provider", + "ProviderRankingStrategy", + "ProviderServiceParams", + "RandomSeed", + "ReasoningParams", + "RepetitionPenaltyParams", + "RoutingParams", + "RoutingTarget", + "RoutingTargetChain", + "SeedSamplerParams", "Stream", "StreamExecutor", "StreamRequest", + "TemperatureSamplerParams", + "TokenThreshold", + "ToolCallingParams", + "ToolChoiceMode", + "ToolRef", + "ToolSelection", + "TopKSamplerParams", + "TopPSamplerParams", + "Unset", "VideoParams", "generate", "get_model", diff --git a/src/ai/models/core/api.py b/src/ai/models/core/api.py index c95e906d..9916daff 100644 --- a/src/ai/models/core/api.py +++ b/src/ai/models/core/api.py @@ -41,7 +41,7 @@ class StreamRequest: messages: list[types.messages.Message] tools: Sequence[types.tools.Tool] | None = None output_type: type[pydantic.BaseModel] | None = None - params: Any = None + params: params_.InferenceRequestParams | None = None @dataclasses.dataclass(frozen=True) @@ -369,14 +369,14 @@ def tools(self) -> list[types.tools.Tool]: ... @property def output_type(self) -> type[pydantic.BaseModel] | None: ... @property - def params(self) -> Any: ... + def params(self) -> params_.InferenceRequestParams | None: ... @overload def stream( *, context: StreamContext, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -384,7 +384,7 @@ def stream[T: pydantic.BaseModel]( *, context: StreamContext, output_type: type[T], - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... @overload @@ -393,7 +393,7 @@ def stream( messages: list[types.messages.Message], *, tools: Sequence[types.tools.Tool] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[str]]: ... @overload @@ -403,7 +403,7 @@ def stream[T: pydantic.BaseModel]( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[T], - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[T]]: ... def stream( @@ -413,7 +413,7 @@ def stream( context: StreamContext | None = None, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, executor: StreamExecutor = _default_executor, ) -> AbstractAsyncContextManager[Stream[Any]]: """Stream an LLM response. @@ -464,7 +464,7 @@ async def _stream( messages: list[types.messages.Message], tools: Sequence[types.tools.Tool] | None, output_type: type[pydantic.BaseModel] | None, - params: Any, + params: params_.InferenceRequestParams | None, executor: StreamExecutor, ) -> AsyncIterator[Stream[Any]]: if messages and messages[-1].replay: diff --git a/src/ai/models/core/params.py b/src/ai/models/core/params.py index 9969c31c..ee7d906f 100644 --- a/src/ai/models/core/params.py +++ b/src/ai/models/core/params.py @@ -1,7 +1,39 @@ -from typing import Any +import dataclasses +import enum +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any, Self, final import pydantic + +@final +class ModelProviderDefault: + """Sentinel for params: default value used by the model/provider.""" + + +DEFAULT = ModelProviderDefault() +"""Sentinel for params: default value used by the model/provider.""" + + +@final +class Unset: + """Sentinel for a value that should be unset/omitted in a request.""" + + +UNSET = Unset() +"""Sentinel for a value that should be unset/omitted in a request.""" + + +@final +class RandomSeed: + """Sentinel for explicitly random seed selection.""" + + +RANDOM = RandomSeed() +"""Sentinel requesting provider-random seed selection.""" + + _PARAMS_CONFIG = pydantic.ConfigDict(frozen=True, populate_by_name=True) @@ -40,3 +72,345 @@ class VideoParams(pydantic.BaseModel): GenerateParams = ImageParams | VideoParams + + +@dataclass(frozen=True, kw_only=True) +class TemperatureSamplerParams: + """Temperature sampling controls.""" + + temperature: float | ModelProviderDefault = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class TopKSamplerParams: + """Top-k sampling controls.""" + + top_k: int | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class TopPSamplerParams: + """Nucleus sampling controls.""" + + top_p: float | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class MinPSamplerParams: + """Minimum probability sampling controls.""" + + min_p: float | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class RepetitionPenaltyParams: + """Penalty controls for repeated or overrepresented tokens.""" + + repetition_penalty: float | ModelProviderDefault | None = DEFAULT + frequency_penalty: float | ModelProviderDefault | None = DEFAULT + presence_penalty: float | ModelProviderDefault | None = DEFAULT + consideration_window: int | ModelProviderDefault | None = DEFAULT + + +@dataclass(frozen=True, kw_only=True) +class SeedSamplerParams: + """Random seed controls for sampling.""" + + seed: int | RandomSeed | ModelProviderDefault | None = DEFAULT + + +SamplerParams = ( + TemperatureSamplerParams + | TopKSamplerParams + | TopPSamplerParams + | MinPSamplerParams + | RepetitionPenaltyParams + | SeedSamplerParams +) + + +type SamplerParamsMap = dict[type[SamplerParams], SamplerParams] + + +type ProviderParamsMap = dict[type[Any], Any] + + +class ToolChoiceMode(enum.StrEnum): + """Built-in policies for model tool selection.""" + + AUTO = "auto" + NONE = "none" + REQUIRED = "required" + + +class ToolRef(str): + """A reference to a specific tool (by tool name).""" + + def __repr__(self) -> str: + return f"" + + +@dataclass(frozen=True, kw_only=True, init=False) +class ToolSelection: + """Tool subset paired with a tool choice policy.""" + + tools: frozenset[ToolRef] + mode: ToolChoiceMode + + def __init__( + self, + tools: Iterable[ToolRef] | Iterable[str], + *, + mode: ToolChoiceMode, + ) -> None: + object.__setattr__(self, "tools", frozenset(ToolRef(s) for s in tools)) + object.__setattr__(self, "mode", mode) + + +@dataclass(frozen=True, kw_only=True) +class ToolCallingParams: + """Tool calling parameters.""" + + max_tool_calls: int | ModelProviderDefault | None = DEFAULT + """The maximum number of tool calls the model may make.""" + + parallel_tool_calls: bool | ModelProviderDefault = DEFAULT + """Whether the model may call multiple tools in parallel.""" + + tool_choice: ToolChoiceMode | ToolRef | ToolSelection + """Tool choice policy. + + * `ToolChoiceMode.AUTO`: the model can choose whether and what tools to call + * `ToolChoiceMode.REQUIED`: the model must call (some) tool + * `ToolChoiceMode.NONE`: the model must not call tools + * `ToolRef("tool-name")`: the model must call the specified tool + * `ToolSelection(tools, mode)`: the model should treat the specified set of + tools according to mode. + """ + + +@dataclass(frozen=True, kw_only=True) +class ReasoningParams: + """Model reasoning/thinking options.""" + + effort: str | ModelProviderDefault | None = DEFAULT + """Provider-specific reasoning/thinking effort level. + + None means reasoning is disabled.""" + + +@dataclass(frozen=True, kw_only=True) +class ProviderServiceParams: + """Provider service parameters (service tier).""" + + service_tier: str | ModelProviderDefault = DEFAULT + """Provider-specific service tier.""" + + +@dataclass(frozen=True) +class TokenThreshold: + """Token count used as a trigger threshold.""" + + value: int + """Token count threshold.""" + + +@dataclass(frozen=True, kw_only=True) +class ContextManagementParams: + """Server-side context management parameters.""" + + compaction: TokenThreshold | None = None + """Compaction trigger threshold.""" + + +@dataclass(frozen=True, kw_only=True) +class OutputParams: + """Model output options.""" + + max_tokens: int | None = None + """The maximum number of tokens to generate before stopping.""" + + include: frozenset[str] | None = None + """Additional provider-specific data to include in the model response.""" + + text_verbosity: str | ModelProviderDefault | None = DEFAULT + """Provider-specific text verbosity level.""" + + reasoning_summary: str | ModelProviderDefault | None = DEFAULT + """Provider-specific reasoning summary emission level. + + None means "disabled".""" + + +@dataclass(frozen=True, kw_only=True) +class CacheParams: + """Provider prompt caching behavior.""" + + mode: str | ModelProviderDefault = DEFAULT + """Provider-specific cache mode.""" + + retention: str | ModelProviderDefault = DEFAULT + """Provider-specific cache retention period (time-to-live).""" + + key: str | None = None + """Custom cache key component. Support is provider-specific.""" + + +@final +class GlobalRoutingTarget: + """Sentinel for globally scoped request routing.""" + + def __repr__(self) -> str: + return "GLOBAL" + + +GLOBAL = GlobalRoutingTarget() +"""Sentinel requesting globally scoped request routing.""" + + +class GeoRegion(str): + """A broad geography, e.g. ``us`` or ``eu``.""" + + +class CloudRegion(str): + """A specific cloud/provider region, e.g. ``us-east-1``.""" + + +type RoutingTarget = GlobalRoutingTarget | GeoRegion | CloudRegion + + +@dataclass(frozen=True, kw_only=True) +class RoutingTargetChain: + """Separate Gateway and provider routing targets.""" + + gateway: RoutingTarget + provider: RoutingTarget + + +type RoutingTargetParam = RoutingTarget | RoutingTargetChain + + +class ProviderRankingStrategy(enum.StrEnum): + """Provider ranking strategy.""" + + COST = "cost" + TTFT = "ttft" + TPS = "tps" + PRICE = "price" + LATENCY = "latency" + THROUGHPUT = "throughput" + + +@dataclass(frozen=True, kw_only=True) +class RoutingParams: + """Inference request routing options.""" + + routing_target: RoutingTargetParam | None = None + """Request (geo-/region-) routing target.""" + + provider_allowlist: frozenset[str] | None = None + """Restrict gateway routing to these providers.""" + + provider_order: tuple[str, ...] | None = None + """Preferred provider order.""" + + provider_ranking: ProviderRankingStrategy | None = None + """Dynamic provider sorting strategy.""" + + fallback_models: tuple[str, ...] | None = None + """Fallback models to try after the requested model.""" + + +@dataclass(frozen=True, kw_only=True) +class InferenceRequestParams: + """Model inference request parameters.""" + + sampling: SamplerParamsMap | ModelProviderDefault = DEFAULT + """Advanced token sampling parameters (e.g temperature, max_p etc).""" + + reasoning: ReasoningParams | ModelProviderDefault = DEFAULT + """Model reasoning parameters.""" + + tool_calling: ToolCallingParams | None = None + """Tool calling parameters.""" + + provider_service: ProviderServiceParams | None = None + """Provider-specific service parameters (service tier etc).""" + + safety_identifier: str | None = None + """A stable identifier used for safety monitoring and abuse detection.""" + + metadata: Mapping[str, str] | None = None + """User-specified metadata associated with the request. + + Note that not all providers support attaching metadata to inference + requests, and the ones that do might place restrictions on length of + metadata both in terms of overall length and in terms of individual items. + For example, OpenAI and Open Responses-compatible providers specify that + keys must have a maximum length of 64 characters, values have a maximum + length of 512 characters, and the total number of metadata items must not + exceed 16.""" + + tags: frozenset[str] | None = None + """User-specified tags associated with the request. + + Note that not all providers support attaching tags to inference + requests, and the ones that do might place restrictions on length of + the tags collections as well as restrictions on individual tag value + length. For example, Vercel AI Gateway limits the number of tags to + 10 and the length of each tag to 64 characters. + """ + + output: OutputParams | None = None + """Model output configuration.""" + + cache: CacheParams | None = None + """Prompt cache parameters.""" + + routing: RoutingParams | None = None + """Request routing parameters.""" + + context_management: ContextManagementParams | None = None + """Context management parameters.""" + + provider_params: ProviderParamsMap | None = None + """Provider-specific typed request parameters keyed by params type.""" + + extra_headers: Mapping[str, str | Unset] | None = None + """Extra headers to pass to the provider API.""" + + extra_query: Mapping[str, Any] | None = None + """Extra URL query string arguments to pass to the provider API.""" + + extra_body: Mapping[str, Any] | None = None + """Extra body arguments to pass to the provider API.""" + + def with_temperature( + self, temperature: float | ModelProviderDefault + ) -> Self: + temp_sampling_params: SamplerParamsMap = { + TemperatureSamplerParams: TemperatureSamplerParams( + temperature=temperature + ) + } + + if type(self.sampling) is ModelProviderDefault: + sampling = temp_sampling_params + else: + sampling = self.sampling | temp_sampling_params + + return dataclasses.replace(self, sampling=sampling) + + def with_reasoning_effort( + self, effort: str | ModelProviderDefault | None + ) -> Self: + return dataclasses.replace( + self, + reasoning=ReasoningParams(effort=effort), + ) + + def with_provider_params(self, *provider_params: object) -> Self: + params = dict(self.provider_params or {}) + for value in provider_params: + params[type(value)] = value + return dataclasses.replace(self, provider_params=params) diff --git a/src/ai/providers/ai_gateway/__init__.py b/src/ai/providers/ai_gateway/__init__.py index 3f9e71bd..759afb8a 100644 --- a/src/ai/providers/ai_gateway/__init__.py +++ b/src/ai/providers/ai_gateway/__init__.py @@ -9,11 +9,13 @@ model = ai.get_model("gateway:anthropic/claude-sonnet-4") ids = await ai.get_provider("vercel").list_models() - # Provider-specific request options pass through as raw Gateway body fields. + # Raw provider-specific request options pass through in extra_body. async with ai.stream( model, msgs, - params={"providerOptions": {"anthropic": {"speed": "fast"}}}, + params=ai.InferenceRequestParams( + extra_body={"providerOptions": {"anthropic": {"speed": "fast"}}} + ), tools=[anthropic_tools.web_search(max_uses=5)], ) as s: ... @@ -30,12 +32,15 @@ """ from . import errors, tools +from .params import GatewayParams, ProviderTimeoutsParams from .protocol import GatewayV3Protocol from .provider import GatewayProvider __all__ = [ + "GatewayParams", "GatewayProvider", "GatewayV3Protocol", + "ProviderTimeoutsParams", "errors", "tools", ] diff --git a/src/ai/providers/ai_gateway/client/_client.py b/src/ai/providers/ai_gateway/client/_client.py index d90972c8..206f0b2e 100644 --- a/src/ai/providers/ai_gateway/client/_client.py +++ b/src/ai/providers/ai_gateway/client/_client.py @@ -189,6 +189,7 @@ async def stream( streaming: bool = False, accept: str | None = None, headers: dict[str, str] | None = None, + query: Mapping[str, Any] | None = None, timeout: httpx.Timeout | float | None = None, ) -> AsyncIterator[httpx.Response]: request_headers = self.model_headers( @@ -206,6 +207,7 @@ async def stream( self.url(path), json=body, headers=request_headers, + params=query, ) if timeout is None else self._http.stream( @@ -213,6 +215,7 @@ async def stream( self.url(path), json=body, headers=request_headers, + params=query, timeout=timeout, ) ) diff --git a/src/ai/providers/ai_gateway/params.py b/src/ai/providers/ai_gateway/params.py new file mode 100644 index 00000000..55283dc5 --- /dev/null +++ b/src/ai/providers/ai_gateway/params.py @@ -0,0 +1,37 @@ +from collections.abc import Iterable, Mapping +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, kw_only=True) +class ProviderTimeoutsParams: + """Gateway per-provider timeout configuration.""" + + byok: Mapping[str, int] | None = None + """Per-provider BYOK attempt timeouts in milliseconds.""" + + +@dataclass(frozen=True, kw_only=True) +class GatewayParams: + """Vercel AI Gateway-specific request parameters.""" + + quota_entity_id: str | None = None + """Gateway quota bucket/entity identifier.""" + + zero_data_retention: bool | None = None + """Request zero-data-retention routing.""" + + hipaa_compliant: bool | None = None + """Require HIPAA-compliant providers.""" + + disallow_prompt_training: bool | None = None + """Require providers that do not train on prompts.""" + + byok: Mapping[str, Iterable[Mapping[str, Any]]] | None = None + """Request-supplied BYOK credentials, keyed by provider.""" + + provider_timeouts: ProviderTimeoutsParams | None = None + """Per-provider routing timeout configuration.""" + + +__all__ = ["GatewayParams", "ProviderTimeoutsParams"] diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index a19e8a8d..992ef2d7 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -6,19 +6,21 @@ import base64 import json -from collections.abc import AsyncGenerator, Mapping, Sequence -from typing import Any +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence +from typing import Any, TypeVar import httpx import pydantic from ... import types from ...models import core +from ...models.core import params as params_ from .. import base from ..anthropic import tools as anthropic_tools from ..openai import tools as openai_tools from . import client as gateway_client from . import errors +from . import params as gateway_params from . import tools as gateway_tools from .client import errors as client_errors @@ -27,6 +29,30 @@ # --------------------------------------------------------------------------- +_ProviderParamsT = TypeVar("_ProviderParamsT") + + +def _provider_params_value( + value: Mapping[type[Any], Any] | None, + params_type: type[_ProviderParamsT], + *, + provider: str, +) -> _ProviderParamsT | None: + if value is None: + return None + if not isinstance(value, Mapping): + raise TypeError(f"{provider} provider_params must be a mapping") + provider_params = value.get(params_type) + if provider_params is None: + return None + if not isinstance(provider_params, params_type): + raise TypeError( + f"{provider} provider_params[{params_type.__name__}] " + f"must be {params_type.__name__}" + ) + return provider_params + + def _extract_prompt(messages: list[types.messages.Message]) -> str: """Concatenate all text from user/system messages into one prompt.""" parts: list[str] = [] @@ -248,11 +274,10 @@ async def _build_request_body( messages: list[types.messages.Message], tools: Sequence[types.tools.Tool] | None = None, output_type: type[Any] | None = None, - params: Any = None, + params: Mapping[str, Any] | None = None, ) -> dict[str, Any]: """Build the ``LanguageModelV3CallOptions`` request body.""" - stream_params = _coerce_params(params) - body: dict[str, Any] = dict(stream_params) + body: dict[str, Any] = dict(params or {}) body["prompt"] = await _messages_to_prompt(messages) if tools: body["tools"] = [_tool_to_v3(tool) for tool in tools] @@ -265,12 +290,423 @@ async def _build_request_body( return body -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _provider_from_model_id(model_id: str) -> str | None: + provider, sep, _ = model_id.partition("/") + return provider if sep else None + + +def _body_provider_options( + body: dict[str, Any], provider: str +) -> dict[str, Any]: + provider_options = body.setdefault("providerOptions", {}) + if not isinstance(provider_options, dict): + raise TypeError("providerOptions must be a dict") + options = provider_options.setdefault(provider, {}) + if not isinstance(options, dict): + raise TypeError(f"providerOptions.{provider} must be a dict") + return options + + +def _sequence(value: Iterable[str] | None) -> list[str] | None: + if value is None: + return None + return list(value) + + +def _target_to_inference_region( + target: params_.RoutingTarget, +) -> dict[str, str]: + if target is params_.GLOBAL: + return {"scope": "global"} + if isinstance(target, params_.GeoRegion): + return {"geoRegion": str(target)} + return {"providerRegion": str(target)} + + +def _apply_provider_target( + body: dict[str, Any], + *, + provider: str | None, + target: params_.RoutingTarget, +) -> None: + target_provider = provider or "gateway" + options = _body_provider_options(body, target_provider) + target_value = "global" if target is params_.GLOBAL else str(target) + if target_provider == "anthropic" and isinstance(target, params_.GeoRegion): + options["inferenceGeo"] = target_value + else: + options["region"] = target_value + + +def _routing_to_gateway_options( + routing: params_.RoutingParams, +) -> dict[str, Any]: + options: dict[str, Any] = {} + for key, value in { + "only": sorted(routing.provider_allowlist) + if routing.provider_allowlist is not None + else None, + "order": _sequence(routing.provider_order), + "sort": routing.provider_ranking, + "models": _sequence(routing.fallback_models), + }.items(): + if value is not None: + options[key] = value + + if routing.routing_target is not None: + target = routing.routing_target + gateway_target = ( + target.gateway + if isinstance(target, params_.RoutingTargetChain) + else target + ) + options["inferenceRegion"] = _target_to_inference_region(gateway_target) + + return options + + +def _apply_gateway_routing( + body: dict[str, Any], + routing: params_.RoutingParams | None, + *, + provider: str | None, +) -> None: + if routing is None: + return + _body_provider_options(body, "gateway").update( + _routing_to_gateway_options(routing) + ) + if isinstance(routing.routing_target, params_.RoutingTargetChain): + _apply_provider_target( + body, + provider=provider, + target=routing.routing_target.provider, + ) + + +def _apply_gateway_params( + body: dict[str, Any], value: gateway_params.GatewayParams | None +) -> None: + if value is None: + return + options = _body_provider_options(body, "gateway") + for key, option in { + "quotaEntityId": value.quota_entity_id, + "zeroDataRetention": value.zero_data_retention, + "hipaaCompliant": value.hipaa_compliant, + "disallowPromptTraining": value.disallow_prompt_training, + }.items(): + if option is not None: + options[key] = option + + if value.byok is not None: + options["byok"] = { + provider: [dict(credential) for credential in credentials] + for provider, credentials in value.byok.items() + } + + if value.provider_timeouts is not None: + provider_timeouts: dict[str, Any] = {} + if value.provider_timeouts.byok is not None: + provider_timeouts["byok"] = dict(value.provider_timeouts.byok) + if provider_timeouts: + options["providerTimeouts"] = provider_timeouts + + +def _merge_extra_body( + body: dict[str, Any], extra_body: Mapping[str, Any] +) -> None: + extra = dict(extra_body) + provider_options = extra.pop("providerOptions", None) + body.update(extra) + if provider_options is None: + return + if not isinstance(provider_options, Mapping): + raise TypeError("extra_body.providerOptions must be a mapping") + existing = body.setdefault("providerOptions", {}) + if not isinstance(existing, dict): + raise TypeError("providerOptions must be a dict") + for provider, options in provider_options.items(): + if not isinstance(provider, str): + raise TypeError("providerOptions keys must be strings") + if not isinstance(options, Mapping): + raise TypeError(f"providerOptions.{provider} must be a mapping") + current = existing.setdefault(provider, {}) + if not isinstance(current, dict): + raise TypeError(f"providerOptions.{provider} must be a dict") + current.update(options) + + +def _gateway_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, + body: dict[str, Any], +) -> str | dict[str, str]: + if isinstance(tool_choice, params_.ToolChoiceMode): + return tool_choice.value + if isinstance(tool_choice, params_.ToolRef): + return {"type": "tool", "toolName": str(tool_choice)} + body["activeTools"] = sorted(str(tool) for tool in tool_choice.tools) + return tool_choice.mode.value + + +def _apply_gateway_reasoning( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str | None, +) -> None: + reasoning = request_params.reasoning + output = request_params.output + effort: str | params_.ModelProviderDefault | None = params_.DEFAULT + if not isinstance(reasoning, params_.ModelProviderDefault): + effort = reasoning.effort + summary = params_.DEFAULT if output is None else output.reasoning_summary + if _is_default(effort) and _is_default(summary): + return + if provider == "openai": + options = _body_provider_options(body, "openai") + if _not_default(effort): + options["reasoningEffort"] = effort + if _not_default(summary): + options["reasoningSummary"] = summary + return + if provider == "anthropic": + options = _body_provider_options(body, "anthropic") + if _not_default(effort): + if effort is None: + options["thinking"] = {"type": "disabled"} + else: + options["effort"] = effort + if _not_default(summary): + if summary is None: + options["thinking"] = {"type": "disabled"} + else: + thinking = dict(options.get("thinking") or {}) + thinking.setdefault("type", "adaptive") + thinking["display"] = summary + options["thinking"] = thinking + return + body["reasoning"] = { + key: value + for key, value in { + "effort": effort, + "summary": summary, + }.items() + if _not_default(value) + } + + +def _apply_gateway_context_management( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str | None, +) -> None: + context_management = request_params.context_management + if context_management is None or context_management.compaction is None: + return + threshold = context_management.compaction.value + if provider == "openai": + _body_provider_options(body, "openai")["contextManagement"] = [ + {"type": "compaction", "compactThreshold": threshold} + ] + return + if provider == "anthropic": + _body_provider_options(body, "anthropic")["contextManagement"] = { + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": threshold}, + } + ] + } + return + raise ValueError( + "AI Gateway context management requires an OpenAI or Anthropic model" + ) + + +def _apply_gateway_sampling( + body: dict[str, Any], + request_params: params_.InferenceRequestParams, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + body["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + body["topP"] = top_p + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed): + value = _seed_value(seed) + if value is not None: + body["seed"] = value + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k): + body["topK"] = top_k + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("AI Gateway does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + if _not_default(repetition.frequency_penalty): + body["frequencyPenalty"] = repetition.frequency_penalty + if _not_default(repetition.presence_penalty): + body["presencePenalty"] = repetition.presence_penalty + if ( + _not_default(repetition.repetition_penalty) + and repetition.repetition_penalty is not None + ): + raise ValueError( + "AI Gateway does not support repetition_penalty" + ) + if ( + _not_default(repetition.consideration_window) + and repetition.consideration_window is not None + ): + raise ValueError( + "AI Gateway does not support consideration_window" + ) + + +def _gateway_request_options( + value: params_.InferenceRequestParams | None, + *, + model_id: str, +) -> tuple[dict[str, Any], dict[str, str] | None, dict[str, Any] | None]: if value is None: - return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("ai-gateway stream params must be a dict") + return {}, None, None + if not isinstance(value, params_.InferenceRequestParams): + raise TypeError( + "ai-gateway stream params must be InferenceRequestParams" + ) + + body: dict[str, Any] = {} + provider = _provider_from_model_id(model_id) + _apply_gateway_routing(body, value.routing, provider=provider) + _apply_gateway_params( + body, + _provider_params_value( + value.provider_params, + gateway_params.GatewayParams, + provider="ai-gateway", + ), + ) + _apply_gateway_sampling(body, value) + _apply_gateway_reasoning(body, value, provider=provider) + _apply_gateway_context_management(body, value, provider=provider) + + if value.tool_calling is not None: + tool_calling = value.tool_calling + if _not_default(tool_calling.max_tool_calls): + body["maxToolCalls"] = tool_calling.max_tool_calls + if _not_default(tool_calling.parallel_tool_calls): + body["parallelToolCalls"] = tool_calling.parallel_tool_calls + body["toolChoice"] = _gateway_tool_choice( + tool_calling.tool_choice, + body, + ) + + if value.provider_service is not None: + service = value.provider_service + target_provider = provider or "gateway" + options = _body_provider_options(body, target_provider) + if _not_default(service.service_tier): + options["serviceTier"] = service.service_tier + + if value.safety_identifier is not None: + _body_provider_options(body, "gateway")["user"] = ( + value.safety_identifier + ) + + if value.metadata is not None: + body["metadata"] = dict(value.metadata) + if provider in {"openai", "anthropic"}: + _body_provider_options(body, provider)["metadata"] = dict( + value.metadata + ) + + if value.tags is not None: + _body_provider_options(body, "gateway")["tags"] = sorted(value.tags) + + if value.output is not None: + output = value.output + if output.max_tokens is not None: + body["maxOutputTokens"] = output.max_tokens + if output.include is not None: + if provider == "openai": + _body_provider_options(body, "openai")["include"] = sorted( + output.include + ) + else: + body["include"] = sorted(output.include) + if ( + _not_default(output.text_verbosity) + and output.text_verbosity is not None + ): + raise ValueError("AI Gateway does not support text verbosity") + + if value.cache is not None: + cache = value.cache + if _not_default(cache.mode): + _body_provider_options(body, "gateway")["caching"] = cache.mode + if provider == "openai": + options = _body_provider_options(body, "openai") + if cache.key is not None: + options["promptCacheKey"] = cache.key + if _not_default(cache.retention): + options["promptCacheRetention"] = cache.retention + elif _not_default(cache.retention) or cache.key is not None: + options = _body_provider_options(body, "gateway") + if cache.key is not None: + options["cacheKey"] = cache.key + if _not_default(cache.retention): + options["cacheRetention"] = cache.retention + + if value.extra_body is not None: + _merge_extra_body(body, value.extra_body) + + return ( + body, + _filter_extra_headers(value.extra_headers), + dict(value.extra_query) if value.extra_query is not None else None, + ) # --------------------------------------------------------------------------- @@ -506,10 +942,13 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[types.events.Event]: """Stream an LLM response through the AI Gateway v3 protocol.""" - stream_params = _coerce_params(params) + stream_params, extra_headers, extra_query = _gateway_request_options( + params, + model_id=model.id, + ) body = await _build_request_body( messages, tools=tools, @@ -524,6 +963,8 @@ async def stream( model=model, model_type="language", streaming=True, + headers=extra_headers, + query=extra_query, ) as response: yield types.events.StreamStart() streamed_tool_ids: set[str] = set() @@ -684,7 +1125,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: _ = provider diff --git a/src/ai/providers/ai_gateway/provider.py b/src/ai/providers/ai_gateway/provider.py index 71f5b25e..cf92b4f6 100644 --- a/src/ai/providers/ai_gateway/provider.py +++ b/src/ai/providers/ai_gateway/provider.py @@ -84,7 +84,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the AI Gateway v3 protocol.""" return super().stream( diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 05351b21..ea7896e5 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -8,18 +8,20 @@ import base64 import json -from collections.abc import AsyncGenerator, Mapping, Sequence from typing import TYPE_CHECKING, Any, cast import pydantic from ... import types +from ...models.core import params as params_ from ...types import events from .. import base from . import _sdk, errors from . import tools as anthropic_tools if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Mapping, Sequence + import anthropic from ...models import core @@ -342,12 +344,246 @@ def _to_content_list(content: Any) -> list[dict[str, Any]]: return [{"type": "text", "text": content}] -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _extra_body(api_kwargs: dict[str, Any]) -> dict[str, Any]: + extra_body = api_kwargs.get("extra_body") + if not isinstance(extra_body, dict): + extra_body = {} + api_kwargs["extra_body"] = extra_body + return extra_body + + +def _apply_output_config( + api_kwargs: dict[str, Any], + values: Mapping[str, Any], +) -> None: + output_config = dict(api_kwargs.get("output_config") or {}) + output_config.update(values) + api_kwargs["output_config"] = output_config + + +def _anthropic_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, +) -> dict[str, Any]: + if isinstance(tool_choice, params_.ToolChoiceMode): + match tool_choice: + case params_.ToolChoiceMode.AUTO: + return {"type": "auto"} + case params_.ToolChoiceMode.REQUIRED: + return {"type": "any"} + case params_.ToolChoiceMode.NONE: + return {"type": "none"} + if isinstance(tool_choice, params_.ToolRef): + return {"type": "tool", "name": str(tool_choice)} + if len(tool_choice.tools) == 1 and tool_choice.mode in { + params_.ToolChoiceMode.AUTO, + params_.ToolChoiceMode.REQUIRED, + }: + return {"type": "tool", "name": str(next(iter(tool_choice.tools)))} + raise ValueError("Anthropic does not support allowed tool subsets") + + +def _apply_sampling( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + api_kwargs["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + api_kwargs["top_p"] = top_p + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k): + if top_k is None: + raise ValueError("Anthropic top_k cannot be None") + api_kwargs["top_k"] = top_k + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed) and _seed_value(seed) is not None: + raise ValueError("Anthropic does not support seed") + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("Anthropic does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + unsupported = { + "repetition_penalty": repetition.repetition_penalty, + "frequency_penalty": repetition.frequency_penalty, + "presence_penalty": repetition.presence_penalty, + "consideration_window": repetition.consideration_window, + } + for key, value in unsupported.items(): + if _not_default(value) and value is not None: + raise ValueError(f"Anthropic does not support {key}") + + +def _apply_anthropic_params( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str, +) -> None: + _ = provider + disable_parallel_tool_use = None + _apply_sampling(api_kwargs, request_params) + + reasoning = request_params.reasoning + output = request_params.output + summary = params_.DEFAULT if output is None else output.reasoning_summary + if not isinstance(reasoning, params_.ModelProviderDefault) and _not_default( + reasoning.effort + ): + if reasoning.effort is None: + api_kwargs["thinking"] = {"type": "disabled"} + else: + _apply_output_config(api_kwargs, {"effort": reasoning.effort}) + if _not_default(summary): + if summary is None: + api_kwargs["thinking"] = {"type": "disabled"} + else: + thinking = dict(api_kwargs.get("thinking") or {}) + thinking.setdefault("type", "adaptive") + thinking["display"] = summary + api_kwargs["thinking"] = thinking + + if request_params.tool_calling is not None: + tool_calling = request_params.tool_calling + if ( + _not_default(tool_calling.max_tool_calls) + and tool_calling.max_tool_calls is not None + ): + raise ValueError("Anthropic does not support max_tool_calls") + tool_choice = _anthropic_tool_choice(tool_calling.tool_choice) + if _not_default(tool_calling.parallel_tool_calls): + disable_parallel_tool_use = not tool_calling.parallel_tool_calls + if disable_parallel_tool_use is not None: + if tool_choice["type"] == "none": + raise ValueError( + "Anthropic cannot set parallel tool calls with tool none" + ) + tool_choice["disable_parallel_tool_use"] = disable_parallel_tool_use + api_kwargs["tool_choice"] = tool_choice + elif disable_parallel_tool_use is not None: + api_kwargs["tool_choice"] = { + "type": "auto", + "disable_parallel_tool_use": disable_parallel_tool_use, + } + + if request_params.provider_service is not None: + service = request_params.provider_service + if _not_default(service.service_tier): + api_kwargs["service_tier"] = service.service_tier + + metadata: dict[str, str] = {} + if request_params.metadata is not None: + metadata.update(request_params.metadata) + if request_params.safety_identifier is not None: + metadata["user_id"] = request_params.safety_identifier + if metadata: + api_kwargs["metadata"] = metadata + + if request_params.context_management is not None: + context_management = request_params.context_management + if context_management.compaction is not None: + _extra_body(api_kwargs)["context_management"] = { + "edits": [ + { + "type": "compact_20260112", + "trigger": { + "type": "input_tokens", + "value": context_management.compaction.value, + }, + } + ] + } + + if request_params.output is not None: + output = request_params.output + if output.max_tokens is not None: + api_kwargs["max_tokens"] = output.max_tokens + if output.include is not None: + raise ValueError("Anthropic does not support output include") + if ( + _not_default(output.text_verbosity) + and output.text_verbosity is not None + ): + raise ValueError("Anthropic does not support text verbosity") + + if request_params.cache is not None: + cache = request_params.cache + if cache.key is not None: + raise ValueError("Anthropic does not support cache keys") + cache_control: dict[str, Any] = {"type": "ephemeral"} + if _not_default(cache.retention): + cache_control["ttl"] = cache.retention + api_kwargs["cache_control"] = cache_control + + extra_headers = _filter_extra_headers(request_params.extra_headers) + if extra_headers is not None: + api_kwargs["extra_headers"] = extra_headers + if ( + request_params.context_management is not None + and request_params.context_management.compaction is not None + ): + _add_builtin_beta_headers( + api_kwargs, + {"compact-2026-01-12", "context-management-2025-06-27"}, + ) + if request_params.extra_query is not None: + api_kwargs["extra_query"] = dict(request_params.extra_query) + if request_params.extra_body is not None: + _extra_body(api_kwargs).update(request_params.extra_body) + + +def _coerce_params( + value: params_.InferenceRequestParams | None, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("anthropic stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_anthropic_params( + api_kwargs, + value, + provider=PROVIDER_NAME, + ) + return api_kwargs + raise TypeError("anthropic stream params must be InferenceRequestParams") def _add_builtin_beta_headers( @@ -393,7 +629,7 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: """Stream through the Anthropic messages protocol using *sdk_client*. @@ -407,6 +643,9 @@ async def stream( """ anthropic_sdk = _sdk.import_sdk(provider=provider) stream_params = _coerce_params(params) + if params is not None: + stream_params = {} + _apply_anthropic_params(stream_params, params, provider=provider) system_prompt, anthropic_messages = await _messages_to_anthropic(messages) custom_tools, builtin_tools = _split_tools(tools or ()) @@ -621,7 +860,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: return stream( diff --git a/src/ai/providers/anthropic/provider.py b/src/ai/providers/anthropic/provider.py index 87042a48..ef5ac5fb 100644 --- a/src/ai/providers/anthropic/provider.py +++ b/src/ai/providers/anthropic/provider.py @@ -21,6 +21,7 @@ import pydantic from ...models.core import model as model_ + from ...models.core import params as params_ from ...types import events from ...types import messages as messages_ from ...types import tools as tools_ @@ -136,7 +137,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via the Anthropic messages protocol.""" return super().stream( diff --git a/src/ai/providers/base.py b/src/ai/providers/base.py index 91fd33b9..28d1af10 100644 --- a/src/ai/providers/base.py +++ b/src/ai/providers/base.py @@ -38,7 +38,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events.Event]: """Stream a language-model response using *client*.""" @@ -212,7 +212,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream a language-model response from this provider.""" selected_protocol = model.protocol or self.protocol diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index 41be970e..e6c4c9b6 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -14,6 +14,7 @@ from ... import errors as ai_errors from ... import types from ...models import core +from ...models.core import params as params_ from .. import base from . import _sdk, errors from . import tools as openai_tools @@ -208,12 +209,288 @@ async def _messages_to_openai( # --------------------------------------------------------------------------- -def _coerce_params(value: Any) -> dict[str, Any]: +def _is_default(value: object) -> bool: + return isinstance(value, params_.ModelProviderDefault) + + +def _not_default(value: object) -> bool: + return not _is_default(value) + + +def _seed_value(seed: object) -> int | None: + if seed is None or seed == -1 or isinstance(seed, params_.RandomSeed): + return None + if isinstance(seed, int): + return seed + if isinstance(seed, params_.ModelProviderDefault): + return None + raise TypeError("seed must be an int, RANDOM, DEFAULT, or None") + + +def _filter_extra_headers( + headers: Mapping[str, str | params_.Unset] | None, +) -> dict[str, str] | None: + if headers is None: + return None + return { + key: value + for key, value in headers.items() + if not isinstance(value, params_.Unset) + } + + +def _tool_ref_name(tool_ref: params_.ToolRef | str) -> str: + return str(tool_ref) + + +def _tools_for_openai_allowed( + tools: frozenset[params_.ToolRef], + *, + responses: bool, +) -> list[dict[str, Any]]: + if responses: + return [ + {"type": "function", "name": _tool_ref_name(tool)} + for tool in sorted(tools) + ] + return [ + {"type": "function", "function": {"name": _tool_ref_name(tool)}} + for tool in sorted(tools) + ] + + +def _openai_tool_choice( + tool_choice: params_.ToolChoiceMode + | params_.ToolRef + | params_.ToolSelection, + *, + responses: bool, +) -> Any: + if isinstance(tool_choice, params_.ToolChoiceMode): + return tool_choice.value + if isinstance(tool_choice, params_.ToolRef): + if responses: + return {"type": "function", "name": _tool_ref_name(tool_choice)} + return { + "type": "function", + "function": {"name": _tool_ref_name(tool_choice)}, + } + if tool_choice.mode == params_.ToolChoiceMode.NONE: + raise ValueError("OpenAI allowed tool selection does not support none") + if responses: + return { + "type": "allowed_tools", + "mode": tool_choice.mode.value, + "tools": _tools_for_openai_allowed( + tool_choice.tools, + responses=True, + ), + } + return { + "type": "allowed_tools", + "allowed_tools": { + "mode": tool_choice.mode.value, + "tools": _tools_for_openai_allowed( + tool_choice.tools, + responses=False, + ), + }, + } + + +def _apply_sampling( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + chat: bool, +) -> None: + sampling = request_params.sampling + if isinstance(sampling, params_.ModelProviderDefault): + return + for sampler in sampling.values(): + match sampler: + case params_.TemperatureSamplerParams(temperature=temperature): + if _not_default(temperature): + api_kwargs["temperature"] = temperature + case params_.TopPSamplerParams(top_p=top_p): + if _not_default(top_p): + api_kwargs["top_p"] = top_p + case params_.SeedSamplerParams(seed=seed): + if _not_default(seed): + value = _seed_value(seed) + if value is not None: + if chat: + api_kwargs["seed"] = value + else: + raise ValueError( + "OpenAI responses does not support seed" + ) + case params_.TopKSamplerParams(top_k=top_k): + if _not_default(top_k) and top_k is not None: + raise ValueError("OpenAI does not support top_k") + case params_.MinPSamplerParams(min_p=min_p): + if _not_default(min_p) and min_p is not None: + raise ValueError("OpenAI does not support min_p") + case params_.RepetitionPenaltyParams() as repetition: + if _not_default(repetition.frequency_penalty): + api_kwargs["frequency_penalty"] = ( + repetition.frequency_penalty + ) + if _not_default(repetition.presence_penalty): + api_kwargs["presence_penalty"] = repetition.presence_penalty + if ( + _not_default(repetition.repetition_penalty) + and repetition.repetition_penalty is not None + ): + raise ValueError( + "OpenAI does not support repetition_penalty" + ) + if ( + _not_default(repetition.consideration_window) + and repetition.consideration_window is not None + ): + raise ValueError( + "OpenAI does not support consideration_window" + ) + _ = chat + + +def _apply_common_openai_params( + api_kwargs: dict[str, Any], + request_params: params_.InferenceRequestParams, + *, + provider: str, + responses: bool, +) -> None: + _ = provider + _apply_sampling(api_kwargs, request_params, chat=not responses) + + reasoning = request_params.reasoning + output = request_params.output + effort: str | params_.ModelProviderDefault | None = params_.DEFAULT + if not isinstance(reasoning, params_.ModelProviderDefault): + effort = reasoning.effort + summary = params_.DEFAULT if output is None else output.reasoning_summary + if _not_default(effort) or _not_default(summary): + if responses: + reasoning_kwargs: dict[str, Any] = dict( + api_kwargs.get("reasoning") or {} + ) + if _not_default(effort): + reasoning_kwargs["effort"] = ( + "none" if effort is None else effort + ) + if _not_default(summary): + reasoning_kwargs["summary"] = summary + api_kwargs["reasoning"] = reasoning_kwargs + else: + if _not_default(effort): + api_kwargs["reasoning_effort"] = ( + "none" if effort is None else effort + ) + if _not_default(summary) and summary is not None: + raise ValueError( + "OpenAI chat completions does not support reasoning summary" + ) + + if request_params.tool_calling is not None: + tool_calling = request_params.tool_calling + if _not_default(tool_calling.max_tool_calls): + if responses: + api_kwargs["max_tool_calls"] = tool_calling.max_tool_calls + elif tool_calling.max_tool_calls is not None: + raise ValueError( + "OpenAI chat completions does not support max_tool_calls" + ) + if _not_default(tool_calling.parallel_tool_calls): + api_kwargs["parallel_tool_calls"] = tool_calling.parallel_tool_calls + api_kwargs["tool_choice"] = _openai_tool_choice( + tool_calling.tool_choice, + responses=responses, + ) + + if request_params.provider_service is not None: + service = request_params.provider_service + if _not_default(service.service_tier): + api_kwargs["service_tier"] = service.service_tier + + if request_params.safety_identifier is not None: + api_kwargs["safety_identifier"] = request_params.safety_identifier + if request_params.metadata is not None: + api_kwargs["metadata"] = dict(request_params.metadata) + + if request_params.context_management is not None: + context_management = request_params.context_management + if context_management.compaction is not None: + if responses: + api_kwargs["context_management"] = [ + { + "type": "compaction", + "compact_threshold": ( + context_management.compaction.value + ), + } + ] + else: + raise ValueError( + "OpenAI chat completions does not support " + "context management" + ) + + if request_params.output is not None: + output = request_params.output + if output.max_tokens is not None: + api_kwargs[ + "max_output_tokens" if responses else "max_completion_tokens" + ] = output.max_tokens + if output.include is not None: + if responses: + api_kwargs["include"] = sorted(output.include) + else: + raise ValueError( + "OpenAI chat completions does not support output include" + ) + if _not_default(output.text_verbosity): + if responses: + text_config = dict(api_kwargs.get("text") or {}) + text_config["verbosity"] = output.text_verbosity + api_kwargs["text"] = text_config + elif output.text_verbosity is not None: + raise ValueError( + "OpenAI chat completions does not support text verbosity" + ) + + if request_params.cache is not None: + cache = request_params.cache + if cache.key is not None: + api_kwargs["prompt_cache_key"] = cache.key + if _not_default(cache.retention): + api_kwargs["prompt_cache_retention"] = cache.retention + + extra_headers = _filter_extra_headers(request_params.extra_headers) + if extra_headers is not None: + api_kwargs["extra_headers"] = extra_headers + if request_params.extra_query is not None: + api_kwargs["extra_query"] = dict(request_params.extra_query) + if request_params.extra_body is not None: + api_kwargs["extra_body"] = dict(request_params.extra_body) + + +def _coerce_params( + value: params_.InferenceRequestParams | None, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("openai stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_common_openai_params( + api_kwargs, + value, + provider="openai", + responses=False, + ) + return api_kwargs + raise TypeError("openai stream params must be InferenceRequestParams") async def stream( @@ -223,7 +500,7 @@ async def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: """Stream through the OpenAI chat completions protocol.""" @@ -237,6 +514,14 @@ async def stream( ) stream_params = _coerce_params(params) + if params is not None: + stream_params = {} + _apply_common_openai_params( + stream_params, + params, + provider=provider, + responses=False, + ) openai_messages = await _messages_to_openai(messages) openai_tools = _tools_to_openai(tools) if tools else None @@ -396,7 +681,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: return stream( @@ -431,12 +716,25 @@ def stream( ) -def _coerce_responses_params(value: Any) -> dict[str, Any]: +def _coerce_responses_params( + value: params_.InferenceRequestParams | None, + *, + provider: str, +) -> dict[str, Any]: if value is None: return {} - if isinstance(value, Mapping): - return dict(value) - raise TypeError("openai responses stream params must be a dict") + if isinstance(value, params_.InferenceRequestParams): + api_kwargs: dict[str, Any] = {} + _apply_common_openai_params( + api_kwargs, + value, + provider=provider, + responses=True, + ) + return api_kwargs + raise TypeError( + "openai responses stream params must be InferenceRequestParams" + ) def _json_dumps(value: Any) -> str: @@ -940,11 +1238,11 @@ async def _stream_responses( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: openai_sdk = _sdk.import_sdk(provider=provider) - stream_params = _coerce_responses_params(params) + stream_params = _coerce_responses_params(params, provider=provider) protected = sorted(_RESPONSES_PROTECTED_PARAMS & stream_params.keys()) if protected: raise ValueError( @@ -1452,7 +1750,7 @@ def stream( *, tools: Sequence[types.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[types.events.Event]: return _stream_responses( diff --git a/src/ai/providers/openai/provider.py b/src/ai/providers/openai/provider.py index 92b93044..1541ef4e 100644 --- a/src/ai/providers/openai/provider.py +++ b/src/ai/providers/openai/provider.py @@ -21,6 +21,7 @@ import pydantic from ...models.core import model as model_ + from ...models.core import params as params_ from ...types import events from ...types import messages as messages_ from ...types import tools as tools_ @@ -108,6 +109,7 @@ def _make_sdk_client( api_key=self.api_key or "", default_headers=self.headers, http_client=http_client, + _enforce_credentials=False, ) @property @@ -134,7 +136,7 @@ def stream( *, tools: Sequence[tools_.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: params_.InferenceRequestParams | None = None, ) -> AsyncGenerator[events.Event]: """Stream via this provider's configured OpenAI-compatible protocol.""" return super().stream( diff --git a/tests/conftest.py b/tests/conftest.py index 92ef235d..3f742f05 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ def stream( *, tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, ) -> AsyncGenerator[events_.Event]: if model.protocol is not None: return model.protocol.stream( @@ -76,7 +76,7 @@ async def generate( self, model: models.Model, messages: list[messages_.Message], - params: Any, + params: models.GenerateParams, ) -> messages_.Message: if model.protocol is not None: return await model.protocol.generate( @@ -227,7 +227,7 @@ async def generate( self, model: models.Model, messages: list[messages_.Message], - params: Any = None, + params: models.GenerateParams, ) -> messages_.Message: if self._call_index >= len(self._responses): raise RuntimeError( diff --git a/tests/models/core/test_api.py b/tests/models/core/test_api.py index b4a0d97b..e7c740f0 100644 --- a/tests/models/core/test_api.py +++ b/tests/models/core/test_api.py @@ -33,6 +33,30 @@ def _provider_metadata_marker( return marker +def test_inference_request_params_with_provider_params() -> None: + class GatewayParams: + pass + + class OpenAIParams: + pass + + original_gateway = GatewayParams() + replacement_gateway = GatewayParams() + openai = OpenAIParams() + + base = ai.InferenceRequestParams( + provider_params={GatewayParams: original_gateway} + ) + updated = base.with_provider_params(replacement_gateway, openai) + + assert base.provider_params == {GatewayParams: original_gateway} + assert updated is not base + assert updated.provider_params == { + GatewayParams: replacement_gateway, + OpenAIParams: openai, + } + + async def test_stream_aggregates_registered_adapter_events() -> None: mock = mock_llm([[text_msg("Hello world")]]) @@ -193,7 +217,7 @@ async def _spy_stream( MOCK_PROVIDER._stream_impl = _spy_stream - params = {"raw": "ok"} + params = models.InferenceRequestParams(extra_body={"raw": "ok"}) async with models.stream( MOCK_MODEL, [ai.user_message("Hi")], @@ -257,7 +281,7 @@ def stream( *, tools: Sequence[ai.tools.Tool] | None = None, output_type: type[pydantic.BaseModel] | None = None, - params: Any = None, + params: models.InferenceRequestParams | None = None, provider: str, ) -> AsyncGenerator[events_.Event]: _ = client, model, messages, tools, output_type, params, provider @@ -296,7 +320,7 @@ async def test_generate_dispatches_to_provider() -> None: async def _generate( model: models.Model, messages: list[messages_.Message], - params: Any = None, + params: models.GenerateParams, ) -> messages_.Message: nonlocal called called = True diff --git a/tests/providers/ai_gateway/test_stream.py b/tests/providers/ai_gateway/test_stream.py index 10962a6a..de8ac02f 100644 --- a/tests/providers/ai_gateway/test_stream.py +++ b/tests/providers/ai_gateway/test_stream.py @@ -16,7 +16,7 @@ from __future__ import annotations import json -from typing import Any +from typing import Any, cast import httpx import pytest @@ -24,6 +24,7 @@ import ai from ai import models from ai.models.core import model as model_ +from ai.providers.ai_gateway import GatewayParams, ProviderTimeoutsParams from ai.types import events, messages from .conftest import mock_client, mock_model, sse, user_msg @@ -345,22 +346,30 @@ def handler(req: httpx.Request) -> httpx.Response: httpx.MockTransport(handler), model_id="anthropic/claude-sonnet-4", ) - request_params = { - "providerOptions": { - "gateway": { - "order": ["bedrock", "anthropic"], - "zeroDataRetention": True, - }, - "anthropic": { - "speed": "fast", - "futureAnthropicField": True, - }, - "google": { - "thinkingConfig": {"budgetTokens": 1024}, + request_params = ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123)}, + output=ai.OutputParams(reasoning_summary="detailed"), + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + routing=ai.RoutingParams(provider_order=("bedrock", "anthropic")), + provider_params={ + GatewayParams: GatewayParams(zero_data_retention=True) + }, + extra_body={ + "providerOptions": { + "anthropic": { + "speed": "fast", + "futureAnthropicField": True, + }, + "google": { + "thinkingConfig": {"budgetTokens": 1024}, + }, }, + "futureGatewayField": True, }, - "futureGatewayField": True, - } + ) async with models.stream( model, [user_msg("Hi")], @@ -375,15 +384,165 @@ def handler(req: httpx.Request) -> httpx.Response: "zeroDataRetention": True, }, "anthropic": { + "effort": "high", "speed": "fast", "futureAnthropicField": True, + "contextManagement": { + "edits": [ + { + "type": "compact_20260112", + "trigger": { + "type": "input_tokens", + "value": 120_000, + }, + } + ] + }, + "thinking": {"type": "adaptive", "display": "detailed"}, }, "google": { "thinkingConfig": {"budgetTokens": 1024}, }, } + assert captured_body["seed"] == 123 assert captured_body["futureGatewayField"] is True + async def test_gateway_omits_random_seed(self) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="openai/gpt-5.4", + ) + request_params = ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=-1)} + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert "seed" not in captured_body + + async def test_gateway_routing_params_map_to_provider_options( + self, + ) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="anthropic/claude-sonnet-4", + ) + byok: dict[str, list[dict[str, Any]]] = { + "anthropic": [{"apiKey": "sk-test"}] + } + request_params = ai.InferenceRequestParams( + cache=ai.CacheParams(mode="auto"), + safety_identifier="user_123", + tags=frozenset({"team:search", "env:prod"}), + routing=ai.RoutingParams( + provider_allowlist=frozenset({"anthropic", "bedrock"}), + provider_order=("bedrock", "anthropic"), + provider_ranking=ai.ProviderRankingStrategy.LATENCY, + fallback_models=("openai/gpt-5-mini",), + routing_target=ai.CloudRegion("us-east-1"), + ), + provider_params={ + GatewayParams: GatewayParams( + quota_entity_id="quota_123", + zero_data_retention=True, + hipaa_compliant=True, + disallow_prompt_training=True, + byok=byok, + provider_timeouts=ProviderTimeoutsParams( + byok={"anthropic": 5000} + ), + ), + }, + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert captured_body["providerOptions"]["gateway"] == { + "caching": "auto", + "only": ["anthropic", "bedrock"], + "order": ["bedrock", "anthropic"], + "sort": "latency", + "models": ["openai/gpt-5-mini"], + "user": "user_123", + "tags": ["env:prod", "team:search"], + "quotaEntityId": "quota_123", + "zeroDataRetention": True, + "hipaaCompliant": True, + "disallowPromptTraining": True, + "inferenceRegion": {"providerRegion": "us-east-1"}, + "byok": {"anthropic": [{"apiKey": "sk-test"}]}, + "providerTimeouts": {"byok": {"anthropic": 5000}}, + } + + async def test_gateway_openai_context_management_maps_to_provider_options( + self, + ) -> None: + captured_body: dict[str, Any] = {} + + def handler(req: httpx.Request) -> httpx.Response: + captured_body.update(json.loads(req.content)) + return httpx.Response( + 200, + text=sse( + {"type": "finish", "finishReason": "stop", "usage": {}} + ), + ) + + model = mock_model( + httpx.MockTransport(handler), + model_id="openai/gpt-5.4", + ) + request_params = ai.InferenceRequestParams( + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ) + ) + async with models.stream( + model, + [user_msg("Hi")], + params=request_params, + ) as stream: + async for _ in stream: + pass + + assert captured_body["providerOptions"]["openai"] == { + "contextManagement": [ + {"type": "compaction", "compactThreshold": 120_000} + ] + } + async def test_gateway_rejects_non_dict_params(self) -> None: def handler(req: httpx.Request) -> httpx.Response: raise AssertionError("request should not be sent") @@ -392,13 +551,14 @@ def handler(req: httpx.Request) -> httpx.Response: httpx.MockTransport(handler), model_id="openai/gpt-5.4", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): async with models.stream( model, [user_msg("Hi")], - params=[ - {"providerOptions": {"openai": {"serviceTier": "auto"}}} - ], + params=cast( + Any, + [{"providerOptions": {"openai": {"serviceTier": "auto"}}}], + ), ) as stream: async for _ in stream: pass diff --git a/tests/providers/anthropic/test_adapter.py b/tests/providers/anthropic/test_adapter.py index 01d23da6..a2167d98 100644 --- a/tests/providers/anthropic/test_adapter.py +++ b/tests/providers/anthropic/test_adapter.py @@ -53,7 +53,7 @@ async def _drain(stream: Any) -> None: pass -async def test_raw_params_pass_through_to_sdk_kwargs( +async def test_params_translate_to_sdk_kwargs( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, captured = _patch_client(monkeypatch) @@ -63,41 +63,54 @@ async def test_raw_params_pass_through_to_sdk_kwargs( fake, _MODEL, [ai.user_message("Hi")], - params={ - "max_tokens": 123, - "speed": "fast", - "thinking": {"type": "disabled"}, - "output_config": { - "effort": "high", - "task_budget": {"type": "tokens", "total": 20000}, - }, - "tool_choice": { - "type": "auto", - "disable_parallel_tool_use": True, + params=ai.InferenceRequestParams( + output=ai.OutputParams(max_tokens=123, reasoning_summary=None), + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + tool_calling=ai.ToolCallingParams( + tool_choice=ai.ToolChoiceMode.AUTO, + parallel_tool_calls=False, + ), + extra_body={ + "speed": "fast", + "future_option": {"enabled": True}, }, - "extra_body": {"future_option": {"enabled": True}}, - "extra_headers": {"x-anthropic-feature": "enabled"}, - }, + extra_headers={"x-anthropic-feature": "enabled"}, + ), provider="anthropic", ) ) assert captured["max_tokens"] == 123 - assert captured["speed"] == "fast" assert captured["thinking"] == {"type": "disabled"} assert captured["output_config"] == { "effort": "high", - "task_budget": {"type": "tokens", "total": 20000}, } assert captured["tool_choice"] == { "type": "auto", "disable_parallel_tool_use": True, } - assert captured["extra_body"] == {"future_option": {"enabled": True}} - assert captured["extra_headers"] == {"x-anthropic-feature": "enabled"} + assert captured["extra_body"] == { + "context_management": { + "edits": [ + { + "type": "compact_20260112", + "trigger": {"type": "input_tokens", "value": 120_000}, + } + ] + }, + "speed": "fast", + "future_option": {"enabled": True}, + } + assert captured["extra_headers"] == { + "anthropic-beta": "compact-2026-01-12,context-management-2025-06-27", + "x-anthropic-feature": "enabled", + } -async def test_non_dict_params_rejected_by_adapter( +async def test_non_inference_params_rejected_by_adapter( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, _ = _patch_client(monkeypatch) @@ -106,14 +119,53 @@ async def test_non_dict_params_rejected_by_adapter( fake, _MODEL, [ai.user_message("Hi")], - params=[{"speed": "fast"}], + params=cast(Any, [{"speed": "fast"}]), provider="anthropic", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): await _drain(stream) +async def test_seed_rejected_by_adapter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch_client(monkeypatch) + + stream = protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123)} + ), + provider="anthropic", + ) + + with pytest.raises(ValueError, match="seed"): + await _drain(stream) + + +async def test_random_seed_omitted_by_adapter( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, captured = _patch_client(monkeypatch) + + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ai.SeedSamplerParams: ai.SeedSamplerParams(seed=-1)} + ), + provider="anthropic", + ) + ) + + assert "seed" not in captured + + async def test_reasoning_signature_round_trips_from_provider_metadata( monkeypatch: pytest.MonkeyPatch, ) -> None: diff --git a/tests/providers/anthropic/test_tools.py b/tests/providers/anthropic/test_tools.py index 509a4fbf..7c1e5008 100644 --- a/tests/providers/anthropic/test_tools.py +++ b/tests/providers/anthropic/test_tools.py @@ -30,7 +30,7 @@ async def _capture_tools( monkeypatch: pytest.MonkeyPatch, tools: list[Any], *, - params: dict[str, Any] | None = None, + params: ai.InferenceRequestParams | None = None, ) -> dict[str, Any]: _ = monkeypatch captured: dict[str, Any] = {} @@ -170,7 +170,9 @@ async def test_user_anthropic_beta_header_wins( anthropic_tools.web_search(), anthropic_tools.web_fetch(), ], - params={"extra_headers": {"anthropic-beta": "custom-beta-2026-01-01"}}, + params=ai.InferenceRequestParams( + extra_headers={"anthropic-beta": "custom-beta-2026-01-01"} + ), ) assert _beta_header(captured) == "custom-beta-2026-01-01" diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index 902c24ad..30d6ba86 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -111,7 +111,10 @@ async def close(self) -> None: self.closed = True -_MODEL = ai.Model("gpt-5.4", provider=ai.get_provider("openai")) +_MODEL = ai.Model( + "gpt-5.4", + provider=ai.get_provider("openai", api_key="sk-test"), +) def _patch( @@ -157,7 +160,7 @@ async def test_responses_request_uses_responses_input() -> None: assert "messages" not in captured -async def test_responses_raw_params_and_structured_output() -> None: +async def test_responses_params_and_structured_output() -> None: fake, captured = _patch_responses() await _drain( @@ -166,25 +169,76 @@ async def test_responses_raw_params_and_structured_output() -> None: _MODEL, [ai.user_message("Hi")], output_type=_Answer, - params={ - "reasoning": {"effort": "high"}, - "include": ["file_search_call.results"], - "text": {"verbosity": "low"}, - "extra_headers": {"x-openai-feature": "enabled"}, - }, + params=ai.InferenceRequestParams( + reasoning=ai.ReasoningParams(effort="high"), + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ), + output=ai.OutputParams( + include=frozenset({"file_search_call.results"}), + reasoning_summary="auto", + text_verbosity="low", + ), + extra_body={"future_option": True}, + extra_headers={"x-openai-feature": "enabled"}, + ), provider="openai", ) ) - assert captured["reasoning"] == {"effort": "high"} + assert captured["reasoning"] == {"effort": "high", "summary": "auto"} + assert captured["context_management"] == [ + {"type": "compaction", "compact_threshold": 120_000} + ] assert captured["include"] == ["file_search_call.results"] assert captured["extra_headers"] == {"x-openai-feature": "enabled"} + assert captured["extra_body"] == {"future_option": True} assert captured["text"]["verbosity"] == "low" assert captured["text"]["format"]["type"] == "json_schema" assert captured["text"]["format"]["name"] == "_Answer" assert captured["text"]["format"]["strict"] is True +async def test_chat_rejects_text_verbosity( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch(monkeypatch) + + with pytest.raises(ValueError, match="text verbosity"): + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + output=ai.OutputParams(text_verbosity="low") + ), + provider="openai", + ) + ) + + +async def test_chat_rejects_context_management( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, _ = _patch(monkeypatch) + + with pytest.raises(ValueError, match="context management"): + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + context_management=ai.ContextManagementParams( + compaction=ai.TokenThreshold(120_000) + ) + ), + provider="openai", + ) + ) + + async def test_responses_tools_convert_function_and_provider_tools() -> None: fake, captured = _patch_responses() @@ -421,7 +475,7 @@ async def test_system_messages_use_openai_system_role( assert captured["messages"][0] == {"role": "system", "content": "rules"} -async def test_raw_params_pass_through_to_sdk_kwargs( +async def test_params_translate_to_sdk_kwargs( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, captured = _patch(monkeypatch) @@ -431,27 +485,71 @@ async def test_raw_params_pass_through_to_sdk_kwargs( fake, _MODEL, [ai.user_message("Hi")], - params={ - "logprobs": 3, - "verbosity": "low", - "max_completion_tokens": 128, - "extra_body": {"future_option": True}, - "extra_headers": {"x-openai-feature": "enabled"}, - "stream_options": {"include_usage": False, "custom": True}, - }, + params=ai.InferenceRequestParams( + sampling={ + ai.TemperatureSamplerParams: ai.TemperatureSamplerParams( + temperature=0.2 + ), + ai.TopPSamplerParams: ai.TopPSamplerParams(top_p=0.9), + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123), + }, + output=ai.OutputParams(max_tokens=128), + provider_service=ai.ProviderServiceParams(service_tier="auto"), + extra_body={"future_option": True, "verbosity": "low"}, + extra_headers={"x-openai-feature": "enabled"}, + ), provider="openai", ) ) - assert captured["logprobs"] == 3 - assert captured["verbosity"] == "low" + assert captured["temperature"] == 0.2 + assert captured["top_p"] == 0.9 + assert captured["seed"] == 123 assert captured["max_completion_tokens"] == 128 - assert captured["extra_body"] == {"future_option": True} + assert captured["service_tier"] == "auto" + assert captured["extra_body"] == {"future_option": True, "verbosity": "low"} assert captured["extra_headers"] == {"x-openai-feature": "enabled"} - assert captured["stream_options"] == { - "include_usage": False, - "custom": True, - } + + +async def test_chat_omits_explicit_random_seed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + fake, captured = _patch(monkeypatch) + + await _drain( + protocol.stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=ai.RANDOM) + } + ), + provider="openai", + ) + ) + + assert "seed" not in captured + + +async def test_responses_rejects_seed() -> None: + fake, _ = _patch_responses() + + with pytest.raises(ValueError, match="seed"): + await _drain( + protocol.OpenAIResponsesProtocol().stream( + fake, + _MODEL, + [ai.user_message("Hi")], + params=ai.InferenceRequestParams( + sampling={ + ai.SeedSamplerParams: ai.SeedSamplerParams(seed=123) + } + ), + provider="openai", + ) + ) async def test_strict_json_schema_flows_into_response_format( @@ -472,7 +570,7 @@ async def test_strict_json_schema_flows_into_response_format( assert captured["response_format"]["json_schema"]["strict"] is True -async def test_non_dict_params_rejected_by_adapter( +async def test_non_inference_params_rejected_by_adapter( monkeypatch: pytest.MonkeyPatch, ) -> None: fake, _ = _patch(monkeypatch) @@ -481,11 +579,11 @@ async def test_non_dict_params_rejected_by_adapter( fake, _MODEL, [ai.user_message("Hi")], - params=[{"reasoning_effort": "high"}], + params=cast(Any, [{"reasoning_effort": "high"}]), provider="openai", ) - with pytest.raises(TypeError, match="dict"): + with pytest.raises(TypeError, match="InferenceRequestParams"): await _drain(stream) diff --git a/tests/types/test_integrity.py b/tests/types/test_integrity.py index d65eecc9..9a7db8b3 100644 --- a/tests/types/test_integrity.py +++ b/tests/types/test_integrity.py @@ -630,7 +630,7 @@ async def test_generate_sanitizes_internal_messages() -> None: async def _spy_gen( model: models.Model, messages: list[messages.Message], - params: Any, + params: models.GenerateParams, ) -> messages.Message: received.append(list(messages)) return sentinel diff --git a/uv.lock b/uv.lock index 0f3f5c68..e168bba7 100644 --- a/uv.lock +++ b/uv.lock @@ -106,7 +106,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.83.0" +version = "0.103.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -118,9 +118,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/db/e5/02cd2919ec327b24234abb73082e6ab84c451182cc3cc60681af700f4c63/anthropic-0.83.0.tar.gz", hash = "sha256:a8732c68b41869266c3034541a31a29d8be0f8cd0a714f9edce3128b351eceb4", size = 534058, upload-time = "2026-02-19T19:26:38.904Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fb/57/0b758b08cf4606c94d63a997d67a0063f7438efbaf81cfedd0d7c0c69d67/anthropic-0.103.1.tar.gz", hash = "sha256:21c12f4fc0fdd87a2e80d58479cd0af640062b3cfb82bbfa01c7977acd4defeb", size = 848877, upload-time = "2026-05-19T15:43:27.698Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/75/b9d58e4e2a4b1fc3e75ffbab978f999baf8b7c4ba9f96e60edb918ba386b/anthropic-0.83.0-py3-none-any.whl", hash = "sha256:f069ef508c73b8f9152e8850830d92bd5ef185645dbacf234bb213344a274810", size = 456991, upload-time = "2026-02-19T19:26:40.114Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ec/cf357cf571377a39552c1530390a9b79bbdb6ea463f48fbe4e3624141e3b/anthropic-0.103.1-py3-none-any.whl", hash = "sha256:b9a523fac34e64caf6ee55fdbda213950e6a744b906fce100d34909aad2cd8f4", size = 832551, upload-time = "2026-05-19T15:43:29.663Z" }, ] [[package]] @@ -741,7 +741,7 @@ wheels = [ [[package]] name = "openai" -version = "2.14.0" +version = "2.37.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -753,9 +753,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d8/b1/12fe1c196bea326261718eb037307c1c1fe1dedc2d2d4de777df822e6238/openai-2.14.0.tar.gz", hash = "sha256:419357bedde9402d23bf8f2ee372fca1985a73348debba94bddff06f19459952", size = 626938, upload-time = "2025-12-19T03:28:45.742Z" } +sdist = { url = "https://files.pythonhosted.org/packages/32/50/5901f01ef14e6c27788beb91e54fef5d6204fb5fb9e97402fc8a14de2e32/openai-2.37.0.tar.gz", hash = "sha256:f4bc562cc5f3a43d40d678105572d9d44765f6e0f50c125f63055419b72f4bd9", size = 754706, upload-time = "2026-05-15T22:30:35.428Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/4b/7c1a00c2c3fbd004253937f7520f692a9650767aa73894d7a34f0d65d3f4/openai-2.14.0-py3-none-any.whl", hash = "sha256:7ea40aca4ffc4c4a776e77679021b47eec1160e341f42ae086ba949c9dcc9183", size = 1067558, upload-time = "2025-12-19T03:28:43.727Z" }, + { url = "https://files.pythonhosted.org/packages/ed/4c/bce61680d0699a78a405fd9a67989b175ba020590428831aab2ab1d2be7c/openai-2.37.0-py3-none-any.whl", hash = "sha256:814633888b8f3b1ffd6615697c6e4ef93632d08b7c2e28c8c5ef3556e5a10107", size = 1303238, upload-time = "2026-05-15T22:30:32.767Z" }, ] [[package]]