From 0baf787fc6801274a484675bcd395cc83668dddb Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Mon, 15 Jun 2026 17:34:42 -0700 Subject: [PATCH 1/3] Split Tool.args into tool spec and tool config Delete provider-specific dataclasses and only leave factories for built-in tools. --- examples/temporal-direct/main.py | 14 +- examples/tools_schema.py | 2 +- src/ai/agents/agent.py | 2 +- src/ai/agents/mcp/client.py | 2 +- src/ai/providers/ai_gateway/protocol.py | 67 +++--- src/ai/providers/ai_gateway/tools.py | 85 ++++---- src/ai/providers/anthropic/protocol.py | 44 ++-- src/ai/providers/anthropic/tools.py | 178 +++++----------- src/ai/providers/openai/protocol.py | 65 +++--- src/ai/providers/openai/tools.py | 261 ++++++----------------- src/ai/types/tools.py | 33 ++- tests/agents/mcp/test_client.py | 14 +- tests/agents/test_tools.py | 12 +- tests/providers/ai_gateway/test_tools.py | 69 ++++-- tests/providers/openai/test_adapter.py | 2 +- tests/types/test_tools.py | 78 +++++++ 16 files changed, 421 insertions(+), 507 deletions(-) create mode 100644 tests/types/test_tools.py diff --git a/examples/temporal-direct/main.py b/examples/temporal-direct/main.py index 7e30a462..98ec7ca4 100644 --- a/examples/temporal-direct/main.py +++ b/examples/temporal-direct/main.py @@ -131,14 +131,7 @@ async def llm_call_activity(params: LLMParams) -> LLMResult: """Call the LLM, drain the stream, return the final message.""" model = ai.get_model(params.model_id) messages = [ai.messages.Message.model_validate(m) for m in params.messages] - tools = [ - ai.Tool( - kind="function", - name=t["name"], - args=ai.tools.FunctionToolArgs.model_validate(t["args"]), - ) - for t in params.tool_schemas - ] + tools = [ai.Tool.model_validate(t) for t in params.tool_schemas] async with ai.stream(model, messages, tools=tools) as s: async for _event in s: @@ -162,10 +155,7 @@ class WeatherAgent(ai.Agent): async def loop( self, context: ai.Context ) -> AsyncGenerator[ai.events.AgentEvent]: - tool_schemas = [ - {"name": t.name, "args": t.args.model_dump(mode="json")} - for t in context.tools - ] + tool_schemas = [t.model_dump(mode="json") for t in context.tools] while context.keep_running(): # 1. LLM call via activity → complete message diff --git a/examples/tools_schema.py b/examples/tools_schema.py index 9a8da5db..e702c2f6 100644 --- a/examples/tools_schema.py +++ b/examples/tools_schema.py @@ -10,7 +10,7 @@ get_weather = ai.Tool( kind="function", name="get_weather", - args=ai.tools.FunctionToolArgs( + spec=ai.tools.ToolSpec( description="Get the current weather for a city.", params={ "type": "object", diff --git a/src/ai/agents/agent.py b/src/ai/agents/agent.py index 06bea9e7..e5eb3f6f 100644 --- a/src/ai/agents/agent.py +++ b/src/ai/agents/agent.py @@ -495,7 +495,7 @@ def wrap(fn: Any) -> AgentTool: tool_decl = Tool( kind="function", name=fn.__name__, - args=types.tools.FunctionToolArgs( + spec=types.tools.ToolSpec( description=inspect.getdoc(fn) or "", params=validator.model_json_schema(), ), diff --git a/src/ai/agents/mcp/client.py b/src/ai/agents/mcp/client.py index 38e6b4a4..d5022a80 100644 --- a/src/ai/agents/mcp/client.py +++ b/src/ai/agents/mcp/client.py @@ -301,7 +301,7 @@ def _mcp_tool_to_native( tool = Tool( kind="function", name=name, - args=types.tools.FunctionToolArgs( + spec=types.tools.ToolSpec( description=mcp_tool.description or "", params=mcp_tool.inputSchema, ), diff --git a/src/ai/providers/ai_gateway/protocol.py b/src/ai/providers/ai_gateway/protocol.py index 03acc199..cebbb5a5 100644 --- a/src/ai/providers/ai_gateway/protocol.py +++ b/src/ai/providers/ai_gateway/protocol.py @@ -11,17 +11,15 @@ import httpx import pydantic +from pydantic.alias_generators import to_camel from ... import types from ...models import core from ...models.core import params as params_ from .. import base -from ..anthropic import tools as anthropic_tools -from ..openai import tools as openai_tools from . import client as gateway_client from . import errors from . import params as gateway_params -from . import tools as gateway_tools from .client import errors as client_errors # --------------------------------------------------------------------------- @@ -276,48 +274,53 @@ async def _messages_to_prompt( return result +# Free-form payload fields whose keys are data, not config structure — +# their subtrees must reach the wire verbatim, never camelized. +_OPAQUE_ARG_KEYS: dict[str, frozenset[str]] = { + "openai.mcp": frozenset({"headers", "allowed_tools"}), + "openai.file_search": frozenset({"filters"}), + "openai.tool_search": frozenset({"parameters", "execution"}), +} + + +def _camelize(value: Any) -> Any: + if isinstance(value, dict): + return {to_camel(k): _camelize(v) for k, v in value.items()} + if isinstance(value, list): + return [_camelize(v) for v in value] + return value + + def _tool_to_v3(tool: types.tools.Tool) -> dict[str, Any]: """Convert a tool schema blob to the v3 wire format.""" if tool.kind == "provider": + cfg = tool.tool_config + tool_id = cfg.id if cfg is not None else None + if tool_id is None: + raise TypeError( + f"provider tool {tool.name!r} has no tool_config id" + ) + opaque = _OPAQUE_ARG_KEYS.get(tool_id, frozenset()) return { "type": "provider", - "id": _provider_tool_id(tool), + "id": tool_id, "name": tool.name, - "args": tool.args.model_dump( - mode="json", - by_alias=True, - exclude_none=True, - ), + "args": { + to_camel(k): v if k in opaque else _camelize(v) + for k, v in (cfg.args if cfg is not None else {}).items() + }, } - args = tool.args - if not isinstance(args, types.tools.FunctionToolArgs): - raise TypeError(f"function tool {tool.name!r} has invalid args") + spec = tool.spec + if spec is None: + raise TypeError(f"function tool {tool.name!r} has no spec") return { "type": "function", "name": tool.name, - "description": args.description or "", - "inputSchema": args.params, + "description": spec.description or "", + "inputSchema": spec.params, } -def _provider_tool_id(tool: types.tools.Tool) -> str: - if isinstance(tool.args, anthropic_tools.AnthropicProviderArgs): - return f"anthropic.{tool.args.anthropic_type}" - if isinstance(tool.args, openai_tools.OpenAIProviderArgs): - return tool.args.openai_id - - match tool.args: - case gateway_tools.PerplexitySearchArgs(): - return "gateway.perplexity_search" - case gateway_tools.ParallelSearchArgs(): - return "gateway.parallel_search" - case _: - raise TypeError( - f"provider tool {tool.name!r} has unsupported args " - f"{type(tool.args).__name__}" - ) - - async def _build_request_body( messages: list[types.messages.Message], tools: Sequence[types.tools.Tool] | None = None, diff --git a/src/ai/providers/ai_gateway/tools.py b/src/ai/providers/ai_gateway/tools.py index 41beccfd..910ee502 100644 --- a/src/ai/providers/ai_gateway/tools.py +++ b/src/ai/providers/ai_gateway/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal import pydantic from pydantic.alias_generators import to_camel @@ -43,26 +43,25 @@ class FetchPolicy(pydantic.BaseModel): max_age_seconds: int | None = None -class PerplexitySearchArgs(pydantic.BaseModel): - model_config = _CONFIG_MODEL - - max_results: int | None = None - max_tokens_per_page: int | None = None - max_tokens: int | None = None - country: str | None = None - search_domain_filter: list[str] | None = None - search_language_filter: list[str] | None = None - search_recency_filter: Literal["day", "week", "month", "year"] | None = None - +def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: + return types.tools.Tool( + kind="provider", + name=name, + tool_config=types.tools.ToolConfig( + id=id, + args={k: v for k, v in args.items() if v is not None}, + ), + ) -class ParallelSearchArgs(pydantic.BaseModel): - model_config = _CONFIG_MODEL - mode: Literal["one-shot", "agentic"] | None = None - max_results: int | None = None - source_policy: SourcePolicy | None = None - excerpts: Excerpts | None = None - fetch_policy: FetchPolicy | None = None +def _dump[M: pydantic.BaseModel]( + model_type: type[M], value: M | dict[str, object] | None +) -> dict[str, Any] | None: + if value is None: + return None + if isinstance(value, dict): + value = model_type.model_validate(value) + return value.model_dump(mode="json", exclude_none=True) def perplexity_search( @@ -76,18 +75,16 @@ def perplexity_search( search_recency_filter: Literal["day", "week", "month", "year"] | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="perplexity_search", - args=PerplexitySearchArgs( - max_results=max_results, - max_tokens_per_page=max_tokens_per_page, - max_tokens=max_tokens, - country=country, - search_domain_filter=search_domain_filter, - search_language_filter=search_language_filter, - search_recency_filter=search_recency_filter, - ), + return _provider_tool( + "perplexity_search", + "gateway.perplexity_search", + max_results=max_results, + max_tokens_per_page=max_tokens_per_page, + max_tokens=max_tokens, + country=country, + search_domain_filter=search_domain_filter, + search_language_filter=search_language_filter, + search_recency_filter=search_recency_filter, ) @@ -99,30 +96,20 @@ def parallel_search( excerpts: Excerpts | dict[str, object] | None = None, fetch_policy: FetchPolicy | dict[str, object] | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="parallel_search", - args=ParallelSearchArgs( - mode=mode, - max_results=max_results, - source_policy=SourcePolicy.model_validate(source_policy) - if isinstance(source_policy, dict) - else source_policy, - excerpts=Excerpts.model_validate(excerpts) - if isinstance(excerpts, dict) - else excerpts, - fetch_policy=FetchPolicy.model_validate(fetch_policy) - if isinstance(fetch_policy, dict) - else fetch_policy, - ), + return _provider_tool( + "parallel_search", + "gateway.parallel_search", + mode=mode, + max_results=max_results, + source_policy=_dump(SourcePolicy, source_policy), + excerpts=_dump(Excerpts, excerpts), + fetch_policy=_dump(FetchPolicy, fetch_policy), ) __all__ = [ "Excerpts", "FetchPolicy", - "ParallelSearchArgs", - "PerplexitySearchArgs", "SourcePolicy", "parallel_search", "perplexity_search", diff --git a/src/ai/providers/anthropic/protocol.py b/src/ai/providers/anthropic/protocol.py index 0a927d9d..542b9e13 100644 --- a/src/ai/providers/anthropic/protocol.py +++ b/src/ai/providers/anthropic/protocol.py @@ -77,14 +77,14 @@ def _custom_tools_to_anthropic( """Convert host-executed tools to Anthropic tool schema format.""" result: list[dict[str, Any]] = [] for tool in tools: - args = tool.args - if not isinstance(args, types.tools.FunctionToolArgs): - raise TypeError(f"function tool {tool.name!r} has invalid args") + spec = tool.spec + if spec is None: + raise TypeError(f"function tool {tool.name!r} has no spec") result.append( { "name": tool.name, - "description": args.description or "", - "input_schema": args.params, + "description": spec.description or "", + "input_schema": spec.params, } ) return result @@ -98,27 +98,33 @@ def _builtin_tools_to_anthropic( Returns ``(wire_tools, beta_headers)``. Beta headers are merged into the ``anthropic-beta`` request header by the caller. - Provider tool schemas keep args in the snake_case shape the native + Provider tool configs keep args in the snake_case shape the native Anthropic API expects. """ wire: list[dict[str, Any]] = [] betas: set[str] = set() for tool in builtin: - args_model = tool.args - if not isinstance(args_model, anthropic_tools.AnthropicProviderArgs): + cfg = tool.tool_config + tool_id = cfg.id if cfg is not None else None + if ( + cfg is None + or tool_id is None + or not tool_id.startswith("anthropic.") + ): raise ValueError( - "AnthropicModel does not support provider args " - f"{type(args_model).__name__}" + "AnthropicModel does not support provider tool " + f"{tool_id or tool.name!r}" ) - args = args_model.model_dump(mode="json", exclude_none=True) - block: dict[str, Any] = { - "type": args_model.anthropic_type, - "name": tool.name, - **args, - } - wire.append(block) - if args_model.anthropic_beta is not None: - betas.add(args_model.anthropic_beta) + wire.append( + { + "type": tool_id.removeprefix("anthropic."), + "name": tool.name, + **cfg.args, + } + ) + beta = anthropic_tools.BETA_HEADERS.get(tool_id) + if beta is not None: + betas.add(beta) return wire, betas diff --git a/src/ai/providers/anthropic/tools.py b/src/ai/providers/anthropic/tools.py index ec8864e7..4a95a4d5 100644 --- a/src/ai/providers/anthropic/tools.py +++ b/src/ai/providers/anthropic/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import ClassVar, Literal +from typing import Any, Literal import pydantic from pydantic.alias_generators import to_camel @@ -15,6 +15,16 @@ alias_generator=to_camel, ) +# Beta request headers per provider tool id, merged into the +# ``anthropic-beta`` header by the adapter. +BETA_HEADERS: dict[str, str] = { + "anthropic.web_search_20260209": "code-execution-web-tools-2026-02-09", + "anthropic.web_fetch_20260209": "code-execution-web-tools-2026-02-09", + "anthropic.computer_20251124": "computer-use-2025-11-24", + "anthropic.bash_20250124": "computer-use-2025-01-24", + "anthropic.memory_20250818": "context-management-2025-06-27", +} + class UserLocation(pydantic.BaseModel): """Approximate user location for geographically relevant search results.""" @@ -36,78 +46,15 @@ class Citations(pydantic.BaseModel): enabled: bool -class AnthropicProviderArgs(pydantic.BaseModel): - """Base for Anthropic provider-executed tool args.""" - - model_config = _CONFIG_MODEL - - anthropic_type: ClassVar[str] - anthropic_beta: ClassVar[str | None] = None - - -class WebSearchArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "web_search_20260209" - anthropic_beta: ClassVar[str | None] = "code-execution-web-tools-2026-02-09" - - model_config = _CONFIG_MODEL - - max_uses: int | None = None - allowed_domains: list[str] | None = None - blocked_domains: list[str] | None = None - user_location: UserLocation | None = None - - -class WebFetchArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "web_fetch_20260209" - anthropic_beta: ClassVar[str | None] = "code-execution-web-tools-2026-02-09" - - model_config = _CONFIG_MODEL - - max_uses: int | None = None - allowed_domains: list[str] | None = None - blocked_domains: list[str] | None = None - citations: Citations | None = None - max_content_tokens: int | None = None - - -class CodeExecutionArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "code_execution_20260120" - - model_config = _CONFIG_MODEL - - -class ComputerUseArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "computer_20251124" - anthropic_beta: ClassVar[str | None] = "computer-use-2025-11-24" - - model_config = _CONFIG_MODEL - - display_width_px: int - display_height_px: int - display_number: int | None = None - enable_zoom: bool | None = None - - -class TextEditorArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "text_editor_20250728" - - model_config = _CONFIG_MODEL - - max_characters: int | None = None - - -class BashArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "bash_20250124" - anthropic_beta: ClassVar[str | None] = "computer-use-2025-01-24" - - model_config = _CONFIG_MODEL - - -class MemoryArgs(AnthropicProviderArgs): - anthropic_type: ClassVar[str] = "memory_20250818" - anthropic_beta: ClassVar[str | None] = "context-management-2025-06-27" - - model_config = _CONFIG_MODEL +def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: + return types.tools.Tool( + kind="provider", + name=name, + tool_config=types.tools.ToolConfig( + id=id, + args={k: v for k, v in args.items() if v is not None}, + ), + ) def _check_domains( @@ -129,17 +76,16 @@ def web_search( blocked_domains: list[str] | None = None, user_location: UserLocation | None = None, ) -> types.tools.Tool: - args = WebSearchArgs( + _check_domains("web_search", allowed_domains, blocked_domains) + return _provider_tool( + "web_search", + "anthropic.web_search_20260209", max_uses=max_uses, allowed_domains=allowed_domains, blocked_domains=blocked_domains, - user_location=user_location, - ) - _check_domains("web_search", args.allowed_domains, args.blocked_domains) - return types.tools.Tool( - kind="provider", - name="web_search", - args=args, + user_location=user_location.model_dump(mode="json", exclude_none=True) + if user_location is not None + else None, ) @@ -151,29 +97,24 @@ def web_fetch( citations: Citations | bool | None = None, max_content_tokens: int | None = None, ) -> types.tools.Tool: - args = WebFetchArgs( + _check_domains("web_fetch", allowed_domains, blocked_domains) + if isinstance(citations, bool): + citations = Citations(enabled=citations) + return _provider_tool( + "web_fetch", + "anthropic.web_fetch_20260209", max_uses=max_uses, allowed_domains=allowed_domains, blocked_domains=blocked_domains, - citations=Citations(enabled=citations) - if isinstance(citations, bool) - else citations, + citations=citations.model_dump(mode="json", exclude_none=True) + if citations is not None + else None, max_content_tokens=max_content_tokens, ) - _check_domains("web_fetch", args.allowed_domains, args.blocked_domains) - return types.tools.Tool( - kind="provider", - name="web_fetch", - args=args, - ) def code_execution() -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="code_execution", - args=CodeExecutionArgs(), - ) + return _provider_tool("code_execution", "anthropic.code_execution_20260120") def computer_use( @@ -183,53 +124,36 @@ def computer_use( display_number: int | None = None, enable_zoom: bool | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="computer", - args=ComputerUseArgs( - display_width_px=display_width_px, - display_height_px=display_height_px, - display_number=display_number, - enable_zoom=enable_zoom, - ), + return _provider_tool( + "computer", + "anthropic.computer_20251124", + display_width_px=display_width_px, + display_height_px=display_height_px, + display_number=display_number, + enable_zoom=enable_zoom, ) def text_editor(*, max_characters: int | None = None) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="str_replace_based_edit_tool", - args=TextEditorArgs(max_characters=max_characters), + return _provider_tool( + "str_replace_based_edit_tool", + "anthropic.text_editor_20250728", + max_characters=max_characters, ) def bash() -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="bash", - args=BashArgs(), - ) + return _provider_tool("bash", "anthropic.bash_20250124") def memory() -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="memory", - args=MemoryArgs(), - ) + return _provider_tool("memory", "anthropic.memory_20250818") __all__ = [ - "AnthropicProviderArgs", - "BashArgs", + "BETA_HEADERS", "Citations", - "CodeExecutionArgs", - "ComputerUseArgs", - "MemoryArgs", - "TextEditorArgs", "UserLocation", - "WebFetchArgs", - "WebSearchArgs", "bash", "code_execution", "computer_use", diff --git a/src/ai/providers/openai/protocol.py b/src/ai/providers/openai/protocol.py index a5320f3e..9c1581eb 100644 --- a/src/ai/providers/openai/protocol.py +++ b/src/ai/providers/openai/protocol.py @@ -17,7 +17,6 @@ from ...models.core import params as params_ from .. import base from . import _sdk, errors -from . import tools as openai_tools if TYPE_CHECKING: import openai @@ -40,16 +39,16 @@ def _tools_to_openai( for tool in tools: if tool.kind == "provider": continue - args = tool.args - if not isinstance(args, types.tools.FunctionToolArgs): - raise TypeError(f"function tool {tool.name!r} has invalid args") + spec = tool.spec + if spec is None: + raise TypeError(f"function tool {tool.name!r} has no spec") result.append( { "type": "function", "function": { "name": tool.name, - "description": args.description or "", - "parameters": args.params, + "description": spec.description or "", + "parameters": spec.params, }, } ) @@ -777,10 +776,6 @@ def _json_dumps(value: Any) -> str: return json.dumps(value, separators=(",", ":"), default=str) -def _model_dump(value: pydantic.BaseModel) -> dict[str, Any]: - return value.model_dump(exclude_none=True) - - def _openai_metadata(part: Any) -> dict[str, Any]: metadata = getattr(part, "provider_metadata", None) if not isinstance(metadata, Mapping): @@ -1040,56 +1035,53 @@ def _tools_to_responses( for tool in tools: if tool.kind == "function": - args = tool.args - if not isinstance(args, types.tools.FunctionToolArgs): - raise TypeError(f"function tool {tool.name!r} has invalid args") + spec = tool.spec + if spec is None: + raise TypeError(f"function tool {tool.name!r} has no spec") result.append( { "type": "function", "name": tool.name, - "description": args.description or "", - "parameters": args.params, + "description": spec.description or "", + "parameters": spec.params, } ) continue - args = tool.args - tool_id = getattr(type(args), "openai_id", None) - if not isinstance(args, openai_tools.OpenAIProviderArgs): + cfg = tool.tool_config + tool_id = cfg.id if cfg is not None else None + if tool_id is None or not tool_id.startswith("openai."): raise TypeError( f"provider tool {tool.name!r} is not an OpenAI tool" ) + args = dict(cfg.args) if cfg is not None else {} match tool_id: case "openai.web_search": - result.append({"type": "web_search", **_model_dump(args)}) + result.append({"type": "web_search", **args}) case "openai.web_search_preview": - result.append( - {"type": "web_search_preview", **_model_dump(args)} - ) + result.append({"type": "web_search_preview", **args}) case "openai.file_search": - data = _model_dump(args) - ranking = data.pop("ranking", None) + ranking = args.pop("ranking", None) if ranking is not None: - data["ranking_options"] = ranking - result.append({"type": "file_search", **data}) + args["ranking_options"] = ranking + result.append({"type": "file_search", **args}) case "openai.code_interpreter": - data = _model_dump(args) - if "container" not in data: - data["container"] = {"type": "auto"} - result.append({"type": "code_interpreter", **data}) + if "container" not in args: + args["container"] = {"type": "auto"} + result.append({"type": "code_interpreter", **args}) case "openai.image_generation": - result.append({"type": "image_generation", **_model_dump(args)}) + result.append({"type": "image_generation", **args}) case "openai.local_shell": result.append({"type": "local_shell"}) case "openai.shell": - result.append({"type": "shell", **_model_dump(args)}) + result.append({"type": "shell", **args}) case "openai.apply_patch": result.append({"type": "apply_patch"}) case "openai.mcp": - result.append({"type": "mcp", **_model_dump(args)}) + result.append({"type": "mcp", **args}) case "openai.tool_search": - result.append({"type": "tool_search", **_model_dump(args)}) + result.append({"type": "tool_search", **args}) case _: raise NotImplementedError( f"unsupported OpenAI provider tool {tool_id}" @@ -1143,8 +1135,9 @@ def _image_media_type( tools: Sequence[types.tools.Tool], ) -> str: for tool in tools: - if isinstance(tool.args, openai_tools.ImageGenerationArgs): - fmt = str(tool.args.output_format or "png") + cfg = tool.tool_config + if cfg is not None and cfg.id == "openai.image_generation": + fmt = str(cfg.args.get("output_format") or "png") return "image/jpeg" if fmt == "jpeg" else f"image/{fmt}" text = params.get("text") if isinstance(text, Mapping): diff --git a/src/ai/providers/openai/tools.py b/src/ai/providers/openai/tools.py index a5ac708f..a3a10a9b 100644 --- a/src/ai/providers/openai/tools.py +++ b/src/ai/providers/openai/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, ClassVar, Literal +from typing import Any, Literal import pydantic from pydantic.alias_generators import to_camel @@ -50,111 +50,21 @@ class CodeInterpreterContainer(pydantic.BaseModel): file_ids: list[str] | None = None -class OpenAIProviderArgs(pydantic.BaseModel): - """Base for OpenAI provider-executed tool args.""" - - model_config = _CONFIG_MODEL - - openai_id: ClassVar[str] - - -class WebSearchArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.web_search" - - model_config = _CONFIG_MODEL - - external_web_access: bool | None = None - filters: WebSearchFilters | None = None - search_context_size: Literal["low", "medium", "high"] | None = None - user_location: WebSearchUserLocation | None = None - - -class WebSearchPreviewArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.web_search_preview" - - model_config = _CONFIG_MODEL - - search_context_size: Literal["low", "medium", "high"] | None = None - user_location: WebSearchUserLocation | None = None - - -class FileSearchArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.file_search" - - model_config = _CONFIG_MODEL - - vector_store_ids: list[str] - max_num_results: int | None = None - ranking: FileSearchRanking | None = None - filters: dict[str, Any] | None = None - - -class CodeInterpreterArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.code_interpreter" - - model_config = _CONFIG_MODEL - - container: CodeInterpreterContainer | str | None = None - - -class ImageGenerationArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.image_generation" - - model_config = _CONFIG_MODEL - - background: Literal["transparent", "opaque", "auto"] | None = None - input_fidelity: Literal["high", "low"] | None = None - model: str | None = None - moderation: Literal["auto", "low"] | None = None - output_compression: int | None = None - output_format: Literal["png", "webp", "jpeg"] | None = None - partial_images: int | None = None - quality: Literal["low", "medium", "high", "auto"] | None = None - size: str | None = None - - -class LocalShellArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.local_shell" - - model_config = _CONFIG_MODEL - - -class ShellArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.shell" - - model_config = _CONFIG_MODEL - - environment: str | None = None - - -class ApplyPatchArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.apply_patch" - - model_config = _CONFIG_MODEL - - -class McpArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.mcp" - - model_config = _CONFIG_MODEL - - server_label: str - server_url: str | None = None - connector_id: str | None = None - authorization: str | None = None - headers: dict[str, str] | None = None - allowed_tools: list[str] | dict[str, Any] | None = None - server_description: str | None = None - - -class ToolSearchArgs(OpenAIProviderArgs): - openai_id: ClassVar[str] = "openai.tool_search" +def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: + return types.tools.Tool( + kind="provider", + name=name, + tool_config=types.tools.ToolConfig( + id=id, + args={k: v for k, v in args.items() if v is not None}, + ), + ) - model_config = _CONFIG_MODEL - description: str | None = None - parameters: dict[str, Any] | None = None - execution: dict[str, Any] | None = None +def _dump(model: pydantic.BaseModel | None) -> dict[str, Any] | None: + if model is None: + return None + return model.model_dump(mode="json", exclude_none=True) def web_search( @@ -164,15 +74,13 @@ def web_search( search_context_size: Literal["low", "medium", "high"] | None = None, user_location: WebSearchUserLocation | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="web_search", - args=WebSearchArgs( - external_web_access=external_web_access, - filters=filters, - search_context_size=search_context_size, - user_location=user_location, - ), + return _provider_tool( + "web_search", + "openai.web_search", + external_web_access=external_web_access, + filters=_dump(filters), + search_context_size=search_context_size, + user_location=_dump(user_location), ) @@ -181,13 +89,11 @@ def web_search_preview( search_context_size: Literal["low", "medium", "high"] | None = None, user_location: WebSearchUserLocation | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="web_search_preview", - args=WebSearchPreviewArgs( - search_context_size=search_context_size, - user_location=user_location, - ), + return _provider_tool( + "web_search_preview", + "openai.web_search_preview", + search_context_size=search_context_size, + user_location=_dump(user_location), ) @@ -198,15 +104,13 @@ def file_search( ranking: FileSearchRanking | None = None, filters: dict[str, Any] | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="file_search", - args=FileSearchArgs( - vector_store_ids=vector_store_ids, - max_num_results=max_num_results, - ranking=ranking, - filters=filters, - ), + return _provider_tool( + "file_search", + "openai.file_search", + vector_store_ids=vector_store_ids, + max_num_results=max_num_results, + ranking=_dump(ranking), + filters=filters, ) @@ -214,10 +118,12 @@ def code_interpreter( *, container: CodeInterpreterContainer | str | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="code_interpreter", - args=CodeInterpreterArgs(container=container), + return _provider_tool( + "code_interpreter", + "openai.code_interpreter", + container=_dump(container) + if isinstance(container, CodeInterpreterContainer) + else container, ) @@ -233,43 +139,31 @@ def image_generation( quality: Literal["low", "medium", "high", "auto"] | None = None, size: str | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="image_generation", - args=ImageGenerationArgs( - background=background, - input_fidelity=input_fidelity, - model=model, - moderation=moderation, - output_compression=output_compression, - output_format=output_format, - partial_images=partial_images, - quality=quality, - size=size, - ), + return _provider_tool( + "image_generation", + "openai.image_generation", + background=background, + input_fidelity=input_fidelity, + model=model, + moderation=moderation, + output_compression=output_compression, + output_format=output_format, + partial_images=partial_images, + quality=quality, + size=size, ) def local_shell() -> types.tools.Tool: - return types.tools.Tool( - kind="provider", name="local_shell", args=LocalShellArgs() - ) + return _provider_tool("local_shell", "openai.local_shell") def shell(*, environment: str | None = None) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="shell", - args=ShellArgs(environment=environment), - ) + return _provider_tool("shell", "openai.shell", environment=environment) def apply_patch() -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="apply_patch", - args=ApplyPatchArgs(), - ) + return _provider_tool("apply_patch", "openai.apply_patch") def mcp( @@ -282,18 +176,16 @@ def mcp( allowed_tools: list[str] | dict[str, Any] | None = None, server_description: str | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="mcp", - args=McpArgs( - server_label=server_label, - server_url=server_url, - connector_id=connector_id, - authorization=authorization, - headers=headers, - allowed_tools=allowed_tools, - server_description=server_description, - ), + return _provider_tool( + "mcp", + "openai.mcp", + server_label=server_label, + server_url=server_url, + connector_id=connector_id, + authorization=authorization, + headers=headers, + allowed_tools=allowed_tools, + server_description=server_description, ) @@ -303,32 +195,19 @@ def tool_search( parameters: dict[str, Any] | None = None, execution: dict[str, Any] | None = None, ) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name="tool_search", - args=ToolSearchArgs( - description=description, - parameters=parameters, - execution=execution, - ), + return _provider_tool( + "tool_search", + "openai.tool_search", + description=description, + parameters=parameters, + execution=execution, ) __all__ = [ - "ApplyPatchArgs", - "CodeInterpreterArgs", "CodeInterpreterContainer", - "FileSearchArgs", "FileSearchRanking", - "ImageGenerationArgs", - "LocalShellArgs", - "McpArgs", - "OpenAIProviderArgs", - "ShellArgs", - "ToolSearchArgs", - "WebSearchArgs", "WebSearchFilters", - "WebSearchPreviewArgs", "WebSearchUserLocation", "apply_patch", "code_interpreter", diff --git a/src/ai/types/tools.py b/src/ai/types/tools.py index 3016909d..039a0a0c 100644 --- a/src/ai/types/tools.py +++ b/src/ai/types/tools.py @@ -12,30 +12,47 @@ class ToolApproval(pydantic.BaseModel): reason: str | None = None -class FunctionToolArgs(pydantic.BaseModel): - description: str | None = pydantic.Field(default=None) +class ToolSpec(pydantic.BaseModel): + """Model-facing declaration of a host-executed function tool.""" + + description: str | None = None params: dict[str, Any] +class ToolConfig(pydantic.BaseModel): + """Execution configuration for a tool. + + For provider-executed tools ``id`` is the canonical provider tool id + (e.g. ``"anthropic.web_search_20260209"``, ``"openai.mcp"``) and + ``args`` holds the provider wire arguments as plain snake_case data. + """ + + id: str | None = None + args: dict[str, Any] = pydantic.Field(default_factory=dict) + + class Tool(pydantic.BaseModel): kind: Literal["function", "provider"] name: str - args: pydantic.BaseModel + spec: ToolSpec | None = None + tool_config: ToolConfig | None = None require_approval: bool = False @pydantic.model_validator(mode="after") - def validate_args_shape(self) -> Self: + def validate_shape(self) -> Self: match self.kind: case "function": - if not isinstance(self.args, FunctionToolArgs): + if self.spec is None: raise ValueError( - "function tools require args=FunctionToolArgs(...)" + "function tools require spec=ToolSpec(...)" ) case "provider": - if isinstance(self.args, FunctionToolArgs): + if self.spec is not None: + raise ValueError("provider tools cannot have spec") + if self.tool_config is None or self.tool_config.id is None: raise ValueError( - "provider tools cannot use FunctionToolArgs" + "provider tools require tool_config with an id" ) return self diff --git a/tests/agents/mcp/test_client.py b/tests/agents/mcp/test_client.py index f6db7b0c..41d6a039 100644 --- a/tests/agents/mcp/test_client.py +++ b/tests/agents/mcp/test_client.py @@ -53,7 +53,7 @@ def test_mcp_tool_to_native_basic() -> None: ) assert native.name == "mcp_basic_test" - assert _function_args(native).description == "Echo input" + assert _spec(native).description == "Echo input" def test_mcp_tool_to_native_with_prefix() -> None: @@ -73,8 +73,8 @@ def test_mcp_tool_to_native_schema_preserved() -> None: mcp_tool, "test:key", _noop_transport_factory, None ) - assert _function_args(native).params == mcp_tool.inputSchema - assert _function_args(native).description == "Echo input" + assert _spec(native).params == mcp_tool.inputSchema + assert _spec(native).description == "Echo input" async def test_get_http_tools_raises_installation_error_without_mcp( @@ -144,7 +144,7 @@ async def fake_fn(**kwargs: str) -> str: assert llm.call_count == 2 -def _function_args(tool: ai.AgentTool) -> ai.tools.FunctionToolArgs: - args = tool.tool.args - assert isinstance(args, ai.tools.FunctionToolArgs) - return args +def _spec(tool: ai.AgentTool) -> ai.tools.ToolSpec: + spec = tool.tool.spec + assert spec is not None + return spec diff --git a/tests/agents/test_tools.py b/tests/agents/test_tools.py index f28650f0..61a71181 100644 --- a/tests/agents/test_tools.py +++ b/tests/agents/test_tools.py @@ -20,7 +20,7 @@ async def greet(name: str, count: int) -> str: return f"Hello {name}" * count assert greet.name == "greet" - assert _function_args(greet).description == "Say hello." + assert _spec(greet).description == "Say hello." props = _schema(greet)["properties"] assert props["name"]["type"] == "string" assert props["count"]["type"] == "integer" @@ -226,13 +226,13 @@ def _required(tool: ai.AgentTool) -> list[str]: def _schema(tool: ai.AgentTool) -> dict[str, Any]: - return _function_args(tool).params + return _spec(tool).params -def _function_args(tool: ai.AgentTool) -> ai.tools.FunctionToolArgs: - args = tool.tool.args - assert isinstance(args, ai.tools.FunctionToolArgs) - return args +def _spec(tool: ai.AgentTool) -> ai.tools.ToolSpec: + spec = tool.tool.spec + assert spec is not None + return spec # Module-level model so get_type_hints() can resolve it when @ai.tool diff --git a/tests/providers/ai_gateway/test_tools.py b/tests/providers/ai_gateway/test_tools.py index 8137cd2f..705fe4e0 100644 --- a/tests/providers/ai_gateway/test_tools.py +++ b/tests/providers/ai_gateway/test_tools.py @@ -11,8 +11,6 @@ from typing import Any import httpx -import pydantic -import pytest from ai import types from ai.providers.ai_gateway import tools as gateway_tools @@ -140,28 +138,67 @@ async def test_gateway_parallel_search_serializes(self) -> None: } ] - async def test_unknown_provider_args_rejected(self) -> None: - """Provider-executed tools need a registered args type.""" + async def test_openai_mcp_opaque_payloads_not_camelized(self) -> None: + """Free-form payloads (headers, allowed_tools) reach the wire + verbatim while structured config keys are camelized.""" + captured: dict[str, Any] = {} + model = mock_model(httpx.MockTransport(_capture_body_handler(captured))) - class UnknownArgs(pydantic.BaseModel): - value: str + async for _ in model.provider.stream( + model, + [user_msg("hi")], + tools=[ + openai_tools.mcp( + server_label="my-server", + server_url="https://mcp.example.com", + headers={"x_api_key": "secret"}, + allowed_tools={"tool_names": ["my_tool"]}, + ), + ], + ): + pass - def handler(req: httpx.Request) -> httpx.Response: - raise AssertionError("request should not be sent") + assert captured["tools"] == [ + { + "type": "provider", + "id": "openai.mcp", + "name": "mcp", + "args": { + "serverLabel": "my-server", + "serverUrl": "https://mcp.example.com", + "headers": {"x_api_key": "secret"}, + "allowedTools": {"tool_names": ["my_tool"]}, + }, + } + ] - model = mock_model(httpx.MockTransport(handler)) - stream = model.provider.stream( + async def test_unknown_provider_id_passes_through(self) -> None: + """Ids without a local factory are forwarded verbatim — the + gateway, not this adapter, owns the set of supported tools.""" + captured: dict[str, Any] = {} + model = mock_model(httpx.MockTransport(_capture_body_handler(captured))) + + async for _ in model.provider.stream( model, [user_msg("hi")], tools=[ types.tools.Tool( kind="provider", - name="bad", - args=UnknownArgs(value="x"), + name="frobnicate", + tool_config=types.tools.ToolConfig( + id="gateway.frobnicate", + args={"max_uses": 1}, + ), ) ], - ) + ): + pass - with pytest.raises(TypeError, match="unsupported args"): - async for _ in stream: - pass + assert captured["tools"] == [ + { + "type": "provider", + "id": "gateway.frobnicate", + "name": "frobnicate", + "args": {"maxUses": 1}, + } + ] diff --git a/tests/providers/openai/test_adapter.py b/tests/providers/openai/test_adapter.py index 30d6ba86..06d97177 100644 --- a/tests/providers/openai/test_adapter.py +++ b/tests/providers/openai/test_adapter.py @@ -251,7 +251,7 @@ async def test_responses_tools_convert_function_and_provider_tools() -> None: tools.Tool( kind="function", name="weather", - args=tools.FunctionToolArgs( + spec=tools.ToolSpec( description="Get weather", params={ "type": "object", diff --git a/tests/types/test_tools.py b/tests/types/test_tools.py new file mode 100644 index 00000000..7ff2a08d --- /dev/null +++ b/tests/types/test_tools.py @@ -0,0 +1,78 @@ +"""Tool data model: shape validation and JSON roundtripping.""" + +from __future__ import annotations + +import pydantic +import pytest + +from ai import types +from ai.providers.anthropic import tools as anthropic_tools +from ai.providers.openai import tools as openai_tools + + +def _roundtrip(tool: types.tools.Tool) -> types.tools.Tool: + return types.tools.Tool.model_validate_json(tool.model_dump_json()) + + +def test_function_tool_roundtrips_through_json() -> None: + tool = types.tools.Tool( + kind="function", + name="get_weather", + spec=types.tools.ToolSpec( + description="Get the weather.", + params={ + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + ), + require_approval=True, + ) + + assert _roundtrip(tool) == tool + + +def test_provider_tool_roundtrips_through_json() -> None: + tool = anthropic_tools.web_search( + max_uses=3, + allowed_domains=["example.com"], + user_location=anthropic_tools.UserLocation(city="SF", country="US"), + ) + + assert _roundtrip(tool) == tool + + +def test_provider_tool_with_opaque_payload_roundtrips() -> None: + tool = openai_tools.mcp( + server_label="my-server", + headers={"x_api_key": "secret"}, + ) + + assert _roundtrip(tool) == tool + + +def test_function_tool_requires_spec() -> None: + with pytest.raises(pydantic.ValidationError, match="require spec"): + types.tools.Tool(kind="function", name="bad") + + +def test_provider_tool_rejects_spec() -> None: + with pytest.raises(pydantic.ValidationError, match="cannot have spec"): + types.tools.Tool( + kind="provider", + name="bad", + spec=types.tools.ToolSpec(params={}), + tool_config=types.tools.ToolConfig(id="x.y"), + ) + + +def test_provider_tool_requires_tool_config_id() -> None: + with pytest.raises(pydantic.ValidationError, match="tool_config"): + types.tools.Tool(kind="provider", name="bad") + + with pytest.raises(pydantic.ValidationError, match="tool_config"): + types.tools.Tool( + kind="provider", + name="bad", + tool_config=types.tools.ToolConfig(), + ) From 5374f1fd402aed3294ee0cdc8f66d8e64d3a554c Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Tue, 16 Jun 2026 08:17:25 -0700 Subject: [PATCH 2/3] Inline provider tool definitions --- src/ai/providers/ai_gateway/tools.py | 65 +++++---- src/ai/providers/anthropic/tools.py | 133 +++++++++++------ src/ai/providers/openai/tools.py | 205 ++++++++++++++++++--------- 3 files changed, 262 insertions(+), 141 deletions(-) diff --git a/src/ai/providers/ai_gateway/tools.py b/src/ai/providers/ai_gateway/tools.py index 910ee502..515774dc 100644 --- a/src/ai/providers/ai_gateway/tools.py +++ b/src/ai/providers/ai_gateway/tools.py @@ -43,17 +43,6 @@ class FetchPolicy(pydantic.BaseModel): max_age_seconds: int | None = None -def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name=name, - tool_config=types.tools.ToolConfig( - id=id, - args={k: v for k, v in args.items() if v is not None}, - ), - ) - - def _dump[M: pydantic.BaseModel]( model_type: type[M], value: M | dict[str, object] | None ) -> dict[str, Any] | None: @@ -75,16 +64,25 @@ def perplexity_search( search_recency_filter: Literal["day", "week", "month", "year"] | None = None, ) -> types.tools.Tool: - return _provider_tool( - "perplexity_search", - "gateway.perplexity_search", - max_results=max_results, - max_tokens_per_page=max_tokens_per_page, - max_tokens=max_tokens, - country=country, - search_domain_filter=search_domain_filter, - search_language_filter=search_language_filter, - search_recency_filter=search_recency_filter, + return types.tools.Tool( + kind="provider", + name="perplexity_search", + tool_config=types.tools.ToolConfig( + id="gateway.perplexity_search", + args={ + k: v + for k, v in { + "max_results": max_results, + "max_tokens_per_page": max_tokens_per_page, + "max_tokens": max_tokens, + "country": country, + "search_domain_filter": search_domain_filter, + "search_language_filter": search_language_filter, + "search_recency_filter": search_recency_filter, + }.items() + if v is not None + }, + ), ) @@ -96,14 +94,23 @@ def parallel_search( excerpts: Excerpts | dict[str, object] | None = None, fetch_policy: FetchPolicy | dict[str, object] | None = None, ) -> types.tools.Tool: - return _provider_tool( - "parallel_search", - "gateway.parallel_search", - mode=mode, - max_results=max_results, - source_policy=_dump(SourcePolicy, source_policy), - excerpts=_dump(Excerpts, excerpts), - fetch_policy=_dump(FetchPolicy, fetch_policy), + return types.tools.Tool( + kind="provider", + name="parallel_search", + tool_config=types.tools.ToolConfig( + id="gateway.parallel_search", + args={ + k: v + for k, v in { + "mode": mode, + "max_results": max_results, + "source_policy": _dump(SourcePolicy, source_policy), + "excerpts": _dump(Excerpts, excerpts), + "fetch_policy": _dump(FetchPolicy, fetch_policy), + }.items() + if v is not None + }, + ), ) diff --git a/src/ai/providers/anthropic/tools.py b/src/ai/providers/anthropic/tools.py index 4a95a4d5..173de91f 100644 --- a/src/ai/providers/anthropic/tools.py +++ b/src/ai/providers/anthropic/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Literal +from typing import Literal import pydantic from pydantic.alias_generators import to_camel @@ -46,17 +46,6 @@ class Citations(pydantic.BaseModel): enabled: bool -def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name=name, - tool_config=types.tools.ToolConfig( - id=id, - args={k: v for k, v in args.items() if v is not None}, - ), - ) - - def _check_domains( tool_name: str, allowed_domains: list[str] | None, @@ -77,15 +66,27 @@ def web_search( user_location: UserLocation | None = None, ) -> types.tools.Tool: _check_domains("web_search", allowed_domains, blocked_domains) - return _provider_tool( - "web_search", - "anthropic.web_search_20260209", - max_uses=max_uses, - allowed_domains=allowed_domains, - blocked_domains=blocked_domains, - user_location=user_location.model_dump(mode="json", exclude_none=True) - if user_location is not None - else None, + return types.tools.Tool( + kind="provider", + name="web_search", + tool_config=types.tools.ToolConfig( + id="anthropic.web_search_20260209", + args={ + k: v + for k, v in { + "max_uses": max_uses, + "allowed_domains": allowed_domains, + "blocked_domains": blocked_domains, + "user_location": user_location.model_dump( + mode="json", + exclude_none=True, + ) + if user_location is not None + else None, + }.items() + if v is not None + }, + ), ) @@ -100,21 +101,39 @@ def web_fetch( _check_domains("web_fetch", allowed_domains, blocked_domains) if isinstance(citations, bool): citations = Citations(enabled=citations) - return _provider_tool( - "web_fetch", - "anthropic.web_fetch_20260209", - max_uses=max_uses, - allowed_domains=allowed_domains, - blocked_domains=blocked_domains, - citations=citations.model_dump(mode="json", exclude_none=True) - if citations is not None - else None, - max_content_tokens=max_content_tokens, + return types.tools.Tool( + kind="provider", + name="web_fetch", + tool_config=types.tools.ToolConfig( + id="anthropic.web_fetch_20260209", + args={ + k: v + for k, v in { + "max_uses": max_uses, + "allowed_domains": allowed_domains, + "blocked_domains": blocked_domains, + "citations": citations.model_dump( + mode="json", + exclude_none=True, + ) + if citations is not None + else None, + "max_content_tokens": max_content_tokens, + }.items() + if v is not None + }, + ), ) def code_execution() -> types.tools.Tool: - return _provider_tool("code_execution", "anthropic.code_execution_20260120") + return types.tools.Tool( + kind="provider", + name="code_execution", + tool_config=types.tools.ToolConfig( + id="anthropic.code_execution_20260120" + ), + ) def computer_use( @@ -124,30 +143,54 @@ def computer_use( display_number: int | None = None, enable_zoom: bool | None = None, ) -> types.tools.Tool: - return _provider_tool( - "computer", - "anthropic.computer_20251124", - display_width_px=display_width_px, - display_height_px=display_height_px, - display_number=display_number, - enable_zoom=enable_zoom, + return types.tools.Tool( + kind="provider", + name="computer", + tool_config=types.tools.ToolConfig( + id="anthropic.computer_20251124", + args={ + k: v + for k, v in { + "display_width_px": display_width_px, + "display_height_px": display_height_px, + "display_number": display_number, + "enable_zoom": enable_zoom, + }.items() + if v is not None + }, + ), ) def text_editor(*, max_characters: int | None = None) -> types.tools.Tool: - return _provider_tool( - "str_replace_based_edit_tool", - "anthropic.text_editor_20250728", - max_characters=max_characters, + return types.tools.Tool( + kind="provider", + name="str_replace_based_edit_tool", + tool_config=types.tools.ToolConfig( + id="anthropic.text_editor_20250728", + args={ + k: v + for k, v in {"max_characters": max_characters}.items() + if v is not None + }, + ), ) def bash() -> types.tools.Tool: - return _provider_tool("bash", "anthropic.bash_20250124") + return types.tools.Tool( + kind="provider", + name="bash", + tool_config=types.tools.ToolConfig(id="anthropic.bash_20250124"), + ) def memory() -> types.tools.Tool: - return _provider_tool("memory", "anthropic.memory_20250818") + return types.tools.Tool( + kind="provider", + name="memory", + tool_config=types.tools.ToolConfig(id="anthropic.memory_20250818"), + ) __all__ = [ diff --git a/src/ai/providers/openai/tools.py b/src/ai/providers/openai/tools.py index a3a10a9b..da37bab4 100644 --- a/src/ai/providers/openai/tools.py +++ b/src/ai/providers/openai/tools.py @@ -50,17 +50,6 @@ class CodeInterpreterContainer(pydantic.BaseModel): file_ids: list[str] | None = None -def _provider_tool(name: str, id: str, **args: Any) -> types.tools.Tool: - return types.tools.Tool( - kind="provider", - name=name, - tool_config=types.tools.ToolConfig( - id=id, - args={k: v for k, v in args.items() if v is not None}, - ), - ) - - def _dump(model: pydantic.BaseModel | None) -> dict[str, Any] | None: if model is None: return None @@ -74,13 +63,22 @@ def web_search( search_context_size: Literal["low", "medium", "high"] | None = None, user_location: WebSearchUserLocation | None = None, ) -> types.tools.Tool: - return _provider_tool( - "web_search", - "openai.web_search", - external_web_access=external_web_access, - filters=_dump(filters), - search_context_size=search_context_size, - user_location=_dump(user_location), + return types.tools.Tool( + kind="provider", + name="web_search", + tool_config=types.tools.ToolConfig( + id="openai.web_search", + args={ + k: v + for k, v in { + "external_web_access": external_web_access, + "filters": _dump(filters), + "search_context_size": search_context_size, + "user_location": _dump(user_location), + }.items() + if v is not None + }, + ), ) @@ -89,11 +87,20 @@ def web_search_preview( search_context_size: Literal["low", "medium", "high"] | None = None, user_location: WebSearchUserLocation | None = None, ) -> types.tools.Tool: - return _provider_tool( - "web_search_preview", - "openai.web_search_preview", - search_context_size=search_context_size, - user_location=_dump(user_location), + return types.tools.Tool( + kind="provider", + name="web_search_preview", + tool_config=types.tools.ToolConfig( + id="openai.web_search_preview", + args={ + k: v + for k, v in { + "search_context_size": search_context_size, + "user_location": _dump(user_location), + }.items() + if v is not None + }, + ), ) @@ -104,13 +111,22 @@ def file_search( ranking: FileSearchRanking | None = None, filters: dict[str, Any] | None = None, ) -> types.tools.Tool: - return _provider_tool( - "file_search", - "openai.file_search", - vector_store_ids=vector_store_ids, - max_num_results=max_num_results, - ranking=_dump(ranking), - filters=filters, + return types.tools.Tool( + kind="provider", + name="file_search", + tool_config=types.tools.ToolConfig( + id="openai.file_search", + args={ + k: v + for k, v in { + "vector_store_ids": vector_store_ids, + "max_num_results": max_num_results, + "ranking": _dump(ranking), + "filters": filters, + }.items() + if v is not None + }, + ), ) @@ -118,12 +134,21 @@ def code_interpreter( *, container: CodeInterpreterContainer | str | None = None, ) -> types.tools.Tool: - return _provider_tool( - "code_interpreter", - "openai.code_interpreter", - container=_dump(container) - if isinstance(container, CodeInterpreterContainer) - else container, + return types.tools.Tool( + kind="provider", + name="code_interpreter", + tool_config=types.tools.ToolConfig( + id="openai.code_interpreter", + args={ + k: v + for k, v in { + "container": _dump(container) + if isinstance(container, CodeInterpreterContainer) + else container, + }.items() + if v is not None + }, + ), ) @@ -139,31 +164,59 @@ def image_generation( quality: Literal["low", "medium", "high", "auto"] | None = None, size: str | None = None, ) -> types.tools.Tool: - return _provider_tool( - "image_generation", - "openai.image_generation", - background=background, - input_fidelity=input_fidelity, - model=model, - moderation=moderation, - output_compression=output_compression, - output_format=output_format, - partial_images=partial_images, - quality=quality, - size=size, + return types.tools.Tool( + kind="provider", + name="image_generation", + tool_config=types.tools.ToolConfig( + id="openai.image_generation", + args={ + k: v + for k, v in { + "background": background, + "input_fidelity": input_fidelity, + "model": model, + "moderation": moderation, + "output_compression": output_compression, + "output_format": output_format, + "partial_images": partial_images, + "quality": quality, + "size": size, + }.items() + if v is not None + }, + ), ) def local_shell() -> types.tools.Tool: - return _provider_tool("local_shell", "openai.local_shell") + return types.tools.Tool( + kind="provider", + name="local_shell", + tool_config=types.tools.ToolConfig(id="openai.local_shell"), + ) def shell(*, environment: str | None = None) -> types.tools.Tool: - return _provider_tool("shell", "openai.shell", environment=environment) + return types.tools.Tool( + kind="provider", + name="shell", + tool_config=types.tools.ToolConfig( + id="openai.shell", + args={ + k: v + for k, v in {"environment": environment}.items() + if v is not None + }, + ), + ) def apply_patch() -> types.tools.Tool: - return _provider_tool("apply_patch", "openai.apply_patch") + return types.tools.Tool( + kind="provider", + name="apply_patch", + tool_config=types.tools.ToolConfig(id="openai.apply_patch"), + ) def mcp( @@ -176,16 +229,25 @@ def mcp( allowed_tools: list[str] | dict[str, Any] | None = None, server_description: str | None = None, ) -> types.tools.Tool: - return _provider_tool( - "mcp", - "openai.mcp", - server_label=server_label, - server_url=server_url, - connector_id=connector_id, - authorization=authorization, - headers=headers, - allowed_tools=allowed_tools, - server_description=server_description, + return types.tools.Tool( + kind="provider", + name="mcp", + tool_config=types.tools.ToolConfig( + id="openai.mcp", + args={ + k: v + for k, v in { + "server_label": server_label, + "server_url": server_url, + "connector_id": connector_id, + "authorization": authorization, + "headers": headers, + "allowed_tools": allowed_tools, + "server_description": server_description, + }.items() + if v is not None + }, + ), ) @@ -195,12 +257,21 @@ def tool_search( parameters: dict[str, Any] | None = None, execution: dict[str, Any] | None = None, ) -> types.tools.Tool: - return _provider_tool( - "tool_search", - "openai.tool_search", - description=description, - parameters=parameters, - execution=execution, + return types.tools.Tool( + kind="provider", + name="tool_search", + tool_config=types.tools.ToolConfig( + id="openai.tool_search", + args={ + k: v + for k, v in { + "description": description, + "parameters": parameters, + "execution": execution, + }.items() + if v is not None + }, + ), ) From 037d479a22e068134922887a6440368dbc54c499 Mon Sep 17 00:00:00 2001 From: Andrey Buzin Date: Thu, 18 Jun 2026 19:28:25 -0700 Subject: [PATCH 3/3] Refactor tool config declaration --- src/ai/providers/ai_gateway/tools.py | 44 +++++---- src/ai/providers/anthropic/tools.py | 82 ++++++++--------- src/ai/providers/openai/tools.py | 130 +++++++++++---------------- 3 files changed, 106 insertions(+), 150 deletions(-) diff --git a/src/ai/providers/ai_gateway/tools.py b/src/ai/providers/ai_gateway/tools.py index 515774dc..bba7e0e2 100644 --- a/src/ai/providers/ai_gateway/tools.py +++ b/src/ai/providers/ai_gateway/tools.py @@ -53,6 +53,10 @@ def _dump[M: pydantic.BaseModel]( return value.model_dump(mode="json", exclude_none=True) +def _dict_filter_none(**args: Any) -> dict[str, Any]: + return {k: v for k, v in args.items() if v is not None} + + def perplexity_search( *, max_results: int | None = None, @@ -69,19 +73,15 @@ def perplexity_search( name="perplexity_search", tool_config=types.tools.ToolConfig( id="gateway.perplexity_search", - args={ - k: v - for k, v in { - "max_results": max_results, - "max_tokens_per_page": max_tokens_per_page, - "max_tokens": max_tokens, - "country": country, - "search_domain_filter": search_domain_filter, - "search_language_filter": search_language_filter, - "search_recency_filter": search_recency_filter, - }.items() - if v is not None - }, + args=_dict_filter_none( + max_results=max_results, + max_tokens_per_page=max_tokens_per_page, + max_tokens=max_tokens, + country=country, + search_domain_filter=search_domain_filter, + search_language_filter=search_language_filter, + search_recency_filter=search_recency_filter, + ), ), ) @@ -99,17 +99,13 @@ def parallel_search( name="parallel_search", tool_config=types.tools.ToolConfig( id="gateway.parallel_search", - args={ - k: v - for k, v in { - "mode": mode, - "max_results": max_results, - "source_policy": _dump(SourcePolicy, source_policy), - "excerpts": _dump(Excerpts, excerpts), - "fetch_policy": _dump(FetchPolicy, fetch_policy), - }.items() - if v is not None - }, + args=_dict_filter_none( + mode=mode, + max_results=max_results, + source_policy=_dump(SourcePolicy, source_policy), + excerpts=_dump(Excerpts, excerpts), + fetch_policy=_dump(FetchPolicy, fetch_policy), + ), ), ) diff --git a/src/ai/providers/anthropic/tools.py b/src/ai/providers/anthropic/tools.py index 173de91f..7930adbc 100644 --- a/src/ai/providers/anthropic/tools.py +++ b/src/ai/providers/anthropic/tools.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Literal +from typing import Any, Literal import pydantic from pydantic.alias_generators import to_camel @@ -58,6 +58,10 @@ def _check_domains( ) +def _dict_filter_none(**args: Any) -> dict[str, Any]: + return {k: v for k, v in args.items() if v is not None} + + def web_search( *, max_uses: int | None = None, @@ -71,21 +75,17 @@ def web_search( name="web_search", tool_config=types.tools.ToolConfig( id="anthropic.web_search_20260209", - args={ - k: v - for k, v in { - "max_uses": max_uses, - "allowed_domains": allowed_domains, - "blocked_domains": blocked_domains, - "user_location": user_location.model_dump( - mode="json", - exclude_none=True, - ) - if user_location is not None - else None, - }.items() - if v is not None - }, + args=_dict_filter_none( + max_uses=max_uses, + allowed_domains=allowed_domains, + blocked_domains=blocked_domains, + user_location=user_location.model_dump( + mode="json", + exclude_none=True, + ) + if user_location is not None + else None, + ), ), ) @@ -106,22 +106,18 @@ def web_fetch( name="web_fetch", tool_config=types.tools.ToolConfig( id="anthropic.web_fetch_20260209", - args={ - k: v - for k, v in { - "max_uses": max_uses, - "allowed_domains": allowed_domains, - "blocked_domains": blocked_domains, - "citations": citations.model_dump( - mode="json", - exclude_none=True, - ) - if citations is not None - else None, - "max_content_tokens": max_content_tokens, - }.items() - if v is not None - }, + args=_dict_filter_none( + max_uses=max_uses, + allowed_domains=allowed_domains, + blocked_domains=blocked_domains, + citations=citations.model_dump( + mode="json", + exclude_none=True, + ) + if citations is not None + else None, + max_content_tokens=max_content_tokens, + ), ), ) @@ -148,16 +144,12 @@ def computer_use( name="computer", tool_config=types.tools.ToolConfig( id="anthropic.computer_20251124", - args={ - k: v - for k, v in { - "display_width_px": display_width_px, - "display_height_px": display_height_px, - "display_number": display_number, - "enable_zoom": enable_zoom, - }.items() - if v is not None - }, + args=_dict_filter_none( + display_width_px=display_width_px, + display_height_px=display_height_px, + display_number=display_number, + enable_zoom=enable_zoom, + ), ), ) @@ -168,11 +160,7 @@ def text_editor(*, max_characters: int | None = None) -> types.tools.Tool: name="str_replace_based_edit_tool", tool_config=types.tools.ToolConfig( id="anthropic.text_editor_20250728", - args={ - k: v - for k, v in {"max_characters": max_characters}.items() - if v is not None - }, + args=_dict_filter_none(max_characters=max_characters), ), ) diff --git a/src/ai/providers/openai/tools.py b/src/ai/providers/openai/tools.py index da37bab4..d89ecbcf 100644 --- a/src/ai/providers/openai/tools.py +++ b/src/ai/providers/openai/tools.py @@ -56,6 +56,10 @@ def _dump(model: pydantic.BaseModel | None) -> dict[str, Any] | None: return model.model_dump(mode="json", exclude_none=True) +def _dict_filter_none(**args: Any) -> dict[str, Any]: + return {k: v for k, v in args.items() if v is not None} + + def web_search( *, external_web_access: bool | None = None, @@ -68,16 +72,12 @@ def web_search( name="web_search", tool_config=types.tools.ToolConfig( id="openai.web_search", - args={ - k: v - for k, v in { - "external_web_access": external_web_access, - "filters": _dump(filters), - "search_context_size": search_context_size, - "user_location": _dump(user_location), - }.items() - if v is not None - }, + args=_dict_filter_none( + external_web_access=external_web_access, + filters=_dump(filters), + search_context_size=search_context_size, + user_location=_dump(user_location), + ), ), ) @@ -92,14 +92,10 @@ def web_search_preview( name="web_search_preview", tool_config=types.tools.ToolConfig( id="openai.web_search_preview", - args={ - k: v - for k, v in { - "search_context_size": search_context_size, - "user_location": _dump(user_location), - }.items() - if v is not None - }, + args=_dict_filter_none( + search_context_size=search_context_size, + user_location=_dump(user_location), + ), ), ) @@ -116,16 +112,12 @@ def file_search( name="file_search", tool_config=types.tools.ToolConfig( id="openai.file_search", - args={ - k: v - for k, v in { - "vector_store_ids": vector_store_ids, - "max_num_results": max_num_results, - "ranking": _dump(ranking), - "filters": filters, - }.items() - if v is not None - }, + args=_dict_filter_none( + vector_store_ids=vector_store_ids, + max_num_results=max_num_results, + ranking=_dump(ranking), + filters=filters, + ), ), ) @@ -139,15 +131,11 @@ def code_interpreter( name="code_interpreter", tool_config=types.tools.ToolConfig( id="openai.code_interpreter", - args={ - k: v - for k, v in { - "container": _dump(container) - if isinstance(container, CodeInterpreterContainer) - else container, - }.items() - if v is not None - }, + args=_dict_filter_none( + container=_dump(container) + if isinstance(container, CodeInterpreterContainer) + else container, + ), ), ) @@ -169,21 +157,17 @@ def image_generation( name="image_generation", tool_config=types.tools.ToolConfig( id="openai.image_generation", - args={ - k: v - for k, v in { - "background": background, - "input_fidelity": input_fidelity, - "model": model, - "moderation": moderation, - "output_compression": output_compression, - "output_format": output_format, - "partial_images": partial_images, - "quality": quality, - "size": size, - }.items() - if v is not None - }, + args=_dict_filter_none( + background=background, + input_fidelity=input_fidelity, + model=model, + moderation=moderation, + output_compression=output_compression, + output_format=output_format, + partial_images=partial_images, + quality=quality, + size=size, + ), ), ) @@ -202,11 +186,7 @@ def shell(*, environment: str | None = None) -> types.tools.Tool: name="shell", tool_config=types.tools.ToolConfig( id="openai.shell", - args={ - k: v - for k, v in {"environment": environment}.items() - if v is not None - }, + args=_dict_filter_none(environment=environment), ), ) @@ -234,19 +214,15 @@ def mcp( name="mcp", tool_config=types.tools.ToolConfig( id="openai.mcp", - args={ - k: v - for k, v in { - "server_label": server_label, - "server_url": server_url, - "connector_id": connector_id, - "authorization": authorization, - "headers": headers, - "allowed_tools": allowed_tools, - "server_description": server_description, - }.items() - if v is not None - }, + args=_dict_filter_none( + server_label=server_label, + server_url=server_url, + connector_id=connector_id, + authorization=authorization, + headers=headers, + allowed_tools=allowed_tools, + server_description=server_description, + ), ), ) @@ -262,15 +238,11 @@ def tool_search( name="tool_search", tool_config=types.tools.ToolConfig( id="openai.tool_search", - args={ - k: v - for k, v in { - "description": description, - "parameters": parameters, - "execution": execution, - }.items() - if v is not None - }, + args=_dict_filter_none( + description=description, + parameters=parameters, + execution=execution, + ), ), )