Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 205 additions & 0 deletions omlx/adapter/harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@
- commentary: Tool/function calls (non-streaming only)
"""

import json
import logging
import re
from dataclasses import dataclass, field
from typing import Any

from openai_harmony import (
Author,
Conversation,
DeveloperContent,
HarmonyEncoding,
HarmonyEncodingName,
Message,
ReasoningEffort,
Role,
StreamableParser,
SystemContent,
ToolDescription,
load_harmony_encoding,
)

Expand Down Expand Up @@ -108,6 +117,202 @@ def preprocess_harmony_messages(
return result


_REASONING_EFFORT_MAP = {
"low": ReasoningEffort.LOW,
"medium": ReasoningEffort.MEDIUM,
"high": ReasoningEffort.HIGH,
}

# Shared across all render_harmony_prompt() calls — the encoding is
# stateless and cheap to reuse (same pattern as parse_tool_calls_from_tokens
# below, which reloads on every call; here we pay the cost once).
_GPT_OSS_ENCODING: HarmonyEncoding | None = None


def _get_gpt_oss_encoding() -> HarmonyEncoding:
global _GPT_OSS_ENCODING
if _GPT_OSS_ENCODING is None:
_GPT_OSS_ENCODING = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
return _GPT_OSS_ENCODING


def _extract_text(content: Any) -> str:
"""Flatten OpenAI content (str or list of blocks) into plain text."""
if isinstance(content, str):
return content
if not isinstance(content, list):
return "" if content is None else str(content)
parts: list[str] = []
for block in content:
if isinstance(block, dict):
if block.get("type") == "text":
parts.append(block.get("text", ""))
elif "text" in block:
parts.append(block["text"])
elif isinstance(block, str):
parts.append(block)
return "".join(parts)


def _tool_description_from_openai(tool: dict[str, Any]) -> ToolDescription | None:
"""Convert an OpenAI tool spec (``{"type":"function","function":{...}}``) to a ToolDescription."""
spec = tool.get("function") if tool.get("type") == "function" else tool
if not isinstance(spec, dict):
return None
name = spec.get("name")
if not name:
return None
description = spec.get("description", "") or ""
parameters = spec.get("parameters") or spec.get("input_schema") or {
"type": "object",
"properties": {},
}
return ToolDescription.new(name, description, parameters=parameters)


def render_harmony_prompt(
messages: list[dict[str, Any]],
tools: list[dict[str, Any]] | None = None,
chat_template_kwargs: dict[str, Any] | None = None,
) -> str:
"""Render chat messages + tools into a Harmony-format prompt string.

Used as a tokenizer-free fallback for gpt-oss models whose packaged
tokenizer does not ship with a ``chat_template``. The official
``openai-harmony`` library does the heavy lifting; this wrapper
translates the OpenAI-style inputs oMLX already uses.

Args:
messages: OpenAI chat messages. ``system`` becomes developer
instructions; ``assistant`` with a ``tool_calls`` field is
emitted on the commentary channel; ``tool`` messages are
rendered as tool responses.
tools: Optional OpenAI-format function tools. Also accepts raw
function specs (``{"name", "description", "parameters"}``).
chat_template_kwargs: Optional template-style kwargs. Supports
``reasoning_effort`` ("low"/"medium"/"high") and
``conversation_start_date``.

