From 7b80a78b088bea4a3610aaa54b3ac99ed96fe3f3 Mon Sep 17 00:00:00 2001 From: markknoffler Date: Wed, 21 Jan 2026 03:10:48 +0530 Subject: [PATCH] Harden tool-call JSON parsing Avoid greedy brace regex parsing and ignore non-tool JSON blocks to prevent KeyError when tool_name is missing. --- gemma/gm/tools/_manager.py | 39 ++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/gemma/gm/tools/_manager.py b/gemma/gm/tools/_manager.py index 57dd9b09..252d883c 100644 --- a/gemma/gm/tools/_manager.py +++ b/gemma/gm/tools/_manager.py @@ -20,7 +20,7 @@ import dataclasses import functools import json -import re +from typing import Any from etils import epy from gemma.gm.tools import _tools @@ -100,7 +100,10 @@ def maybe_execute_tool(self, model_output: str) -> _tools.ToolOutput | None: if not tool_kwargs: return None - tool_name = tool_kwargs.pop('tool_name') + tool_name = tool_kwargs.pop('tool_name', None) + # If the model output contained JSON but it wasn't a tool call, ignore it. + if not isinstance(tool_name, str) or not tool_name: + return None if tool_name not in self.name_to_tool: return _tools.ToolOutput( text=f'Unknown (or unregistered) tool: {tool_name}.' @@ -122,18 +125,26 @@ def maybe_execute_tool(self, model_output: str) -> _tools.ToolOutput | None: return tool_result -def _parse_tool_call(model_output: str) -> dict[str, str] | None: - """Parses the tool call from the model output.""" - # This regex finds the first '{' and the last '}' - match = re.search(r'\{.*\}', model_output) - - if not match: - return None - json_string = match.group(0) - try: - return json.loads(json_string) - except json.JSONDecodeError: - return None +def _parse_tool_call(model_output: str) -> dict[str, Any] | None: + """Parses a tool call dict from the model output. + + The model may emit arbitrary JSON (e.g. when asked to output structured data). + We should only treat JSON as a tool call if it looks like a tool call (i.e. + it is a JSON object containing a `tool_name` field). + """ + # Avoid regex-based extraction: it is brittle (greedy brace matching). + decoder = json.JSONDecoder() + for i, ch in enumerate(model_output): + if ch != '{': + continue + try: + obj, _ = decoder.raw_decode(model_output, i) + except json.JSONDecodeError: + continue + # Only consider JSON objects that look like a tool call. + if isinstance(obj, dict) and isinstance(obj.get('tool_name'), str): + return obj + return None def _format_tool_example(