diff --git a/src/ai/types/tools.py b/src/ai/types/tools.py index 3016909d..a83e5974 100644 --- a/src/ai/types/tools.py +++ b/src/ai/types/tools.py @@ -20,9 +20,26 @@ class FunctionToolArgs(pydantic.BaseModel): class Tool(pydantic.BaseModel): kind: Literal["function", "provider"] name: str - args: pydantic.BaseModel + args: pydantic.SerializeAsAny[pydantic.BaseModel] require_approval: bool = False + # we're using the same type for regular and provider-side tools. + # because of that args can be either FunctionToolArgs or some + # provider-specific type. + @pydantic.model_validator(mode="before") + @classmethod + def validate_args_input(cls, data: Any) -> Any: + if ( + isinstance(data, dict) + and data.get("kind") == "function" + and isinstance(data.get("args"), dict) + ): + return { + **data, + "args": FunctionToolArgs.model_validate(data["args"]), + } + return data + @pydantic.model_validator(mode="after") def validate_args_shape(self) -> Self: match self.kind: diff --git a/tests/types/test_tools.py b/tests/types/test_tools.py new file mode 100644 index 00000000..5933107f --- /dev/null +++ b/tests/types/test_tools.py @@ -0,0 +1,32 @@ +"""Focused tests for tool model serialization.""" + +from __future__ import annotations + +from ai.types import tools + + +def test_function_tool_args_round_trip_through_json_dump() -> None: + tool = tools.Tool( + kind="function", + name="weather", + args=tools.FunctionToolArgs( + description="Get weather", + params={ + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + ), + ) + + data = tool.model_dump(mode="json") + restored = tools.Tool.model_validate(data) + + assert data["args"] == { + "description": "Get weather", + "params": { + "type": "object", + "properties": {"city": {"type": "string"}}, + }, + } + assert isinstance(restored.args, tools.FunctionToolArgs) + assert restored.args.description == "Get weather"