Returns:
A decoded Harmony prompt string ready to feed into ``generate``.
"""
ct = chat_template_kwargs or {}

system_content = SystemContent.new()
effort = ct.get("reasoning_effort")
if isinstance(effort, str):
mapped = _REASONING_EFFORT_MAP.get(effort.lower())
if mapped is not None:
system_content = system_content.with_reasoning_effort(mapped)
start_date = ct.get("conversation_start_date")
if isinstance(start_date, str) and start_date:
system_content = system_content.with_conversation_start_date(start_date)

system_texts: list[str] = []
convo_msgs: list[Message] = []
# Maps tool_call_id -> function name for resolving ``role=tool`` messages
# whose ``name`` field is omitted (OpenAI spec allows this).
tool_call_names: dict[str, str] = {}

for msg in messages:
if not isinstance(msg, dict):
continue
role = msg.get("role")

if role == "system":
text = _extract_text(msg.get("content"))
if text:
system_texts.append(text)
continue

if role == "user":
text = _extract_text(msg.get("content"))
convo_msgs.append(Message.from_role_and_content(Role.USER, text))
continue

if role == "assistant":
text = _extract_text(msg.get("content"))
if text:
convo_msgs.append(
Message.from_role_and_content(Role.ASSISTANT, text)
.with_channel("final")
)
for tc in msg.get("tool_calls") or []:
if not isinstance(tc, dict):
continue
fn = tc.get("function")
if not isinstance(fn, dict):
continue
name = fn.get("name") or ""
if not name:
continue
tc_id = tc.get("id")
if isinstance(tc_id, str):
tool_call_names[tc_id] = name
args = fn.get("arguments")
if isinstance(args, (dict, list)):
args_str = json.dumps(args, ensure_ascii=False)
else:
args_str = args or ""
convo_msgs.append(
Message.from_role_and_content(Role.ASSISTANT, args_str)
.with_channel("commentary")
.with_recipient(f"functions.{name}")
.with_content_type("<|constrain|> json")
)
continue

if role == "tool":
text = _extract_text(msg.get("content"))
# Prefer an explicit ``name``; fall back to the function name
# recorded when the matching assistant tool_call was emitted.
name = msg.get("name")
if not name:
tc_id = msg.get("tool_call_id")
if isinstance(tc_id, str):
name = tool_call_names.get(tc_id)
if not name:
# No recoverable name — skip rather than fabricate one, which
# would confuse the model about which function it called.
logger.warning(
"Skipping tool message: no ``name`` and no matching "
"tool_call_id in earlier assistant message."
)
continue
convo_msgs.append(
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{name}"),
text,
).with_channel("commentary")
)
continue

# Unknown roles: drop silently — openai-harmony would reject them.

developer_content: DeveloperContent | None = None
instructions = "\n\n".join(t for t in system_texts if t)
if instructions:
developer_content = DeveloperContent.new().with_instructions(instructions)
if tools:
tool_descs = [td for t in tools if (td := _tool_description_from_openai(t))]
if tool_descs:
developer_content = (developer_content or DeveloperContent.new()) \
.with_function_tools(tool_descs)

conv_messages: list[Message] = [
Message.from_role_and_content(Role.SYSTEM, system_content),
]
if developer_content is not None:
conv_messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_content))
conv_messages.extend(convo_msgs)

encoding = _get_gpt_oss_encoding()
tokens = encoding.render_conversation_for_completion(
Conversation.from_messages(conv_messages), Role.ASSISTANT
)
return encoding.decode(tokens)


def _get_special_token_ids(tokenizer: Any) -> set[int]:
"""
Get special token IDs from model tokenizer.
Expand Down
17 changes: 16 additions & 1 deletion omlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@

# Optional Harmony adapter import
try:
from ..adapter.harmony import preprocess_harmony_messages
from ..adapter.harmony import preprocess_harmony_messages, render_harmony_prompt

HAS_HARMONY_ADAPTER = True
except ImportError:
HAS_HARMONY_ADAPTER = False
preprocess_harmony_messages = None # type: ignore
render_harmony_prompt = None # type: ignore


class BatchedEngine(BaseEngine):
Expand Down Expand Up @@ -307,6 +308,20 @@ def _apply_chat_template(
chat_template_kwargs: Optional kwargs passed to tokenizer.apply_chat_template
(e.g. enable_thinking, reasoning_effort). Overrides global _enable_thinking.
"""
# Fallback for gpt-oss models whose tokenizer has no chat_template:
# render the Harmony prompt directly via openai-harmony. Models that
# do ship a chat_template continue to use it (preserves any
# model-specific tuning the template author intended).
if (
self.model_type == "gpt_oss"
and HAS_HARMONY_ADAPTER
and render_harmony_prompt is not None
and not getattr(self._tokenizer, "chat_template", None)
):
return render_harmony_prompt(
messages, tools=tools, chat_template_kwargs=chat_template_kwargs
)

if hasattr(self._tokenizer, 'apply_chat_template'):
is_partial = detect_and_strip_partial(messages)
template_kwargs = {
Expand Down
Loading