Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions gemma/gm/tools/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import dataclasses
import functools
import json
import re
from typing import Any

from etils import epy
from gemma.gm.tools import _tools
Expand Down Expand Up @@ -100,7 +100,10 @@ def maybe_execute_tool(self, model_output: str) -> _tools.ToolOutput | None:
if not tool_kwargs:
return None

tool_name = tool_kwargs.pop('tool_name')
tool_name = tool_kwargs.pop('tool_name', None)
# If the model output contained JSON but it wasn't a tool call, ignore it.
if not isinstance(tool_name, str) or not tool_name:
return None
if tool_name not in self.name_to_tool:
return _tools.ToolOutput(
text=f'Unknown (or unregistered) tool: {tool_name}.'
Expand All @@ -122,18 +125,26 @@ def maybe_execute_tool(self, model_output: str) -> _tools.ToolOutput | None:
return tool_result


def _parse_tool_call(model_output: str) -> dict[str, str] | None:
"""Parses the tool call from the model output."""
# This regex finds the first '{' and the last '}'
match = re.search(r'\{.*\}', model_output)

if not match:
return None
json_string = match.group(0)
try:
return json.loads(json_string)
except json.JSONDecodeError:
return None
def _parse_tool_call(model_output: str) -> dict[str, Any] | None:
"""Parses a tool call dict from the model output.

The model may emit arbitrary JSON (e.g. when asked to output structured data).
We should only treat JSON as a tool call if it looks like a tool call (i.e.
it is a JSON object containing a `tool_name` field).
"""
# Avoid regex-based extraction: it is brittle (greedy brace matching).
decoder = json.JSONDecoder()
for i, ch in enumerate(model_output):
if ch != '{':
continue
try:
obj, _ = decoder.raw_decode(model_output, i)
except json.JSONDecodeError:
continue
# Only consider JSON objects that look like a tool call.
if isinstance(obj, dict) and isinstance(obj.get('tool_name'), str):
return obj
return None


def _format_tool_example(
Expand Down