Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions examples/temporal-direct/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,14 +131,7 @@ async def llm_call_activity(params: LLMParams) -> LLMResult:
"""Call the LLM, drain the stream, return the final message."""
model = ai.get_model(params.model_id)
messages = [ai.messages.Message.model_validate(m) for m in params.messages]
tools = [
ai.Tool(
kind="function",
name=t["name"],
args=ai.tools.FunctionToolArgs.model_validate(t["args"]),
)
for t in params.tool_schemas
]
tools = [ai.Tool.model_validate(t) for t in params.tool_schemas]

async with ai.stream(model, messages, tools=tools) as s:
async for _event in s:
Expand All @@ -162,10 +155,7 @@ class WeatherAgent(ai.Agent):
async def loop(
self, context: ai.Context
) -> AsyncGenerator[ai.events.AgentEvent]:
tool_schemas = [
{"name": t.name, "args": t.args.model_dump(mode="json")}
for t in context.tools
]
tool_schemas = [t.model_dump(mode="json") for t in context.tools]

while context.keep_running():
# 1. LLM call via activity → complete message
Expand Down
2 changes: 1 addition & 1 deletion examples/tools_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
get_weather = ai.Tool(
kind="function",
name="get_weather",
args=ai.tools.FunctionToolArgs(
spec=ai.tools.ToolSpec(
description="Get the current weather for a city.",
params={
"type": "object",
Expand Down
2 changes: 1 addition & 1 deletion src/ai/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def wrap(fn: Any) -> AgentTool:
tool_decl = Tool(
kind="function",
name=fn.__name__,
args=types.tools.FunctionToolArgs(
spec=types.tools.ToolSpec(
description=inspect.getdoc(fn) or "",
params=validator.model_json_schema(),
),
Expand Down
2 changes: 1 addition & 1 deletion src/ai/agents/mcp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _mcp_tool_to_native(
tool = Tool(
kind="function",
name=name,
args=types.tools.FunctionToolArgs(
spec=types.tools.ToolSpec(
description=mcp_tool.description or "",
params=mcp_tool.inputSchema,
),
Expand Down
67 changes: 35 additions & 32 deletions src/ai/providers/ai_gateway/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@

import httpx
import pydantic
from pydantic.alias_generators import to_camel

from ... import types
from ...models import core
from ...models.core import params as params_
from .. import base
from ..anthropic import tools as anthropic_tools
from ..openai import tools as openai_tools
from . import client as gateway_client
from . import errors
from . import params as gateway_params
from . import tools as gateway_tools
from .client import errors as client_errors

# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -276,48 +274,53 @@ async def _messages_to_prompt(
return result


# Free-form payload fields whose keys are data, not config structure —
# their subtrees must reach the wire verbatim, never camelized.
_OPAQUE_ARG_KEYS: dict[str, frozenset[str]] = {
"openai.mcp": frozenset({"headers", "allowed_tools"}),
"openai.file_search": frozenset({"filters"}),
"openai.tool_search": frozenset({"parameters", "execution"}),
}


def _camelize(value: Any) -> Any:
if isinstance(value, dict):
return {to_camel(k): _camelize(v) for k, v in value.items()}
if isinstance(value, list):
return [_camelize(v) for v in value]
return value


def _tool_to_v3(tool: types.tools.Tool) -> dict[str, Any]:
"""Convert a tool schema blob to the v3 wire format."""
if tool.kind == "provider":
cfg = tool.tool_config
tool_id = cfg.id if cfg is not None else None
if tool_id is None:
raise TypeError(
f"provider tool {tool.name!r} has no tool_config id"
)
opaque = _OPAQUE_ARG_KEYS.get(tool_id, frozenset())
return {
"type": "provider",
"id": _provider_tool_id(tool),
"id": tool_id,
"name": tool.name,
"args": tool.args.model_dump(
mode="json",
by_alias=True,
exclude_none=True,
),
"args": {
to_camel(k): v if k in opaque else _camelize(v)
for k, v in (cfg.args if cfg is not None else {}).items()
},
}
args = tool.args
if not isinstance(args, types.tools.FunctionToolArgs):
raise TypeError(f"function tool {tool.name!r} has invalid args")
spec = tool.spec
if spec is None:
raise TypeError(f"function tool {tool.name!r} has no spec")
return {
"type": "function",
"name": tool.name,
"description": args.description or "",
"inputSchema": args.params,
"description": spec.description or "",
"inputSchema": spec.params,
}


def _provider_tool_id(tool: types.tools.Tool) -> str:
if isinstance(tool.args, anthropic_tools.AnthropicProviderArgs):
return f"anthropic.{tool.args.anthropic_type}"
if isinstance(tool.args, openai_tools.OpenAIProviderArgs):
return tool.args.openai_id

match tool.args:
case gateway_tools.PerplexitySearchArgs():
return "gateway.perplexity_search"
case gateway_tools.ParallelSearchArgs():
return "gateway.parallel_search"
case _:
raise TypeError(
f"provider tool {tool.name!r} has unsupported args "
f"{type(tool.args).__name__}"
)


async def _build_request_body(
messages: list[types.messages.Message],
tools: Sequence[types.tools.Tool] | None = None,
Expand Down
72 changes: 31 additions & 41 deletions src/ai/providers/ai_gateway/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing import Literal
from typing import Any, Literal

import pydantic
from pydantic.alias_generators import to_camel
Expand Down Expand Up @@ -43,26 +43,18 @@ class FetchPolicy(pydantic.BaseModel):
max_age_seconds: int | None = None


class PerplexitySearchArgs(pydantic.BaseModel):
model_config = _CONFIG_MODEL

max_results: int | None = None
max_tokens_per_page: int | None = None
max_tokens: int | None = None
country: str | None = None
search_domain_filter: list[str] | None = None
search_language_filter: list[str] | None = None
search_recency_filter: Literal["day", "week", "month", "year"] | None = None
def _dump[M: pydantic.BaseModel](
model_type: type[M], value: M | dict[str, object] | None
) -> dict[str, Any] | None:
if value is None:
return None
if isinstance(value, dict):
value = model_type.model_validate(value)
return value.model_dump(mode="json", exclude_none=True)


class ParallelSearchArgs(pydantic.BaseModel):
model_config = _CONFIG_MODEL

mode: Literal["one-shot", "agentic"] | None = None
max_results: int | None = None
source_policy: SourcePolicy | None = None
excerpts: Excerpts | None = None
fetch_policy: FetchPolicy | None = None
def _dict_filter_none(**args: Any) -> dict[str, Any]:
return {k: v for k, v in args.items() if v is not None}


def perplexity_search(
Expand All @@ -79,14 +71,17 @@ def perplexity_search(
return types.tools.Tool(
kind="provider",
name="perplexity_search",
args=PerplexitySearchArgs(
max_results=max_results,
max_tokens_per_page=max_tokens_per_page,
max_tokens=max_tokens,
country=country,
search_domain_filter=search_domain_filter,
search_language_filter=search_language_filter,
search_recency_filter=search_recency_filter,
tool_config=types.tools.ToolConfig(
id="gateway.perplexity_search",
args=_dict_filter_none(
max_results=max_results,
max_tokens_per_page=max_tokens_per_page,
max_tokens=max_tokens,
country=country,
search_domain_filter=search_domain_filter,
search_language_filter=search_language_filter,
search_recency_filter=search_recency_filter,
),
),
)

Expand All @@ -102,27 +97,22 @@ def parallel_search(
return types.tools.Tool(
kind="provider",
name="parallel_search",
args=ParallelSearchArgs(
mode=mode,
max_results=max_results,
source_policy=SourcePolicy.model_validate(source_policy)
if isinstance(source_policy, dict)
else source_policy,
excerpts=Excerpts.model_validate(excerpts)
if isinstance(excerpts, dict)
else excerpts,
fetch_policy=FetchPolicy.model_validate(fetch_policy)
if isinstance(fetch_policy, dict)
else fetch_policy,
tool_config=types.tools.ToolConfig(
id="gateway.parallel_search",
args=_dict_filter_none(
mode=mode,
max_results=max_results,
source_policy=_dump(SourcePolicy, source_policy),
excerpts=_dump(Excerpts, excerpts),
fetch_policy=_dump(FetchPolicy, fetch_policy),
),
),
)


__all__ = [
"Excerpts",
"FetchPolicy",
"ParallelSearchArgs",
"PerplexitySearchArgs",
"SourcePolicy",
"parallel_search",
"perplexity_search",
Expand Down
44 changes: 25 additions & 19 deletions src/ai/providers/anthropic/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def _custom_tools_to_anthropic(
"""Convert host-executed tools to Anthropic tool schema format."""
result: list[dict[str, Any]] = []
for tool in tools:
args = tool.args
if not isinstance(args, types.tools.FunctionToolArgs):
raise TypeError(f"function tool {tool.name!r} has invalid args")
spec = tool.spec
if spec is None:
raise TypeError(f"function tool {tool.name!r} has no spec")
result.append(
{
"name": tool.name,
"description": args.description or "",
"input_schema": args.params,
"description": spec.description or "",
"input_schema": spec.params,
}
)
return result
Expand All @@ -98,27 +98,33 @@ def _builtin_tools_to_anthropic(
Returns ``(wire_tools, beta_headers)``. Beta headers are merged into
the ``anthropic-beta`` request header by the caller.

Provider tool schemas keep args in the snake_case shape the native
Provider tool configs keep args in the snake_case shape the native
Anthropic API expects.
"""
wire: list[dict[str, Any]] = []
betas: set[str] = set()
for tool in builtin:
args_model = tool.args
if not isinstance(args_model, anthropic_tools.AnthropicProviderArgs):
cfg = tool.tool_config
tool_id = cfg.id if cfg is not None else None
if (
cfg is None
or tool_id is None
or not tool_id.startswith("anthropic.")
):
raise ValueError(
"AnthropicModel does not support provider args "
f"{type(args_model).__name__}"
"AnthropicModel does not support provider tool "
f"{tool_id or tool.name!r}"
)
args = args_model.model_dump(mode="json", exclude_none=True)
block: dict[str, Any] = {
"type": args_model.anthropic_type,
"name": tool.name,
**args,
}
wire.append(block)
if args_model.anthropic_beta is not None:
betas.add(args_model.anthropic_beta)
wire.append(
{
"type": tool_id.removeprefix("anthropic."),
"name": tool.name,
**cfg.args,
}
)
beta = anthropic_tools.BETA_HEADERS.get(tool_id)
if beta is not None:
betas.add(beta)

return wire, betas

Expand Down
Loading
Loading