From 6f3fa97ab738611059290a05812213665015c373 Mon Sep 17 00:00:00 2001 From: Graham Neubig <398875+neubig@users.noreply.github.com> Date: Tue, 9 Jun 2026 15:08:32 -0400 Subject: [PATCH] Count tokens with tokenizer chat templates --- openhands-sdk/openhands/sdk/llm/llm.py | 119 +++++++++++++++++++++++++ tests/sdk/llm/test_llm.py | 115 ++++++++++++++++++++++++ 2 files changed, 234 insertions(+) diff --git a/openhands-sdk/openhands/sdk/llm/llm.py b/openhands-sdk/openhands/sdk/llm/llm.py index 8ccf8bdd3a..ecb5649e7d 100644 --- a/openhands-sdk/openhands/sdk/llm/llm.py +++ b/openhands-sdk/openhands/sdk/llm/llm.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import importlib import json import os import threading @@ -500,6 +501,7 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin): # Runtime-only private attrs _model_info: Any = PrivateAttr(default=None) _tokenizer: Any = PrivateAttr(default=None) + _chat_template_tokenizer: Any = PrivateAttr(default=None) _telemetry: Telemetry | None = PrivateAttr(default=None) _is_subscription: bool = PrivateAttr(default=False) _litellm_provider: str | None = PrivateAttr(default=None) @@ -576,6 +578,9 @@ def _post_init(self): # Tokenizer if self.custom_tokenizer: self._tokenizer = create_pretrained_tokenizer(self.custom_tokenizer) + self._chat_template_tokenizer = self._load_chat_template_tokenizer( + self.custom_tokenizer + ) # Capabilities + model info self._init_model_info_and_caps() @@ -2305,6 +2310,11 @@ def get_token_count( cc_tools = [] try: + template_count = self._get_chat_template_token_count( + formatted_messages, cc_tools + ) + if template_count is not None: + return template_count return int( token_counter( model=self.model, @@ -2325,6 +2335,115 @@ def get_token_count( ) return 0 + def _get_chat_template_token_count( + self, + formatted_messages: list[dict], + tools: list[ChatCompletionToolParam], + ) -> int | None: + """Count tokens with a tokenizer chat template when one is available. + + LiteLLM's generic token counter estimates OpenAI-style chat/tool overhead. + Local OpenAI-compatible servers commonly apply the model tokenizer's chat + template before tokenization, which can differ substantially once tool + schemas are rendered into the prompt. If a caller configured a tokenizer + that supports ``apply_chat_template``, prefer that exact rendered prompt + shape for condenser token checks and fall back to LiteLLM otherwise. + """ + tokenizer = self._chat_template_tokenizer or self._tokenizer + if isinstance(tokenizer, dict): + tokenizer = tokenizer.get("tokenizer") + if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"): + return None + + try: + template_messages = self._messages_for_chat_template(formatted_messages) + kwargs: dict[str, Any] = { + "tokenize": True, + "add_generation_prompt": True, + } + if tools: + kwargs["tools"] = tools + tokenized = tokenizer.apply_chat_template(template_messages, **kwargs) + return self._count_tokenized_output(tokenized, tokenizer) + except Exception: + logger.debug( + "Chat-template token counting failed; falling back to LiteLLM", + exc_info=True, + ) + return None + + @staticmethod + def _count_tokenized_output(tokenized: Any, tokenizer: Any) -> int: + if isinstance(tokenized, str): + encoded = tokenizer.encode(tokenized) + return LLM._count_tokenized_output(encoded, tokenizer) + if hasattr(tokenized, "shape") and len(tokenized.shape) > 0: + return int(tokenized.shape[-1]) + if hasattr(tokenized, "ids"): + return len(tokenized.ids) + if isinstance(tokenized, dict) and "input_ids" in tokenized: + return LLM._count_tokenized_output(tokenized["input_ids"], tokenizer) + get_input_ids = getattr(tokenized, "get", None) + if callable(get_input_ids): + input_ids = get_input_ids("input_ids") + if input_ids is not None: + return LLM._count_tokenized_output(input_ids, tokenizer) + encodings = getattr(tokenized, "encodings", None) + if encodings: + return LLM._count_tokenized_output(encodings[0], tokenizer) + if isinstance(tokenized, Sequence): + if tokenized and hasattr(tokenized[0], "ids"): + return LLM._count_tokenized_output(tokenized[0], tokenizer) + if tokenized and isinstance(tokenized[0], Sequence): + return len(tokenized[0]) + return len(tokenized) + raise TypeError(f"Unsupported tokenized output: {type(tokenized).__name__}") + + @staticmethod + def _messages_for_chat_template(messages: list[dict]) -> list[dict]: + template_messages = copy.deepcopy(messages) + for message in template_messages: + content = message.get("content") + if not isinstance(content, list): + continue + text_parts: list[str] = [] + for block in content: + if not isinstance(block, dict) or block.get("type") != "text": + text_parts = [] + break + text_parts.append(str(block.get("text", ""))) + if text_parts: + message["content"] = "".join(text_parts) + return template_messages + + @staticmethod + def _load_chat_template_tokenizer(identifier: str) -> Any | None: + try: + transformers = importlib.import_module("transformers") + except ModuleNotFoundError: + return None + except Exception: + logger.debug("Unable to import transformers", exc_info=True) + return None + + auto_tokenizer = getattr(transformers, "AutoTokenizer", None) + if auto_tokenizer is None: + return None + + try: + tokenizer = auto_tokenizer.from_pretrained(identifier) + except Exception: + logger.debug( + "Unable to load chat-template tokenizer for %s", + identifier, + exc_info=True, + ) + return None + + if hasattr(tokenizer, "apply_chat_template"): + return tokenizer + return None + @classmethod def from_persisted(cls, data: Any, *, context: dict[str, Any] | None = None) -> LLM: """Load a persisted LLM profile payload, applying schema migrations.""" diff --git a/tests/sdk/llm/test_llm.py b/tests/sdk/llm/test_llm.py index c9d341b7a8..585a23ced4 100644 --- a/tests/sdk/llm/test_llm.py +++ b/tests/sdk/llm/test_llm.py @@ -365,6 +365,121 @@ def test_llm_token_counting_includes_tools(mock_token_counter, default_llm): assert "message" in kwargs["tools"][0]["function"]["parameters"]["properties"] +def test_llm_load_chat_template_tokenizer_prefers_transformers(monkeypatch): + """The optional chat-template tokenizer uses Transformers when available.""" + + class FakeTokenizer: + def apply_chat_template(self, messages, **kwargs): + return [] + + class FakeAutoTokenizer: + loaded_identifier = None + + @classmethod + def from_pretrained(cls, identifier): + cls.loaded_identifier = identifier + return FakeTokenizer() + + class FakeTransformers: + AutoTokenizer = FakeAutoTokenizer + + def fake_import_module(name): + if name == "transformers": + return FakeTransformers + raise ModuleNotFoundError(name) + + monkeypatch.setattr( + "openhands.sdk.llm.llm.importlib.import_module", fake_import_module + ) + + tokenizer = LLM._load_chat_template_tokenizer("model-with-template") + + assert isinstance(tokenizer, FakeTokenizer) + assert FakeAutoTokenizer.loaded_identifier == "model-with-template" + + +@patch("openhands.sdk.llm.llm.token_counter") +def test_llm_token_counting_prefers_chat_template_tokenizer( + mock_token_counter, default_llm +): + """Token counting uses apply_chat_template when the tokenizer supports it.""" + + class FakeChatTemplateTokenizer: + def __init__(self): + self.calls = [] + + def apply_chat_template(self, messages, **kwargs): + self.calls.append((messages, kwargs)) + return list(range(321)) + + tokenizer = FakeChatTemplateTokenizer() + default_llm._chat_template_tokenizer = tokenizer + messages = [Message(role="user", content=[TextContent(text="Hello")])] + tools = list(FinishTool.create()) + + token_count = default_llm.get_token_count( + messages, + tools=tools, + add_security_risk_prediction=True, + ) + + assert token_count == 321 + mock_token_counter.assert_not_called() + applied_messages, kwargs = tokenizer.calls[0] + assert applied_messages[0]["role"] == "user" + assert applied_messages[0]["content"] == "Hello" + assert kwargs["tokenize"] is True + assert kwargs["add_generation_prompt"] is True + assert kwargs["tools"][0]["function"]["name"] == "finish" + assert "message" in kwargs["tools"][0]["function"]["parameters"]["properties"] + + +def test_llm_count_tokenized_output_handles_encoding_objects(default_llm): + """Token counting handles Hugging Face BatchEncoding/Encoding shapes.""" + + class FakeEncoding: + def __init__(self): + self.ids = list(range(321)) + + class FakeBatchEncoding: + def __init__(self): + self.encodings = [FakeEncoding()] + + def get(self, key): + if key == "input_ids": + return list(range(321)) + return None + + class FakeChatTemplateTokenizer: + def apply_chat_template(self, messages, **kwargs): + return FakeBatchEncoding() + + default_llm._chat_template_tokenizer = FakeChatTemplateTokenizer() + messages = [Message(role="user", content=[TextContent(text="Hello")])] + + assert default_llm.get_token_count(messages) == 321 + + +@patch("openhands.sdk.llm.llm.token_counter") +def test_llm_token_counting_falls_back_when_chat_template_fails( + mock_token_counter, default_llm +): + """A broken tokenizer chat template must not break token counting.""" + + class BrokenChatTemplateTokenizer: + def apply_chat_template(self, messages, **kwargs): + raise RuntimeError("template unavailable") + + default_llm._chat_template_tokenizer = BrokenChatTemplateTokenizer() + mock_token_counter.return_value = 123 + messages = [Message(role="user", content=[TextContent(text="Hello")])] + + token_count = default_llm.get_token_count(messages) + + assert token_count == 123 + mock_token_counter.assert_called_once() + + @patch("openhands.sdk.llm.llm.token_counter") def test_llm_token_counting_mocks_tools_for_non_native_models(mock_token_counter): """Test token counting prompt-mocks tools when native tool calling is disabled."""