Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions openhands-sdk/openhands/sdk/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import importlib
import json
import os
import threading
Expand Down Expand Up @@ -500,6 +501,7 @@ class LLM(BaseModel, RetryMixin, NonNativeToolCallingMixin):
# Runtime-only private attrs
_model_info: Any = PrivateAttr(default=None)
_tokenizer: Any = PrivateAttr(default=None)
_chat_template_tokenizer: Any = PrivateAttr(default=None)
_telemetry: Telemetry | None = PrivateAttr(default=None)
_is_subscription: bool = PrivateAttr(default=False)
_litellm_provider: str | None = PrivateAttr(default=None)
Expand Down Expand Up @@ -576,6 +578,9 @@ def _post_init(self):
# Tokenizer
if self.custom_tokenizer:
self._tokenizer = create_pretrained_tokenizer(self.custom_tokenizer)
self._chat_template_tokenizer = self._load_chat_template_tokenizer(
self.custom_tokenizer
)

# Capabilities + model info
self._init_model_info_and_caps()
Expand Down Expand Up @@ -2305,6 +2310,11 @@ def get_token_count(
cc_tools = []

try:
template_count = self._get_chat_template_token_count(
formatted_messages, cc_tools
)
if template_count is not None:
return template_count
return int(
token_counter(
model=self.model,
Expand All @@ -2325,6 +2335,115 @@ def get_token_count(
)
return 0

def _get_chat_template_token_count(
self,
formatted_messages: list[dict],
tools: list[ChatCompletionToolParam],
) -> int | None:
"""Count tokens with a tokenizer chat template when one is available.

LiteLLM's generic token counter estimates OpenAI-style chat/tool overhead.
Local OpenAI-compatible servers commonly apply the model tokenizer's chat
template before tokenization, which can differ substantially once tool
schemas are rendered into the prompt. If a caller configured a tokenizer
that supports ``apply_chat_template``, prefer that exact rendered prompt
shape for condenser token checks and fall back to LiteLLM otherwise.
"""
tokenizer = self._chat_template_tokenizer or self._tokenizer
if isinstance(tokenizer, dict):
tokenizer = tokenizer.get("tokenizer")
if tokenizer is None or not hasattr(tokenizer, "apply_chat_template"):
return None

try:
template_messages = self._messages_for_chat_template(formatted_messages)
kwargs: dict[str, Any] = {
"tokenize": True,
"add_generation_prompt": True,
}
if tools:
kwargs["tools"] = tools
tokenized = tokenizer.apply_chat_template(template_messages, **kwargs)
return self._count_tokenized_output(tokenized, tokenizer)
except Exception:
logger.debug(
"Chat-template token counting failed; falling back to LiteLLM",
exc_info=True,
)
return None

@staticmethod
def _count_tokenized_output(tokenized: Any, tokenizer: Any) -> int:
if isinstance(tokenized, str):
encoded = tokenizer.encode(tokenized)
return LLM._count_tokenized_output(encoded, tokenizer)
if hasattr(tokenized, "shape") and len(tokenized.shape) > 0:
return int(tokenized.shape[-1])
if hasattr(tokenized, "ids"):
return len(tokenized.ids)
if isinstance(tokenized, dict) and "input_ids" in tokenized:
Comment thread
neubig marked this conversation as resolved.
return LLM._count_tokenized_output(tokenized["input_ids"], tokenizer)
get_input_ids = getattr(tokenized, "get", None)
if callable(get_input_ids):
input_ids = get_input_ids("input_ids")
if input_ids is not None:
return LLM._count_tokenized_output(input_ids, tokenizer)
encodings = getattr(tokenized, "encodings", None)
if encodings:
return LLM._count_tokenized_output(encodings[0], tokenizer)
if isinstance(tokenized, Sequence):
if tokenized and hasattr(tokenized[0], "ids"):
return LLM._count_tokenized_output(tokenized[0], tokenizer)
if tokenized and isinstance(tokenized[0], Sequence):
return len(tokenized[0])
return len(tokenized)
raise TypeError(f"Unsupported tokenized output: {type(tokenized).__name__}")

@staticmethod
def _messages_for_chat_template(messages: list[dict]) -> list[dict]:
template_messages = copy.deepcopy(messages)
for message in template_messages:
content = message.get("content")
if not isinstance(content, list):
continue
text_parts: list[str] = []
for block in content:
if not isinstance(block, dict) or block.get("type") != "text":
text_parts = []
break
text_parts.append(str(block.get("text", "")))
if text_parts:
message["content"] = "".join(text_parts)
return template_messages

@staticmethod
def _load_chat_template_tokenizer(identifier: str) -> Any | None:
try:
transformers = importlib.import_module("transformers")
except ModuleNotFoundError:
return None
except Exception:
logger.debug("Unable to import transformers", exc_info=True)
return None

auto_tokenizer = getattr(transformers, "AutoTokenizer", None)
if auto_tokenizer is None:
return None

try:
tokenizer = auto_tokenizer.from_pretrained(identifier)
except Exception:
logger.debug(
"Unable to load chat-template tokenizer for %s",
identifier,
exc_info=True,
)
return None

if hasattr(tokenizer, "apply_chat_template"):
return tokenizer
return None

@classmethod
def from_persisted(cls, data: Any, *, context: dict[str, Any] | None = None) -> LLM:
"""Load a persisted LLM profile payload, applying schema migrations."""
Expand Down
115 changes: 115 additions & 0 deletions tests/sdk/llm/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,121 @@ def test_llm_token_counting_includes_tools(mock_token_counter, default_llm):
assert "message" in kwargs["tools"][0]["function"]["parameters"]["properties"]


def test_llm_load_chat_template_tokenizer_prefers_transformers(monkeypatch):
"""The optional chat-template tokenizer uses Transformers when available."""

class FakeTokenizer:
def apply_chat_template(self, messages, **kwargs):
return []

class FakeAutoTokenizer:
loaded_identifier = None

@classmethod
def from_pretrained(cls, identifier):
cls.loaded_identifier = identifier
return FakeTokenizer()

class FakeTransformers:
AutoTokenizer = FakeAutoTokenizer

def fake_import_module(name):
if name == "transformers":
return FakeTransformers
raise ModuleNotFoundError(name)

monkeypatch.setattr(
"openhands.sdk.llm.llm.importlib.import_module", fake_import_module
)

tokenizer = LLM._load_chat_template_tokenizer("model-with-template")

assert isinstance(tokenizer, FakeTokenizer)
assert FakeAutoTokenizer.loaded_identifier == "model-with-template"


@patch("openhands.sdk.llm.llm.token_counter")
def test_llm_token_counting_prefers_chat_template_tokenizer(
mock_token_counter, default_llm
):
"""Token counting uses apply_chat_template when the tokenizer supports it."""

class FakeChatTemplateTokenizer:
def __init__(self):
self.calls = []

def apply_chat_template(self, messages, **kwargs):
self.calls.append((messages, kwargs))
return list(range(321))

tokenizer = FakeChatTemplateTokenizer()
default_llm._chat_template_tokenizer = tokenizer
messages = [Message(role="user", content=[TextContent(text="Hello")])]
tools = list(FinishTool.create())

token_count = default_llm.get_token_count(
messages,
tools=tools,
add_security_risk_prediction=True,
)

assert token_count == 321
mock_token_counter.assert_not_called()
applied_messages, kwargs = tokenizer.calls[0]
assert applied_messages[0]["role"] == "user"
assert applied_messages[0]["content"] == "Hello"
assert kwargs["tokenize"] is True
assert kwargs["add_generation_prompt"] is True
assert kwargs["tools"][0]["function"]["name"] == "finish"
assert "message" in kwargs["tools"][0]["function"]["parameters"]["properties"]


def test_llm_count_tokenized_output_handles_encoding_objects(default_llm):
"""Token counting handles Hugging Face BatchEncoding/Encoding shapes."""

class FakeEncoding:
def __init__(self):
self.ids = list(range(321))

class FakeBatchEncoding:
def __init__(self):
self.encodings = [FakeEncoding()]

def get(self, key):
if key == "input_ids":
return list(range(321))
return None

class FakeChatTemplateTokenizer:
def apply_chat_template(self, messages, **kwargs):
return FakeBatchEncoding()

default_llm._chat_template_tokenizer = FakeChatTemplateTokenizer()
messages = [Message(role="user", content=[TextContent(text="Hello")])]

assert default_llm.get_token_count(messages) == 321


@patch("openhands.sdk.llm.llm.token_counter")
def test_llm_token_counting_falls_back_when_chat_template_fails(
mock_token_counter, default_llm
):
"""A broken tokenizer chat template must not break token counting."""

class BrokenChatTemplateTokenizer:
def apply_chat_template(self, messages, **kwargs):
raise RuntimeError("template unavailable")

default_llm._chat_template_tokenizer = BrokenChatTemplateTokenizer()
mock_token_counter.return_value = 123
messages = [Message(role="user", content=[TextContent(text="Hello")])]

token_count = default_llm.get_token_count(messages)

assert token_count == 123
mock_token_counter.assert_called_once()


@patch("openhands.sdk.llm.llm.token_counter")
def test_llm_token_counting_mocks_tools_for_non_native_models(mock_token_counter):
"""Test token counting prompt-mocks tools when native tool calling is disabled."""
Expand Down
Loading