diff --git a/crates/higgs-engine/src/chat_template.rs b/crates/higgs-engine/src/chat_template.rs index b3f27b23..908cdde3 100644 --- a/crates/higgs-engine/src/chat_template.rs +++ b/crates/higgs-engine/src/chat_template.rs @@ -1,3 +1,4 @@ +use minijinja::value::Kwargs; use minijinja::{Environment, Value}; use serde::Serialize; @@ -171,9 +172,107 @@ impl ChatTemplateRenderer { } } -/// Custom tojson filter for minijinja (used by HF chat templates). +/// Normalise a tool-call JSON object so Qwen-Hermes-style chat templates +/// can render it without crashing on `tool_call.arguments|items`. +/// +/// Two transformations are applied in place: +/// +/// 1. **Flatten `function.{name,arguments}` to top level.** The `OpenAI` +/// request shape nests them under `function`; Qwen's +/// `chat_template.jinja` references `tool_call.name` and +/// `tool_call.arguments` directly. After this call, both shapes are +/// accessible. +/// 2. **Coerce `arguments` to a mapping.** `OpenAI` sends +/// `function.arguments` as a JSON-encoded string, but Qwen's template +/// iterates it via `|items`. A string that parses to a JSON object is +/// replaced by that object; anything that does not resolve to an object +/// (unparseable strings, or JSON that isn't an object) is coerced to an +/// empty object `{}` by [`normalize_arguments_value`] so the template +/// can't raise `cannot convert value into pairs`. The original string +/// does NOT survive when it isn't object-shaped. +/// +/// Other fields (`id`, `type`, …) are preserved unchanged. Callers that +/// already supply the flat shape pay only the cost of a `serde_json::Value` +/// match. +pub fn normalize_tool_call_for_template(tc: &mut serde_json::Value) { + let Some(obj) = tc.as_object_mut() else { + return; + }; + + // Promote `function.name` / `function.arguments` to the top level. + if let Some(function) = obj.get("function").cloned() { + if let Some(func_obj) = function.as_object() { + if !obj.contains_key("name") { + if let Some(name) = func_obj.get("name") { + obj.insert("name".to_owned(), name.clone()); + } + } + if !obj.contains_key("arguments") { + if let Some(arguments) = func_obj.get("arguments") { + obj.insert("arguments".to_owned(), arguments.clone()); + } + } + } + } + + // Normalize the top-level `arguments` (used by Qwen-flat templates). + if let Some(args) = obj.get_mut("arguments") { + normalize_arguments_value(args); + } + + // Normalize the nested `function.arguments` too. Qwen's + // `chat_template.jinja` lines 107-108 rebind `tool_call` to + // `tool_call.function` when present, so if we only normalised the + // top-level copy the template still walks into a string and crashes + // at `|items`. Templates that don't rebind are unaffected. + if let Some(function) = obj.get_mut("function") { + if let Some(func_obj) = function.as_object_mut() { + if let Some(nested_args) = func_obj.get_mut("arguments") { + normalize_arguments_value(nested_args); + } + } + } +} + +/// Coerce a `tool_call.arguments` (or `function.arguments`) value into +/// the mapping shape that `chat_template.jinja:120` requires. +/// +/// 1. If it's a JSON-string, try to parse it back to a `Value`. +/// 2. If the result still isn't an object (null, bool, number, array, +/// or unparseable string), coerce to an empty object so the +/// template's `|items` doesn't raise. A warn is logged so the +/// pathological shape is visible. +fn normalize_arguments_value(args: &mut serde_json::Value) { + if let Some(s) = args.as_str() { + if let Ok(parsed) = serde_json::from_str::(s) { + *args = parsed; + } + } + if args.is_object() { + return; + } + let shape = match args { + serde_json::Value::Null => "null", + serde_json::Value::Bool(_) => "bool", + serde_json::Value::Number(_) => "number", + serde_json::Value::String(_) => "string", + serde_json::Value::Array(_) => "array", + // `is_object()` already returned for this case above. + serde_json::Value::Object(_) => "object", + }; + tracing::warn!( + shape, + "tool_call arguments not a mapping after normalization; coercing to empty object so the chat template can render" + ); + *args = serde_json::Value::Object(serde_json::Map::new()); +} + +/// `tojson` filter. `_kwargs` absorbs keyword arguments HF chat templates pass +/// — notably `tojson(ensure_ascii=false)` (e.g. `MiniCPM5`). `serde_json` already +/// emits UTF-8, which matches `ensure_ascii=false`, so the kwarg is accepted and +/// ignored rather than failing the render with "too many arguments". #[allow(clippy::needless_pass_by_value)] -fn tojson_filter(value: Value) -> Result { +fn tojson_filter(value: Value, _kwargs: Kwargs) -> Result { let serialized = serde_json::to_string(&value).map_err(|e| { minijinja::Error::new( minijinja::ErrorKind::InvalidOperation, @@ -185,7 +284,13 @@ fn tojson_filter(value: Value) -> Result { } #[cfg(test)] -#[allow(clippy::panic, clippy::unwrap_used)] +#[allow( + clippy::panic, + clippy::unwrap_used, + clippy::expect_used, + clippy::shadow_unrelated, + clippy::shadow_reuse +)] mod tests { use super::*; @@ -240,6 +345,20 @@ mod tests { assert_eq!(result, r#""hello""#); } + /// HF templates (e.g. `MiniCPM5` at `chat:6`) call `tojson(ensure_ascii=…)`. + /// The filter must accept the kwarg instead of failing with "too many + /// arguments"; the value is ignored since `serde_json` emits UTF-8. + #[test] + fn test_tojson_filter_accepts_ensure_ascii_kwarg() { + let env = tojson_env(r"{{ value | tojson(ensure_ascii=false) }}"); + let tmpl = env.get_template("test").unwrap(); + let result = tmpl + .render(minijinja::context! { value => "café" }) + .unwrap(); + // UTF-8 preserved (not \u-escaped), and the call did not error. + assert_eq!(result, "\"café\""); + } + #[test] fn test_invalid_template_syntax_returns_error() { assert!(ChatTemplateRenderer::new("{%- invalid syntax %}}}").is_err()); @@ -593,4 +712,168 @@ TOOLS:{{ tools | length }} .unwrap(); assert!(ChatTemplateRenderer::try_from_model_dir(dir.path()).is_err()); } + + // ----------------------------------------------------------------- + // normalize_tool_call_for_template + // ----------------------------------------------------------------- + // + // Invariants asserted, one test per shape we observed in production: + // + // 1. `OpenAI` shape (name/arguments nested under `function`, + // arguments as JSON-encoded STRING) → after normalize, top-level + // name and arguments-as-mapping. This is the case that crashed + // Qwen's `chat_template.jinja:120` with "cannot convert value + // into pairs". + // 2. Qwen-flat shape (top-level name/arguments, arguments already + // an object) → no-op, identity. + // 3. Non-JSON string in `function.arguments` → flattened but kept + // as string (template can decide what to do). + // 4. Non-object input (string, null, array) → no-op, can't panic. + + fn parsed(s: &str) -> serde_json::Value { + serde_json::from_str(s).unwrap() + } + + #[test] + fn normalize_openai_shape_to_qwen_flat() { + let mut tc = parsed( + r#"{ + "id": "call_0", + "type": "function", + "function": { "name": "get_weather", "arguments": "{\"city\":\"Paris\"}" } + }"#, + ); + normalize_tool_call_for_template(&mut tc); + + // Top-level name and arguments are present. + assert_eq!(tc.get("name").and_then(|v| v.as_str()), Some("get_weather")); + // arguments is now an OBJECT, not a string. + let args = tc.get("arguments").unwrap(); + assert!( + args.is_object(), + "expected arguments to be an object, got {args:?}" + ); + assert_eq!(args.get("city").and_then(|v| v.as_str()), Some("Paris")); + // id and type preserved. + assert_eq!(tc.get("id").and_then(|v| v.as_str()), Some("call_0")); + assert_eq!(tc.get("type").and_then(|v| v.as_str()), Some("function")); + } + + #[test] + fn normalize_qwen_flat_shape_is_noop() { + let original = parsed(r#"{ "name": "search", "arguments": { "q": "rust" } }"#); + let mut tc = original.clone(); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc, original, "already-flat shape must be a no-op"); + } + + #[test] + fn normalize_unparseable_string_arguments_coerced_to_empty_object() { + // Unparseable string arguments are coerced to `{}` so the chat + // template's `|items` doesn't blow up. The model loses the + // pathological arguments, which is strictly better than the + // entire conversation 500-ing. + let mut tc = parsed( + r#"{ + "function": { "name": "f", "arguments": "this is not json" } + }"#, + ); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("name").and_then(|v| v.as_str()), Some("f")); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + } + + #[test] + fn normalize_non_object_is_noop() { + let mut s = parsed(r#""not a tool call""#); + normalize_tool_call_for_template(&mut s); + assert_eq!(s, parsed(r#""not a tool call""#)); + + let mut n = parsed("null"); + normalize_tool_call_for_template(&mut n); + assert_eq!(n, parsed("null")); + + let mut a = parsed("[1, 2, 3]"); + normalize_tool_call_for_template(&mut a); + assert_eq!(a, parsed("[1, 2, 3]")); + } + + /// Qwen's `chat_template.jinja:107-108` rebinds `tool_call` to + /// `tool_call.function` when the latter is defined. If we only + /// normalised the hoisted top-level `arguments` and left + /// `function.arguments` as the original JSON-encoded string, the + /// rebinding would walk straight into a string and the template + /// would crash at `|items`. This test pins both paths. + #[test] + fn normalize_handles_qwen_rebind_to_function() { + let mut tc = parsed( + r#"{ + "id": "call_0", + "type": "function", + "function": { "name": "f", "arguments": "{\"city\":\"London\"}" } + }"#, + ); + normalize_tool_call_for_template(&mut tc); + + // Top-level arguments — Qwen-flat templates see this. + let top_args = tc.get("arguments").unwrap(); + assert!( + top_args.is_object(), + "top-level arguments must be a mapping" + ); + assert_eq!( + top_args.get("city").and_then(|v| v.as_str()), + Some("London") + ); + + // Nested function.arguments — Qwen's standard template walks this + // after rebinding via `set tool_call = tool_call.function`. + let func_args = tc + .get("function") + .and_then(|f| f.get("arguments")) + .expect("function.arguments must still be present"); + assert!( + func_args.is_object(), + "nested function.arguments must ALSO be a mapping, got {func_args:?}" + ); + assert_eq!( + func_args.get("city").and_then(|v| v.as_str()), + Some("London") + ); + } + + /// Arguments shaped as something other than an object after normalization + /// must be coerced to an empty object so the chat template's + /// `tool_call.arguments|items` can render. Without this, Qwen's + /// `chat_template.jinja:120` raises `cannot convert value into pairs` + /// when prior conversation turns carried weird tool-call shapes. + #[test] + fn arguments_coerced_to_empty_object_when_not_mapping() { + // Null arguments → empty object. + let mut tc = parsed(r#"{ "name": "f", "arguments": null }"#); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + + // Array arguments → empty object. + let mut tc = parsed(r#"{ "name": "f", "arguments": [1, 2, 3] }"#); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + + // Number arguments → empty object. + let mut tc = parsed(r#"{ "name": "f", "arguments": 42 }"#); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + + // Unparseable string arguments → empty object (the model can't + // express what it wanted; better than a 500). + let mut tc = parsed(r#"{ "name": "f", "arguments": "this is not json" }"#); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + + // Valid-JSON-string-that-parses-to-array → coerced via the + // second pass (parse succeeds, result is still not an object). + let mut tc = parsed(r#"{ "name": "f", "arguments": "[1,2,3]" }"#); + normalize_tool_call_for_template(&mut tc); + assert_eq!(tc.get("arguments"), Some(&parsed("{}"))); + } } diff --git a/crates/higgs-engine/src/simple.rs b/crates/higgs-engine/src/simple.rs index 88c0425f..28cc7bf5 100644 --- a/crates/higgs-engine/src/simple.rs +++ b/crates/higgs-engine/src/simple.rs @@ -200,6 +200,10 @@ pub struct SimpleEngine { template: Option, model_name: String, eos_token_ids: Vec, + /// Control tokens stripped from decoded output (EOS + `<|…|>` chat + /// delimiters + classic sentinels), while content-bearing special tokens + /// (tool-call markup, ``) are preserved. See [`Self::decode_tokens`]. + decode_skip_ids: std::collections::HashSet, /// Whether to enable thinking mode (Qwen3.5 `` tags). enable_thinking: bool, /// Token ID for ``, resolved from the tokenizer at load time. @@ -247,6 +251,25 @@ impl SimpleEngine { let eos_token_ids = extract_eos_tokens(model_dir); + // Control tokens that must never surface in decoded text. `decode_tokens` + // keeps content-bearing special tokens (so tool-call markup like + // MiniCPM's ``/`` reaches the parser) but strips these: + // the EOS set, the `<|…|>` chat-control delimiters, and classic + // sentinels. Content tokens like ``, ``, `` + // do not match and are preserved. + let decode_skip_ids: std::collections::HashSet = { + let mut ids: std::collections::HashSet = eos_token_ids.iter().copied().collect(); + for (id, added) in tokenizer.get_added_tokens_decoder() { + let content = added.content.as_str(); + let is_control = (content.starts_with("<|") && content.ends_with("|>")) + || matches!(content, "" | "" | "" | "" | ""); + if is_control { + ids.insert(id); + } + } + ids + }; + // Auto-detect thinking mode: Qwen3.5 models support tags. // Override with HIGGS_ENABLE_THINKING=0/1, off/true, yes/no etc. let mut enable_thinking = std::env::var("HIGGS_ENABLE_THINKING") @@ -277,10 +300,7 @@ impl SimpleEngine { enable_thinking = false; } if enable_thinking { - tracing::info!( - think_close_token, - "Thinking mode enabled (Qwen3.5 model detected)" - ); + tracing::info!(think_close_token, "Thinking mode enabled"); } set_wired_limit_to_max(raise_wired_limit); @@ -386,6 +406,7 @@ impl SimpleEngine { template, model_name, eos_token_ids, + decode_skip_ids, enable_thinking, think_close_token, gen_prompt_suffix_len, @@ -700,10 +721,30 @@ impl SimpleEngine { } /// Decode the token buffer and return the text, mapping tokenizer errors. + /// + /// Decodes WITHOUT skipping special tokens so content-bearing markup + /// survives — notably models (e.g. `MiniCPM5`) that encode their tool-call + /// structure (``, ``, …) as special tokens, which the + /// tool parser needs to see. Control tokens (EOS) are filtered out first so + /// they never leak into visible text. Plain text contains no special + /// tokens and decodes identically either way, so normal responses are + /// unaffected. fn decode_tokens(&self, tokens: &[u32]) -> Result { - self.tokenizer - .decode(tokens, true) - .map_err(|e| EngineError::Tokenization(e.to_string())) + let decode = |ids: &[u32]| { + self.tokenizer + .decode(ids, false) + .map_err(|e| EngineError::Tokenization(e.to_string())) + }; + // Fast path: no control token present, decode the slice as-is. + if !tokens.iter().any(|id| self.decode_skip_ids.contains(id)) { + return decode(tokens); + } + let filtered: Vec = tokens + .iter() + .copied() + .filter(|id| !self.decode_skip_ids.contains(id)) + .collect(); + decode(&filtered) } /// The model's hidden dimension (embedding output size). @@ -2162,18 +2203,31 @@ pub(crate) fn extract_eos_tokens(model_dir: &Path) -> Vec { } /// Detect whether a model supports thinking mode based on `model_type`. +/// Whether the model *supports* a thinking toggle (capability, not default). +/// +/// The per-request default — e.g. Qwen3.6 reasons off unless asked — is decided +/// separately by `model_defaults_to_non_thinking` in the router; this only +/// answers "can it think at all". +/// +/// Signals, in order: +/// 1. the chat template exposes an `enable_thinking` switch — the model +/// author's own marker, which covers Qwen3.5/3.6, `MiniCPM5`, and future +/// reasoning models without hardcoding model types; or +/// 2. a known reasoning `model_type`. +/// +/// The caller additionally requires a single-token ``, so a stray +/// mention can't enable thinking for a model that wasn't trained for it. fn detect_thinking_support(model_dir: &Path) -> bool { - let config_path = model_dir.join("config.json"); - let config_str = match std::fs::read_to_string(&config_path) { - Ok(s) => s, - Err(_) => return false, + if chat_template_mentions_enable_thinking(model_dir) { + return true; + } + let Ok(config_str) = std::fs::read_to_string(model_dir.join("config.json")) else { + return false; }; - let config: serde_json::Value = match serde_json::from_str(&config_str) { - Ok(v) => v, - Err(_) => return false, + let Ok(config) = serde_json::from_str::(&config_str) else { + return false; }; - // Qwen3.5 models (qwen3_5, qwen3_5_moe) support tags. - // Check both top-level and nested text_config for VLM wrappers. + // Check both top-level and nested text_config (VLM wrappers). let model_type = config .get("model_type") .and_then(|v| v.as_str()) @@ -2186,11 +2240,41 @@ fn detect_thinking_support(model_dir: &Path) -> bool { matches!(model_type, Some("qwen3_5" | "qwen3_5_moe")) } +/// Whether the model's chat template references the `enable_thinking` toggle, +/// read from `chat_template.jinja` or `tokenizer_config.json`'s `chat_template` +/// (a string, or a `{name, template}` array). +fn chat_template_mentions_enable_thinking(model_dir: &Path) -> bool { + const MARKER: &str = "enable_thinking"; + if let Ok(jinja) = std::fs::read_to_string(model_dir.join("chat_template.jinja")) { + return jinja.contains(MARKER); + } + let Ok(cfg_str) = std::fs::read_to_string(model_dir.join("tokenizer_config.json")) else { + return false; + }; + let Ok(cfg) = serde_json::from_str::(&cfg_str) else { + return false; + }; + let template = cfg.get("chat_template"); + if let Some(s) = template.and_then(|v| v.as_str()) { + return s.contains(MARKER); + } + if let Some(arr) = template.and_then(|v| v.as_array()) { + return arr.iter().any(|entry| { + entry + .get("template") + .and_then(|t| t.as_str()) + .is_some_and(|t| t.contains(MARKER)) + }); + } + false +} + #[cfg(test)] #[allow(clippy::panic, clippy::unwrap_used)] mod tests { use super::{ - check_stop_sequences, derive_model_name, estimate_paged_kv_blocks, parse_enabled_flag, + check_stop_sequences, derive_model_name, detect_thinking_support, estimate_paged_kv_blocks, + parse_enabled_flag, }; use std::path::Path; @@ -2199,6 +2283,44 @@ mod tests { std::fs::write(dir.join("config.json"), json).unwrap(); } + // --- detect_thinking_support tests --- + + /// MiniCPM5-style: a non-reasoning `model_type` (llama) but a chat template + /// that exposes the `enable_thinking` switch ⇒ thinking-capable. + #[test] + fn detect_thinking_from_template_marker() { + let dir = tempfile::tempdir().unwrap(); + write_config(dir.path(), r#"{"model_type": "llama"}"#); + std::fs::write( + dir.path().join("chat_template.jinja"), + "{%- if enable_thinking %}\n{%- endif %}", + ) + .unwrap(); + assert!(detect_thinking_support(dir.path())); + } + + /// Qwen3.5 reasoning `model_type` is detected even without a template file. + #[test] + fn detect_thinking_from_model_type() { + let dir = tempfile::tempdir().unwrap(); + write_config(dir.path(), r#"{"model_type": "qwen3_5_moe"}"#); + assert!(detect_thinking_support(dir.path())); + } + + /// A plain Llama (no reasoning `model_type`, no `enable_thinking` in the + /// template) must NOT be treated as a thinking model. + #[test] + fn no_thinking_for_plain_llama() { + let dir = tempfile::tempdir().unwrap(); + write_config(dir.path(), r#"{"model_type": "llama"}"#); + std::fs::write( + dir.path().join("chat_template.jinja"), + "{%- for m in messages %}{{ m.content }}{%- endfor %}", + ) + .unwrap(); + assert!(!detect_thinking_support(dir.path())); + } + // --- derive_model_name tests --- #[test] diff --git a/crates/higgs-engine/src/tool_parser.rs b/crates/higgs-engine/src/tool_parser.rs index 00c37866..a2f03a3b 100644 --- a/crates/higgs-engine/src/tool_parser.rs +++ b/crates/higgs-engine/src/tool_parser.rs @@ -1,13 +1,32 @@ //! Parse tool calls from model-generated text. //! -//! Qwen models emit tool calls in a specific XML-like format: +//! Qwen models wrap tool calls in `` tags, but the +//! payload *inside* the tags comes in two shapes depending on the model +//! generation: +//! +//! Legacy JSON (Qwen2.5 / Qwen3): //! ```text //! //! {"name": "function_name", "arguments": {"arg1": "value1"}} //! //! ``` //! -//! This module extracts those structured tool calls from the raw text. +//! XML function/parameter (Qwen3.5 / Qwen3.6 — what their +//! `chat_template.jinja` instructs the model to emit): +//! ```text +//! +//! +//! +//! value1 +//! +//! +//! +//! ``` +//! +//! This module extracts structured tool calls from either shape. The XML form +//! emits every value as a raw string, so values are coerced to JSON types +//! using the request's declared tool schema ([`ToolSchema`]) when available, +//! falling back to best-effort parsing otherwise. /// A parsed tool call extracted from model output. #[derive(Debug, Clone)] @@ -28,10 +47,31 @@ pub struct ToolParseResult { const TOOL_CALL_OPEN: &str = ""; const TOOL_CALL_CLOSE: &str = ""; +/// Hard cap on bytes buffered while inside an unclosed ``. +/// +/// Without a cap, a model that emits `` and never closes the tag +/// would grow `buffer` until OOM — flagged CRITICAL on the closed upstream +/// PR #63. On overflow the tracker abandons the parse, emits `` +/// plus the buffered bytes as visible content (preserving the "never +/// silently drop tokens" invariant), and resets so subsequent well-formed +/// tool calls in the same stream still parse. +const MAX_INSIDE_TOOL_CALL_BYTES: usize = 1024 * 1024; + /// Parse model output text for Qwen-format tool calls. /// +/// `schema` carries the request's declared tool parameter types so XML-format +/// values can be coerced; pass `None` for best-effort coercion. +/// /// Returns the non-tool-call text and any extracted tool calls. -pub fn parse_tool_calls(text: &str) -> ToolParseResult { +pub fn parse_tool_calls(text: &str, schema: Option<&ToolSchema>) -> ToolParseResult { + // MiniCPM5 emits bare `` with no `` + // wrapper. When there's no wrapper but a function opener is present, take + // that path; otherwise fall through to the `` scanner (which + // covers both the JSON and Qwen ` ToolParseResult { let raw_block = after_open.get(..end_pos).unwrap_or_default(); let call_content = raw_block.trim(); - if let Some(parsed) = try_parse_tool_call(call_content) { + if let Some(parsed) = parse_tool_call_block(call_content, schema) { tool_calls.push(parsed); } else { result_text.push_str(TOOL_CALL_OPEN); @@ -90,8 +130,617 @@ fn try_parse_tool_call(content: &str) -> Option { Some(ParsedToolCall { name, arguments }) } +const FUNCTION_OPEN: &str = "VALUE` +// with no `` wrapper and optional ``-wrapped values. +// `FUNCTION_CLOSE` (``) is shared with the Qwen XML form above. +const MINICPM_FUNCTION_OPEN: &str = " Option { + match s { + "string" => Some(Self::Str), + "integer" => Some(Self::Integer), + "number" => Some(Self::Number), + "boolean" => Some(Self::Boolean), + "object" => Some(Self::Object), + "array" => Some(Self::Array), + _ => None, + } + } +} + +/// Per-request tool parameter types, keyed by `function name → parameter +/// name → declared type`. +/// +/// Built from the `OpenAI` `tools` array so the XML tool-call parser can +/// coerce raw string parameter values to the JSON types the client declared. +pub struct ToolSchema { + params: std::collections::HashMap>, +} + +impl ToolSchema { + /// Build a [`ToolSchema`] from the request's `OpenAI` tool definitions. + /// + /// Each tool is either `{"type":"function","function":{...}}` or a bare + /// function object. Returns `None` when no function declares a typed + /// `parameters.properties` map — callers then use best-effort coercion. + #[must_use] + pub fn from_tools(tools: Option<&[serde_json::Value]>) -> Option { + let tool_list = tools?; + let mut params: std::collections::HashMap< + String, + std::collections::HashMap, + > = std::collections::HashMap::new(); + + for tool in tool_list { + let function = tool.get("function").unwrap_or(tool); + let Some(name) = function.get("name").and_then(serde_json::Value::as_str) else { + continue; + }; + let Some(properties) = function + .get("parameters") + .and_then(|p| p.get("properties")) + .and_then(serde_json::Value::as_object) + else { + continue; + }; + + let param_types: std::collections::HashMap = properties + .iter() + .filter_map(|(param, spec)| { + let ty = spec + .get("type") + .and_then(serde_json::Value::as_str) + .and_then(ParamType::from_schema_str)?; + Some((param.clone(), ty)) + }) + .collect(); + + if !param_types.is_empty() { + params.insert(name.to_owned(), param_types); + } + } + + if params.is_empty() { + return None; + } + Some(Self { params }) + } + + fn param_type(&self, function: &str, param: &str) -> Option { + self.params.get(function)?.get(param).copied() + } +} + +/// Coerce a raw XML parameter string into a JSON value using its declared +/// schema type, falling back to best-effort JSON parsing when the type is +/// unknown or absent. +fn coerce_param_value(raw: &str, declared: Option) -> serde_json::Value { + use serde_json::Value; + let as_string = || Value::String(raw.to_owned()); + let parsed_if = |pred: fn(&Value) -> bool| { + serde_json::from_str::(raw) + .ok() + .filter(pred) + .unwrap_or_else(|| Value::String(raw.to_owned())) + }; + match declared { + Some(ParamType::Str) => as_string(), + // `integer` must reject fractional values — `is_number` accepts floats. + Some(ParamType::Integer) => parsed_if(|v| v.is_i64() || v.is_u64()), + Some(ParamType::Number) => parsed_if(Value::is_number), + Some(ParamType::Boolean) => match raw.trim() { + "true" => Value::Bool(true), + "false" => Value::Bool(false), + _ => as_string(), + }, + Some(ParamType::Object) => parsed_if(Value::is_object), + Some(ParamType::Array) => parsed_if(Value::is_array), + // No schema for this parameter: parse if it's valid JSON (so `42` + // becomes a number), otherwise keep the raw string (so `London` + // stays a string). + None => serde_json::from_str::(raw).unwrap_or_else(|_| as_string()), + } +} + +/// Strip a single leading and trailing newline — the wrapping the template +/// adds around `` values — preserving any intentional inner or +/// edge whitespace. +fn strip_one_wrapping_newline(s: &str) -> &str { + let without_lead = s + .strip_prefix("\r\n") + .or_else(|| s.strip_prefix('\n')) + .unwrap_or(s); + without_lead + .strip_suffix("\r\n") + .or_else(|| without_lead.strip_suffix('\n')) + .unwrap_or(without_lead) +} + +/// Parse the Qwen XML tool-call body (the text between `` and +/// ``): a single `` block containing +/// zero or more `…` entries. +/// +/// Returns `None` when no well-formed `` opener is present so the +/// caller can fall back to JSON parsing / verbatim preservation. The template +/// never nests more than one function per ``, so only the first is +/// parsed. +fn parse_xml_tool_call(content: &str, schema: Option<&ToolSchema>) -> Option { + let open = content.find(FUNCTION_OPEN)?; + let after_open = content.get(open + FUNCTION_OPEN.len()..)?; + let name_end = after_open.find('>')?; + let name = after_open.get(..name_end)?.trim().to_owned(); + if name.is_empty() { + return None; + } + + // Body between the `>` of `` and the matching + // `` (or end of content if the closer is absent). + let body_all = after_open.get(name_end + 1..).unwrap_or_default(); + let body = body_all + .find(FUNCTION_CLOSE) + .and_then(|i| body_all.get(..i)) + .unwrap_or(body_all); + + let mut map = serde_json::Map::new(); + let mut rest = body; + while let Some(p_open) = rest.find(PARAM_OPEN) { + let after_p = rest.get(p_open + PARAM_OPEN.len()..).unwrap_or_default(); + let Some(key_end) = after_p.find('>') else { + break; + }; + let key = after_p.get(..key_end).unwrap_or_default().trim().to_owned(); + let value_region = after_p.get(key_end + 1..).unwrap_or_default(); + let (raw_value, consumed) = value_region.find(PARAM_CLOSE).map_or_else( + || (value_region, value_region.len()), + |close| { + ( + value_region.get(..close).unwrap_or_default(), + close + PARAM_CLOSE.len(), + ) + }, + ); + + if !key.is_empty() { + let value = strip_one_wrapping_newline(raw_value); + let declared = schema.and_then(|s| s.param_type(&name, &key)); + map.insert(key, coerce_param_value(value, declared)); + } + + // Advance past this whole `…` entry. + let advance = p_open + PARAM_OPEN.len() + key_end + 1 + consumed; + rest = rest.get(advance..).unwrap_or_default(); + } + + Some(ParsedToolCall { + name, + arguments: serde_json::Value::Object(map), + }) +} + +/// Parse one `` block body, dispatching on shape: the Qwen XML +/// `` form vs the legacy JSON-object form. +fn parse_tool_call_block(content: &str, schema: Option<&ToolSchema>) -> Option { + if content.trim_start().starts_with(FUNCTION_OPEN) { + parse_xml_tool_call(content, schema) + } else { + try_parse_tool_call(content) + } +} + +/// Byte offset of the `` that closes a `MiniCPM` function block in +/// `s`, skipping any `` spans whose content may itself contain +/// a literal ``. +/// +/// Returns `None` when the block is not yet terminated: either no closer has +/// arrived, or scanning is parked inside an unclosed CDATA span (the caller +/// should wait for more input). +fn minicpm_function_end(s: &str) -> Option { + let mut i = 0; + loop { + let rest = s.get(i..)?; + let next_close = rest.find(FUNCTION_CLOSE); + // A CDATA span that opens before the next close tag must be skipped + // whole, otherwise a `` inside it would close early. + if let Some(d) = rest.find(CDATA_OPEN) { + if next_close.is_none_or(|c| d < c) { + let after_open = d + CDATA_OPEN.len(); + let close = rest.get(after_open..)?.find(CDATA_CLOSE)?; + i += after_open + close + CDATA_CLOSE.len(); + continue; + } + } + return next_close.map(|c| i + c); + } +} + +/// Extract one `MiniCPM` `` value from `vr` — the text immediately after +/// the param tag's `>`. Returns `(value, rest_after_)`. A +/// `` wrapper yields its verbatim content; otherwise the value is +/// the text up to ``. Both returned slices borrow `vr`. +fn extract_param_value(vr: &str) -> (&str, &str) { + if let Some(stripped) = vr.strip_prefix(CDATA_OPEN) { + if let Some(close) = stripped.find(CDATA_CLOSE) { + let value = stripped.get(..close).unwrap_or_default(); + let tail = stripped + .get(close + CDATA_CLOSE.len()..) + .unwrap_or_default(); + let after = tail + .find(MINICPM_PARAM_CLOSE) + .and_then(|i| tail.get(i + MINICPM_PARAM_CLOSE.len()..)) + .unwrap_or_default(); + return (value, after); + } + return (stripped, ""); + } + vr.find(MINICPM_PARAM_CLOSE).map_or((vr, ""), |i| { + ( + vr.get(..i).unwrap_or_default(), + vr.get(i + MINICPM_PARAM_CLOSE.len()..).unwrap_or_default(), + ) + }) +} + +/// Parse a single `MiniCPM` function block (`…` up to, but +/// not including, the closing ``). +/// +/// Returns `None` when no `name="…"` attribute is present so the caller can +/// preserve the text verbatim. +fn parse_minicpm_function(block: &str, schema: Option<&ToolSchema>) -> Option { + // Read `name="…"` only from the opening `` tag (before its + // closing `>`). Scanning the whole block would let a malformed payload + // like `…` be parsed as a tool call named `x` + // instead of being preserved verbatim. + let tag_close = block.find('>')?; + let open_tag = block.get(..tag_close)?; + let name_attr = open_tag.find(NAME_ATTR)?; + let after_attr = open_tag.get(name_attr + NAME_ATTR.len()..)?; + let name_end = after_attr.find('"')?; + let name = after_attr.get(..name_end)?.to_owned(); + if name.is_empty() { + return None; + } + // Params start after the `>` that closes the `` open tag. + let mut rest = block.get(tag_close + 1..).unwrap_or_default(); + + let mut map = serde_json::Map::new(); + while let Some(p_open) = rest.find(MINICPM_PARAM_OPEN) { + let after_p = rest + .get(p_open + MINICPM_PARAM_OPEN.len()..) + .unwrap_or_default(); + let Some(key_end) = after_p.find('"') else { + break; + }; + let key = after_p.get(..key_end).unwrap_or_default().to_owned(); + let after_key = after_p.get(key_end + 1..).unwrap_or_default(); + let Some(gt) = after_key.find('>') else { + break; + }; + let value_region = after_key.get(gt + 1..).unwrap_or_default(); + let (raw_value, after) = extract_param_value(value_region); + if !key.is_empty() { + let declared = schema.and_then(|s| s.param_type(&name, &key)); + map.insert(key, coerce_param_value(raw_value, declared)); + } + rest = after; + } + + Some(ParsedToolCall { + name, + arguments: serde_json::Value::Object(map), + }) +} + +/// Scan text for one or more bare `MiniCPM` `` blocks +/// (no `` wrapper). Text outside the blocks is preserved as visible +/// content; unparseable or unterminated blocks are preserved verbatim. +fn parse_minicpm_tool_calls(text: &str, schema: Option<&ToolSchema>) -> ToolParseResult { + let mut result_text = String::new(); + let mut tool_calls = Vec::new(); + let mut remaining = text; + + loop { + let Some(start) = remaining.find(MINICPM_FUNCTION_OPEN) else { + result_text.push_str(remaining); + break; + }; + result_text.push_str(remaining.get(..start).unwrap_or_default()); + let block_region = remaining.get(start..).unwrap_or_default(); + + let Some(end) = minicpm_function_end(block_region) else { + result_text.push_str(block_region); + break; + }; + + let block = block_region.get(..end).unwrap_or_default(); + if let Some(parsed) = parse_minicpm_function(block, schema) { + tool_calls.push(parsed); + } else { + result_text.push_str(block); + result_text.push_str(FUNCTION_CLOSE); + } + remaining = block_region + .get(end + FUNCTION_CLOSE.len()..) + .unwrap_or_default(); + } + + ToolParseResult { + text: result_text.trim().to_owned(), + tool_calls, + } +} + +/// One chunk of streaming output from [`StreamingToolCallTracker::process`] +/// or [`StreamingToolCallTracker::flush`]. +/// +/// `visible` is the text that should be forwarded to the client as a normal +/// content delta. `new_tool_calls` are any tool calls that became complete +/// during this chunk — the route layer turns them into `ToolCallDelta` SSE +/// events. +#[derive(Debug, Default)] +pub struct StreamingToolOutput { + /// Text to forward to the client as a normal content delta. + pub visible: String, + /// Tool calls that became complete during this chunk; the route layer + /// emits each as a `tool_calls` SSE delta. + pub new_tool_calls: Vec, +} + +/// Longest opener token. In the scanning state the tracker keeps this many +/// bytes at the buffer tail so a `` or ` MINICPM_FUNCTION_OPEN.len() { + TOOL_CALL_OPEN.len() +} else { + MINICPM_FUNCTION_OPEN.len() +}; + +/// Which kind of tool-call block the tracker is currently inside. +#[derive(Clone, Copy, PartialEq, Eq)] +enum Inside { + /// Scanning for the next opener. + None, + /// Inside a `` block (JSON or Qwen `` block. + Function, +} + +/// State machine that buffers streaming text chunks and extracts tool-call +/// blocks on the fly — `` (JSON or Qwen ``. +/// +/// Designed to be cheap: when `active = false` (no tools in the request), +/// `process` is a single allocation per chunk and `flush` is a no-op. +/// +/// When active, it retains a small tail so an opener can't straddle a chunk +/// boundary; once a complete block is buffered it is parsed and emitted as a +/// [`ParsedToolCall`]. Text before/after blocks streams out verbatim. +/// +/// Invariants: +/// - **Never silently drops tokens.** Unclosed tags at `flush` are re-emitted +/// as visible content rather than discarded. +/// - **UTF-8 safe.** Tail-flushes walk back to the previous char boundary +/// so a partial multi-byte sequence is never split. +/// - **Pure passthrough when inactive.** Zero parsing cost on requests +/// that did not pass `tools` to the chat route. +pub struct StreamingToolCallTracker { + buffer: String, + inside: Inside, + completed_count: usize, + active: bool, + schema: Option, +} + +impl StreamingToolCallTracker { + /// `schema` carries the request's declared tool parameter types so + /// XML-format values can be coerced; pass `None` for best-effort. + pub const fn new(active: bool, schema: Option) -> Self { + Self { + buffer: String::new(), + inside: Inside::None, + completed_count: 0, + active, + schema, + } + } + + pub const fn completed_count(&self) -> usize { + self.completed_count + } + + pub const fn has_tool_calls(&self) -> bool { + self.completed_count > 0 + } + + /// In the scanning state, advance to the next opener — entering + /// `ToolCall`/`Function` — or flush all-but-tail and signal "wait". + /// Returns `true` to keep looping, `false` to break (need more input). + fn scan_for_opener(&mut self, out: &mut StreamingToolOutput) -> bool { + let tc = self.buffer.find(TOOL_CALL_OPEN); + let fc = self.buffer.find(MINICPM_FUNCTION_OPEN); + // Enter whichever opener appears first; `(pos, is_tool_call)`. + let pick = match (tc, fc) { + (Some(t), Some(f)) => Some(if f < t { (f, false) } else { (t, true) }), + (Some(t), None) => Some((t, true)), + (None, Some(f)) => Some((f, false)), + (None, None) => None, + }; + let Some((pos, is_tool_call)) = pick else { + // No opener yet — flush all but a tail large enough to hold a + // split opener, walking back to a UTF-8 char boundary. + if self.buffer.len() > MAX_OPENER_LEN { + let target_len = self.buffer.len() - MAX_OPENER_LEN; + let mut safe_len = target_len; + while safe_len > 0 && !self.buffer.is_char_boundary(safe_len) { + safe_len -= 1; + } + out.visible + .push_str(self.buffer.get(..safe_len).unwrap_or_default()); + self.buffer = self.buffer.get(safe_len..).unwrap_or_default().to_owned(); + } + return false; + }; + out.visible + .push_str(self.buffer.get(..pos).unwrap_or_default()); + if is_tool_call { + // Strip the `` opener; the inner body is parsed at the closer. + self.buffer = self + .buffer + .get(pos + TOOL_CALL_OPEN.len()..) + .unwrap_or_default() + .to_owned(); + self.inside = Inside::ToolCall; + } else { + // Keep the ` StreamingToolOutput { + if !self.active { + return StreamingToolOutput { + visible: text.to_owned(), + new_tool_calls: Vec::new(), + }; + } + + self.buffer.push_str(text); + let mut out = StreamingToolOutput::default(); + + loop { + match self.inside { + Inside::ToolCall => { + // Seek ``; once seen, parse the inner block + // (JSON or Qwen ` MAX_INSIDE_TOOL_CALL_BYTES { + // Overflow guard: opener seen, closer never arrived. + let leftover = std::mem::take(&mut self.buffer); + out.visible.push_str(TOOL_CALL_OPEN); + out.visible.push_str(&leftover); + self.inside = Inside::None; + break; + } else { + break; + } + } + Inside::Function => { + // The ``. + if let Some(end) = minicpm_function_end(&self.buffer) { + let block = self.buffer.get(..end).unwrap_or_default(); + if let Some(parsed) = parse_minicpm_function(block, self.schema.as_ref()) { + out.new_tool_calls.push(parsed); + self.completed_count += 1; + } else { + out.visible.push_str(block); + out.visible.push_str(FUNCTION_CLOSE); + } + self.buffer = self + .buffer + .get(end + FUNCTION_CLOSE.len()..) + .unwrap_or_default() + .to_owned(); + self.inside = Inside::None; + } else if self.buffer.len() > MAX_INSIDE_TOOL_CALL_BYTES { + // Overflow guard: ` { + if !self.scan_for_opener(&mut out) { + break; + } + } + } + } + + out + } + + /// Drain everything still buffered. Call this when the model stream + /// ends. Any unclosed `` block is emitted as visible content + /// (with its opener prepended) so no tokens silently vanish. + pub fn flush(&mut self) -> StreamingToolOutput { + let leftover = std::mem::take(&mut self.buffer); + let inside = self.inside; + self.inside = Inside::None; + + let visible = match inside { + // The `` opener was stripped on entry, so re-prepend it. + Inside::ToolCall => { + let mut v = String::with_capacity(TOOL_CALL_OPEN.len() + leftover.len()); + v.push_str(TOOL_CALL_OPEN); + v.push_str(&leftover); + v + } + // `Function` keeps its ` leftover, + }; + + StreamingToolOutput { + visible, + new_tool_calls: Vec::new(), + } + } +} + #[cfg(test)] -#[allow(clippy::panic, clippy::unwrap_used)] +#[allow(clippy::panic, clippy::unwrap_used, clippy::indexing_slicing)] mod tests { use super::*; @@ -101,7 +750,7 @@ mod tests { expected_tools: usize, text_contains: Option<&str>, ) -> ToolParseResult { - let result = parse_tool_calls(input); + let result = parse_tool_calls(input, None); assert_eq!( result.tool_calls.len(), expected_tools, @@ -245,7 +894,7 @@ this is not json {"arguments": {"key": "value"}, "description": "no name field"} "#; assert_raw_preserved(input); - let result = parse_tool_calls(input); + let result = parse_tool_calls(input, None); assert!(result.text.contains("no name field")); } @@ -253,7 +902,7 @@ this is not json fn test_valid_json_array_not_object_preserved_as_raw() { let input = "\n[1, 2, 3]\n"; assert_raw_preserved(input); - let result = parse_tool_calls(input); + let result = parse_tool_calls(input, None); assert!(result.text.contains("[1, 2, 3]")); } @@ -290,7 +939,7 @@ After last."#; {"name": "inner", "arguments": {}} "#; - let result = parse_tool_calls(input); + let result = parse_tool_calls(input, None); // The parser finds the first , then looks for first . // Content between them: "\n\n{\"name\": \"inner\", \"arguments\": {}}\n" // This is not valid JSON (starts with ), so it's preserved as raw text. @@ -336,4 +985,492 @@ After last."#; let input = "\n \n \t \n"; assert_parse(input, 0, Some("")); } + + // ============================================================ + // StreamingToolCallTracker tests + // + // The tracker is a state machine fed text chunks. It buffers + // until it sees `` boundaries, returning + // (visible_text, completed_tool_calls) on every chunk. + // + // Invariants tested: + // 1. inactive=false → pure passthrough, zero overhead + // 2. complete tag in one chunk → tool call emitted, no visible + // 3. tag split across chunks → tracker reassembles + // 4. text before/after tag → both visible, tool extracted + // 5. invalid JSON inside tag → preserved as visible + // 6. unclosed tag at flush → buffered prefix emitted as visible + // 7. multi-byte UTF-8 boundary at buffer-tail → no panic + // 8. has_tool_calls / completed_count track state correctly + // ============================================================ + + fn drain_visible_and_calls( + tracker: &mut StreamingToolCallTracker, + chunks: &[&str], + ) -> (String, Vec) { + let mut visible = String::new(); + let mut calls = Vec::new(); + for chunk in chunks { + let out = tracker.process(chunk); + visible.push_str(&out.visible); + calls.extend(out.new_tool_calls); + } + let final_out = tracker.flush(); + visible.push_str(&final_out.visible); + calls.extend(final_out.new_tool_calls); + (visible, calls) + } + + #[test] + fn streaming_inactive_is_passthrough() { + let mut t = StreamingToolCallTracker::new(false, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "hello ", + "", + "{\"name\":\"x\"}", + "", + " world", + ], + ); + assert_eq!( + vis, "hello {\"name\":\"x\"} world", + "inactive tracker must pass every chunk through verbatim", + ); + assert!(calls.is_empty()); + assert!(!t.has_tool_calls()); + assert_eq!(t.completed_count(), 0); + } + + #[test] + fn streaming_single_call_one_chunk() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[r#"{"name":"get_weather","arguments":{"city":"London"}}"#], + ); + assert!( + vis.trim().is_empty(), + "tool-only input should yield no visible text, got {vis:?}" + ); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "get_weather"); + assert!(t.has_tool_calls()); + assert_eq!(t.completed_count(), 1); + } + + #[test] + fn streaming_tag_split_across_chunks() { + // Open tag arrives in pieces; close tag also chunk-split. Tracker must reassemble. + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "", + r#"{"name":"search","#, + r#""arguments":{"q":"rust"}}"#, + "", + ], + ); + assert!( + vis.trim().is_empty(), + "split tags must not leak into visible, got {vis:?}" + ); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "search"); + } + + #[test] + fn streaming_text_before_and_after() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "Let me check. ", + r#"{"name":"lookup","arguments":{}}"#, + " Done.", + ], + ); + assert!(vis.contains("Let me check.")); + assert!(vis.contains("Done.")); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "lookup"); + } + + #[test] + fn streaming_invalid_json_preserved_as_visible() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = + drain_visible_and_calls(&mut t, &["not json after"]); + assert!(vis.contains("")); + assert!(vis.contains("not json")); + assert!(vis.contains("")); + assert!(vis.contains("after")); + assert!(calls.is_empty()); + assert_eq!(t.completed_count(), 0); + } + + #[test] + fn streaming_unclosed_tag_flushed_as_visible() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls(&mut t, &["{\"name\":\"partial\""]); + // No closing tag ever arrives — at flush, the buffered prefix MUST be + // emitted as visible (otherwise tokens vanish silently). + assert!(vis.contains("")); + assert!(vis.contains("partial")); + assert!(calls.is_empty()); + } + + #[test] + fn streaming_utf8_char_boundary_safety() { + // The tracker's tail-flush logic must respect UTF-8 char boundaries, + // otherwise it can panic when slicing inside a multi-byte sequence. + let mut t = StreamingToolCallTracker::new(true, None); + // Buffer ends just before the `é` byte sequence; next chunk completes it. + let (vis, calls) = + drain_visible_and_calls(&mut t, &["caf", "\u{00e9}", " and more text here"]); + assert!(vis.contains("caf\u{00e9}")); + assert!(vis.contains("more text")); + assert!(calls.is_empty()); + } + + #[test] + fn streaming_unbounded_buffer_capped_and_recovers() { + // CRITICAL guard (closed upstream PR #63 finding): a model that + // opens `` and never closes must not grow `buffer` past + // `MAX_INSIDE_TOOL_CALL_BYTES`. On overflow we drop the parse, + // flush the buffered bytes as visible, and reset so a later valid + // tool call in the same stream still parses. + let mut t = StreamingToolCallTracker::new(true, None); + let huge = "x".repeat(MAX_INSIDE_TOOL_CALL_BYTES + 1); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "", + huge.as_str(), + // Same stream, after the overflow — a well-formed call + // arrives. The reset state must let it through. + r#"{"name":"after","arguments":{}}"#, + ], + ); + assert!( + vis.contains(""), + "overflow must surface opener as visible, not silently swallow", + ); + assert!( + vis.contains(huge.as_str()), + "overflow must surface buffered bytes as visible", + ); + assert_eq!(calls.len(), 1, "post-overflow valid call still parses"); + assert_eq!(calls[0].name, "after"); + assert_eq!(t.completed_count(), 1); + } + + #[test] + fn streaming_multiple_calls_with_text_between() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "first ", + r#"{"name":"a","arguments":{}}"#, + " middle ", + r#"{"name":"b","arguments":{}}"#, + " last", + ], + ); + assert!(vis.contains("first")); + assert!(vis.contains("middle")); + assert!(vis.contains("last")); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "a"); + assert_eq!(calls[1].name, "b"); + assert_eq!(t.completed_count(), 2); + assert!(t.has_tool_calls()); + } + + // ============================================================ + // Qwen XML tool-call format: … + // ============================================================ + + /// The canonical XML shape Qwen3.5/3.6 emit: one string parameter, + /// values wrapped in newlines by the template. + #[test] + fn xml_single_call_one_param() { + let input = "\n\n\nLondon\n\n\n"; + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 1); + let tc = result.tool_calls.first().unwrap(); + assert_eq!(tc.name, "get_weather"); + assert_eq!(tc.arguments, serde_json::json!({ "city": "London" })); + assert!(result.text.is_empty()); + } + + /// Multiple parameters, and a multi-line value: only the single wrapping + /// newline is stripped, internal newlines are preserved. + #[test] + fn xml_multi_param_multiline_value() { + let input = "\n\n\nsrc/main.rs\n\n\nline one\nline two\n\n\n"; + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 1); + assert_eq!( + result.tool_calls.first().unwrap().arguments, + serde_json::json!({ "path": "src/main.rs", "content": "line one\nline two" }) + ); + } + + /// With a declared schema, values are coerced to their JSON types — and + /// crucially a `string`-typed `"123"` stays a string (schema beats the + /// best-effort number guess). + #[test] + fn xml_schema_driven_coercion() { + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "configure", + "parameters": { + "type": "object", + "properties": { + "count": { "type": "integer" }, + "enabled": { "type": "boolean" }, + "opts": { "type": "object" }, + "label": { "type": "string" } + } + } + } + })]; + let schema = ToolSchema::from_tools(Some(tools.as_slice())); + let input = "\n\n\n42\n\n\ntrue\n\n\n{\"a\": 1}\n\n\n123\n\n\n"; + let result = parse_tool_calls(input, schema.as_ref()); + assert_eq!( + result.tool_calls.first().unwrap().arguments, + serde_json::json!({ "count": 42, "enabled": true, "opts": { "a": 1 }, "label": "123" }) + ); + } + + /// An `integer`-typed parameter must reject fractional input (kept as a + /// string) but accept whole numbers — `is_number` alone would wrongly + /// accept `3.14`. + #[test] + fn xml_integer_rejects_fractional() { + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "f", + "parameters": { "type": "object", "properties": { "n": { "type": "integer" } } } + } + })]; + let schema = ToolSchema::from_tools(Some(tools.as_slice())); + let frac = "\n\n\n3.14\n\n\n"; + assert_eq!( + parse_tool_calls(frac, schema.as_ref()) + .tool_calls + .first() + .unwrap() + .arguments, + serde_json::json!({ "n": "3.14" }) + ); + let whole = + "\n\n\n42\n\n\n"; + assert_eq!( + parse_tool_calls(whole, schema.as_ref()) + .tool_calls + .first() + .unwrap() + .arguments, + serde_json::json!({ "n": 42 }) + ); + } + + /// Without a schema, coercion is best-effort: valid-JSON scalars parse + /// (`42` → number) while bare words stay strings (`London`). + #[test] + fn xml_no_schema_best_effort_coercion() { + let input = "\n\n\n42\n\n\nLondon\n\n\n"; + let result = parse_tool_calls(input, None); + assert_eq!( + result.tool_calls.first().unwrap().arguments, + serde_json::json!({ "n": 42, "city": "London" }) + ); + } + + /// Backward-compat guard: a JSON `` and an XML `` + /// in the same text both parse (dispatch on shape, not on the model). + #[test] + fn mixed_json_and_xml_tool_calls_both_parse() { + let input = concat!( + "\n{\"name\": \"json_call\", \"arguments\": {\"x\": 1}}\n\n", + "\n\n\nhi\n\n\n" + ); + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 2); + assert_eq!(result.tool_calls[0].name, "json_call"); + assert_eq!( + result.tool_calls[0].arguments, + serde_json::json!({ "x": 1 }) + ); + assert_eq!(result.tool_calls[1].name, "xml_call"); + assert_eq!( + result.tool_calls[1].arguments, + serde_json::json!({ "y": "hi" }) + ); + } + + /// The streaming tracker must reassemble an XML tool call split across + /// chunk boundaries (inside the `` opener and the value) and + /// not leak any of it to visible content. + #[test] + fn streaming_xml_split_across_chunks() { + let mut t = StreamingToolCallTracker::new(true, None); + let (vis, calls) = drain_visible_and_calls( + &mut t, + &[ + "\n\n\nLon", + "don\n\n\n", + ], + ); + assert!( + vis.trim().is_empty(), + "split XML must not leak to visible, got {vis:?}" + ); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "get_weather"); + assert_eq!(calls[0].arguments, serde_json::json!({ "city": "London" })); + assert_eq!(t.completed_count(), 1); + } + + // ============================================================ + // MiniCPM5 tool-call format: … + // (no wrapper, attribute-named, optional CDATA values) + // ============================================================ + + /// Canonical `MiniCPM` shape: bare `` with one param. + #[test] + fn minicpm_single_call_one_param() { + let input = r#"London"#; + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 1); + let tc = result.tool_calls.first().unwrap(); + assert_eq!(tc.name, "get_weather"); + assert_eq!(tc.arguments, serde_json::json!({ "city": "London" })); + assert!(result.text.is_empty()); + } + + /// Multiple consecutive blocks, with text before/between them preserved. + #[test] + fn minicpm_multiple_calls_with_text() { + let input = concat!( + "Sure.", + r#"1"#, + " then ", + r#"two"#, + ); + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 2); + assert_eq!(result.tool_calls[0].name, "a"); + // No schema → best-effort: "1" parses to a number. + assert_eq!( + result.tool_calls[0].arguments, + serde_json::json!({ "x": 1 }) + ); + assert_eq!(result.tool_calls[1].name, "b"); + assert_eq!( + result.tool_calls[1].arguments, + serde_json::json!({ "y": "two" }) + ); + assert!(result.text.contains("Sure.")); + assert!(result.text.contains("then")); + } + + /// A CDATA value containing both a newline and a literal `` + /// must be captured verbatim and must NOT close the block early. + #[test] + fn minicpm_cdata_value_with_literal_close_tag() { + let input = " not a real close\n}]]>"; + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 1); + let tc = result.tool_calls.first().unwrap(); + assert_eq!(tc.name, "write"); + let code = tc.arguments.get("code").unwrap().as_str().unwrap(); + assert!(code.contains("fn main()")); + assert!(code.contains(" not a real close")); + assert!(code.contains('\n')); + assert!(result.text.is_empty()); + } + + /// Declared schema coerces `MiniCPM` param values; a `string`-typed `"123"` + /// stays a string (schema beats the best-effort number guess). + #[test] + fn minicpm_schema_driven_coercion() { + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "cfg", + "parameters": { "type": "object", "properties": { + "count": { "type": "integer" }, + "on": { "type": "boolean" }, + "label": { "type": "string" } + }} + } + })]; + let schema = ToolSchema::from_tools(Some(tools.as_slice())); + let input = r#"7true123"#; + let result = parse_tool_calls(input, schema.as_ref()); + assert_eq!( + result.tool_calls.first().unwrap().arguments, + serde_json::json!({ "count": 7, "on": true, "label": "123" }) + ); + } + + /// A function with no params yields empty arguments, not a failure. + #[test] + fn minicpm_no_param_function() { + let input = r#""#; + let result = parse_tool_calls(input, None); + assert_eq!(result.tool_calls.len(), 1); + assert_eq!(result.tool_calls.first().unwrap().name, "ping"); + assert_eq!( + result.tool_calls.first().unwrap().arguments, + serde_json::json!({}) + ); + } + + /// A `` opener with no `name="…"` attribute must NOT borrow the + /// `name` from a nested `` — the block is preserved verbatim rather + /// than routed into the tool-execution path as a call named "city". + #[test] + fn minicpm_function_without_name_is_not_parsed() { + let input = "Paris"; + let result = parse_tool_calls(input, None); + assert!(result.tool_calls.is_empty()); + assert!(result.text.contains("")); + } + + /// Streaming: the tracker reassembles a `MiniCPM` call split inside the + /// `", + "", + ], + ); + assert!( + vis.trim().is_empty(), + "split MiniCPM must not leak to visible, got {vis:?}" + ); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "run"); + assert_eq!(calls[0].arguments, serde_json::json!({ "cmd": "echo hi" })); + assert_eq!(t.completed_count(), 1); + } } diff --git a/crates/higgs-models/src/transformer.rs b/crates/higgs-models/src/transformer.rs index 61ee991d..62a7e22a 100644 --- a/crates/higgs-models/src/transformer.rs +++ b/crates/higgs-models/src/transformer.rs @@ -90,6 +90,14 @@ pub struct ModelArgs { pub sliding_window: Option, #[serde(default)] pub rope_scaling: Option, + /// Explicit attention head dimension (`head_dim` in config.json). Most + /// Llama-family configs omit it, in which case `head_dim()` falls back to + /// `hidden_size / num_attention_heads`. Some models set it to a value that + /// differs from that ratio — e.g. `MiniCPM5` uses 128 while 1536/16 = 96 — + /// and the attention projections, `RoPE`, and scale must use the explicit + /// value or the model produces garbage. + #[serde(default, rename = "head_dim")] + pub head_dim_override: Option, // Quantization (present in pre-quantized MLX models) #[serde(default)] @@ -106,10 +114,14 @@ impl ModelArgs { .unwrap_or(matches!(self.model_type.as_str(), "qwen2")) } - /// Head dimension, computed from `hidden_size / num_attention_heads`. + /// Head dimension. Uses the explicit `head_dim` from config when present, + /// otherwise `hidden_size / num_attention_heads`. /// - /// Panics in debug builds if not evenly divisible. + /// Panics in debug builds if the fallback is not evenly divisible. pub fn head_dim(&self) -> i32 { + if let Some(head_dim) = self.head_dim_override { + return head_dim; + } debug_assert!( self.num_attention_heads != 0 && self.hidden_size % self.num_attention_heads == 0, "hidden_size ({}) must be divisible by num_attention_heads ({})", @@ -119,8 +131,18 @@ impl ModelArgs { self.hidden_size / self.num_attention_heads } - /// Validated head dimension that returns an error if not evenly divisible. + /// Validated head dimension. Honours an explicit `head_dim` from config; + /// otherwise returns an error if `hidden_size` is not evenly divisible by + /// `num_attention_heads`. pub fn checked_head_dim(&self) -> Result { + if let Some(head_dim) = self.head_dim_override { + if head_dim <= 0 { + return Err(ModelError::ShapeMismatch( + "explicit head_dim must be positive".to_owned(), + )); + } + return Ok(head_dim); + } if self.num_attention_heads == 0 { return Err(ModelError::ShapeMismatch( "num_attention_heads must be positive".to_owned(), @@ -994,6 +1016,7 @@ mod tests { use_sliding_window: false, sliding_window: None, rope_scaling: None, + head_dim_override: None, quantization: None, } } @@ -1199,6 +1222,37 @@ mod tests { assert!(args.quantization.is_none()); } + #[test] + fn test_explicit_head_dim_honored() { + // MiniCPM5-1B: head_dim=128 even though hidden/heads = 1536/16 = 96. + // Must use the explicit value or attention/RoPE are wrong. + let json = r#"{ + "architectures": ["LlamaForCausalLM"], + "model_type": "llama", + "hidden_size": 1536, + "num_hidden_layers": 24, + "intermediate_size": 4608, + "num_attention_heads": 16, + "rms_norm_eps": 1e-06, + "vocab_size": 130560, + "num_key_value_heads": 2, + "max_position_embeddings": 131072, + "head_dim": 128 + }"#; + let args = assert_model_config(json, "llama", 1536, 128, false); + assert_eq!(args.head_dim_override, Some(128)); + assert_eq!(args.checked_head_dim().unwrap(), 128); + } + + #[test] + fn test_head_dim_falls_back_when_config_omits_it() { + // Standard Llama config has no `head_dim` → fall back to 4096/32 = 128. + let args = make_model_args("llama", 4096, 32, 32, 32000, 32); + assert_eq!(args.head_dim_override, None); + assert_eq!(args.head_dim(), 128); + assert_eq!(args.checked_head_dim().unwrap(), 128); + } + #[test] fn test_checked_head_dim_valid_cases() { // 768 / 12 = 64 diff --git a/crates/higgs/src/routes/chat.rs b/crates/higgs/src/routes/chat.rs index dd4936fa..5cfecdba 100644 --- a/crates/higgs/src/routes/chat.rs +++ b/crates/higgs/src/routes/chat.rs @@ -23,7 +23,7 @@ use crate::{ types::openai::{ ChatCompletionChoice, ChatCompletionDelta, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse, ChoiceLogprobs, CompletionUsage, MessageContent, StopSequence, - TokenLogprob, ToolCall, ToolCallFunction, TopLogprob, + TokenLogprob, ToolCall, ToolCallDelta, ToolCallFunction, ToolCallFunctionDelta, TopLogprob, }, }; use higgs_models::SamplingParams; @@ -267,7 +267,9 @@ async fn chat_completions_non_streaming( }; let messages = convert_messages(&effective_messages); - let tools = req.tools.as_deref(); + // Treat an empty `tools: []` as absent (mirrors the streaming path) so it + // doesn't define `tools` in the template context or trigger tool parsing. + let tools = req.tools.as_deref().filter(|t| !t.is_empty()); let thinking_enabled = crate::reasoning::effective_thinking_enabled( engine.enable_thinking(), &[engine.model_name(), req.model.as_str()], @@ -316,7 +318,7 @@ async fn chat_completions_non_streaming( .map_err(ServerError::Engine)?; let request_id = generate_request_id(); - let has_tools = req.tools.is_some(); + let has_tools = tools.is_some(); let logprobs_response = output .token_logprobs @@ -348,7 +350,8 @@ async fn chat_completions_non_streaming( }; let (content, tool_calls, finish_reason) = if has_tools { - let parsed = higgs_engine::tool_parser::parse_tool_calls(&raw_text); + let schema = higgs_engine::tool_parser::ToolSchema::from_tools(tools); + let parsed = higgs_engine::tool_parser::parse_tool_calls(&raw_text, schema.as_ref()); if parsed.tool_calls.is_empty() { ( Some(MessageContent::Text(raw_text)), @@ -418,15 +421,16 @@ fn chat_completions_stream( routing_method: crate::router::RoutingMethod, ) -> Result>, ServerError> { let stream_includes_tools = req.tools.as_ref().is_some_and(|t| !t.is_empty()); + // Built here (before the `async_stream::stream!` block, which captures by + // move) so the tracker can coerce XML-format tool-call values to their + // declared JSON types. + let tool_schema = higgs_engine::tool_parser::ToolSchema::from_tools(req.tools.as_deref()); - // Tool-calling responses are not supported in streaming mode. - // Accept requests that include tools (nanobot always sends them) but - // exclude them from prompt rendering so the model generates plain text. if stream_includes_tools { - tracing::warn!( + tracing::debug!( request_model = req.model, tool_count = req.tools.as_ref().map_or(0, Vec::len), - "Streaming API does not support tool-calls; tools will be ignored", + "Streaming with tool-calls enabled; will emit tool_calls deltas via StreamingToolCallTracker", ); } @@ -451,9 +455,17 @@ fn chat_completions_stream( req.reasoning.as_ref(), ); - // Exclude tools from streaming prompt — tool_calls deltas are unsupported. + // Pass tools into prompt rendering so the chat template emits the + // tool spec the model recognises. The on-the-fly + // [`StreamingToolCallTracker`] below intercepts `… + // ` blocks the model produces and turns them into + // structured `ToolCallDelta` SSE events. + let prompt_tools = req + .tools + .as_deref() + .and_then(|t| if t.is_empty() { None } else { Some(t) }); let mut prompt_tokens = engine - .prepare_chat_prompt_with_thinking(&messages, None, thinking_enabled_stream) + .prepare_chat_prompt_with_thinking(&messages, prompt_tools, thinking_enabled_stream) .map_err(ServerError::Engine)?; // Preprocess images for VLM @@ -548,6 +560,29 @@ fn chat_completions_stream( } else { higgs_engine::reasoning_parser::StreamingReasoningTracker::new() }; + // Streaming tool-call extractor — passthrough when no tools were + // requested, otherwise watches for `` + // blocks and emits structured `ToolCallDelta` events. + let mut tool_tracker = higgs_engine::tool_parser::StreamingToolCallTracker::new( + stream_includes_tools, + tool_schema, + ); + + // Closure that turns a `ParsedToolCall` into the OpenAI streaming + // delta shape. Index is the running zero-based position of the + // call in this response. + let make_tool_delta = |index: u32, parsed: &higgs_engine::tool_parser::ParsedToolCall| { + ToolCallDelta { + index, + id: Some(format!("call_{index}_{}", uuid::Uuid::new_v4())), + r#type: Some("function".to_owned()), + function: Some(ToolCallFunctionDelta { + name: Some(parsed.name.clone()), + arguments: Some(parsed.arguments.to_string()), + }), + } + }; + let mut output_token_count: u32 = 0; let mut pending_finish_reason: Option = None; let mut pending_finish_logprobs: Option = None; @@ -560,7 +595,6 @@ fn chat_completions_stream( .map(|lp| logprobs_to_response(std::slice::from_ref(lp), &tokenizer)); let (visible, reasoning) = reasoning_tracker.process(&output.new_text); - let visible_is_empty = visible.is_empty(); if !reasoning.is_empty() { let d = ChatCompletionDelta { @@ -572,10 +606,35 @@ fn chat_completions_stream( emit_delta!(&d, None, None); } - if !visible.is_empty() { + // Run the visible-text portion through the tool-call tracker + // so `` blocks become structured + // deltas rather than being spoken aloud as plain text. + let tool_out = tool_tracker.process(&visible); + let visible_is_empty = tool_out.visible.is_empty(); + + // Tool-call indices count up across the whole response. Each + // chunk that closes N tool calls covers indices + // `[base_index .. base_index+N)` where `base_index` is the + // total completed *before* this chunk. + let base_index = tool_tracker + .completed_count() + .saturating_sub(tool_out.new_tool_calls.len()); + for (i, parsed) in tool_out.new_tool_calls.iter().enumerate() { + #[allow(clippy::cast_possible_truncation)] + let idx = u32::try_from(base_index + i).unwrap_or(u32::MAX); let d = ChatCompletionDelta { role: None, - content: Some(visible), + content: None, + reasoning_content: None, + tool_calls: Some(vec![make_tool_delta(idx, parsed)]), + }; + emit_delta!(&d, None, None); + } + + if !tool_out.visible.is_empty() { + let d = ChatCompletionDelta { + role: None, + content: Some(tool_out.visible), reasoning_content: None, tool_calls: None, }; @@ -588,7 +647,7 @@ fn chat_completions_stream( } } - // Flush any remaining buffered content from the reasoning tracker + // Flush any remaining buffered content. let (flush_vis, flush_reas) = reasoning_tracker.flush(); if !flush_reas.is_empty() { let d = ChatCompletionDelta { @@ -599,23 +658,59 @@ fn chat_completions_stream( }; emit_delta!(&d, None, None); } - if !flush_vis.is_empty() { + // Drain the tool tracker (handles unclosed `` tags by + // re-emitting their buffered prefix as visible content — never + // silently drop tokens). + let flush_tool_out = tool_tracker.process(&flush_vis); + let flush_base_index = tool_tracker + .completed_count() + .saturating_sub(flush_tool_out.new_tool_calls.len()); + for (i, parsed) in flush_tool_out.new_tool_calls.iter().enumerate() { + #[allow(clippy::cast_possible_truncation)] + let idx = u32::try_from(flush_base_index + i).unwrap_or(u32::MAX); + let d = ChatCompletionDelta { + role: None, + content: None, + reasoning_content: None, + tool_calls: Some(vec![make_tool_delta(idx, parsed)]), + }; + emit_delta!(&d, None, None); + } + if !flush_tool_out.visible.is_empty() { + let d = ChatCompletionDelta { + role: None, + content: Some(flush_tool_out.visible), + reasoning_content: None, + tool_calls: None, + }; + emit_delta!(&d, None, None); + } + let final_tool_out = tool_tracker.flush(); + if !final_tool_out.visible.is_empty() { let d = ChatCompletionDelta { role: None, - content: Some(flush_vis), + content: Some(final_tool_out.visible), reasoning_content: None, tool_calls: None, }; emit_delta!(&d, None, None); } + + // Defer `finish_reason` until after the tracker has drained so we + // know whether to report `"tool_calls"` or `"stop"`. if let Some(finish_reason) = pending_finish_reason { + let effective_finish = if tool_tracker.has_tool_calls() { + "tool_calls".to_owned() + } else { + finish_reason + }; let d = ChatCompletionDelta { role: None, content: None, reasoning_content: None, tool_calls: None, }; - emit_delta!(&d, Some(finish_reason.as_str()), pending_finish_logprobs.as_ref()); + emit_delta!(&d, Some(effective_finish.as_str()), pending_finish_logprobs.as_ref()); } // Emit final chunk with usage only when explicitly requested. @@ -654,6 +749,17 @@ fn convert_messages( calls .iter() .filter_map(|tc| serde_json::to_value(tc).ok()) + .map(|mut tc_value| { + // Make the tool call template-friendly: hoist + // `function.{name,arguments}` to the top level + // and parse string-encoded arguments to a JSON + // value. Without this, Qwen's chat template + // crashes on `tool_call.arguments|items`. + higgs_engine::chat_template::normalize_tool_call_for_template( + &mut tc_value, + ); + tc_value + }) .collect() }); let content = m