From 4da261788e4233d0a9188c53df2a44250ad5e638 Mon Sep 17 00:00:00 2001 From: VS Chandra Mourya Date: Tue, 14 Apr 2026 22:28:31 -0700 Subject: [PATCH 01/27] fix(grpc): skip reasoning parser when constrained decoding is active --- .../src/routers/grpc/regular/processor.rs | 8 +- .../src/routers/grpc/regular/streaming.rs | 4 + model_gateway/src/routers/grpc/utils/mod.rs | 2 +- .../src/routers/grpc/utils/parsers.rs | 167 ++++++++++++++++++ 4 files changed, 179 insertions(+), 2 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 582b401ad..6e5737fb1 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -96,7 +96,13 @@ impl ResponseProcessor { let mut reasoning_text: Option = None; let mut processed_text = final_text; - if original_request.separate_reasoning && reasoning_parser_available { + if original_request.separate_reasoning + && reasoning_parser_available + && !utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), + ) + { let pooled_parser = utils::get_reasoning_parser( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 031c93ab0..0bcdd9c51 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -233,6 +233,10 @@ impl StreamingProcessor { // Check parser availability once upfront (log warning only once per request) let reasoning_parser_available = separate_reasoning + && !utils::has_constrained_output( + tool_choice.as_ref(), + original_request.response_format.as_ref(), + ) && utils::check_reasoning_parser_availability( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), diff --git a/model_gateway/src/routers/grpc/utils/mod.rs b/model_gateway/src/routers/grpc/utils/mod.rs index 2415494ae..f7c3174ea 100644 --- a/model_gateway/src/routers/grpc/utils/mod.rs +++ b/model_gateway/src/routers/grpc/utils/mod.rs @@ -22,5 +22,5 @@ pub(crate) use metrics::{error_type_from_status, route_to_endpoint}; pub(crate) use parsers::{ check_reasoning_parser_availability, check_tool_parser_availability, create_reasoning_parser, create_tool_parser, extract_thinking_from_kwargs, get_reasoning_parser, get_tool_parser, - should_mark_reasoning_started, + has_constrained_output, should_mark_reasoning_started, }; diff --git a/model_gateway/src/routers/grpc/utils/parsers.rs b/model_gateway/src/routers/grpc/utils/parsers.rs index 106f0a8ba..d77135b12 100644 --- a/model_gateway/src/routers/grpc/utils/parsers.rs +++ b/model_gateway/src/routers/grpc/utils/parsers.rs @@ -4,6 +4,7 @@ use llm_tokenizer::{ chat_template::{ThinkingKeyName, ThinkingToggle}, traits::Tokenizer, }; +use openai_protocol::common::{ResponseFormat, ToolChoice, ToolChoiceValue}; use reasoning_parser::{ ParserFactory as ReasoningParserFactory, PooledParser as ReasoningPooledParser, ReasoningParser, }; @@ -178,3 +179,169 @@ pub(crate) fn create_tool_parser( tool_parser_factory.registry().create_for_model(model) } } + +/// Returns `true` when constrained decoding is active, meaning the model output +/// is structured JSON rather than free-form text. In that case the reasoning +/// parser must be skipped — otherwise it captures the constrained JSON as +/// `reasoning_content` and leaves `content` empty. +/// +/// Constrained decoding is triggered by: +/// - `tool_choice` = a specific function, `required`, or `allowed_tools` with +/// `mode == "required"` +/// - `response_format` = `json_object` or `json_schema` +pub(crate) fn has_constrained_output( + tool_choice: Option<&ToolChoice>, + response_format: Option<&ResponseFormat>, +) -> bool { + let constrained_tool_choice = matches!( + tool_choice, + Some(ToolChoice::Function { .. }) | Some(ToolChoice::Value(ToolChoiceValue::Required)) + ) || matches!( + tool_choice, + Some(ToolChoice::AllowedTools { mode, .. }) if mode == "required" + ); + + let constrained_response_format = matches!( + response_format, + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) + ); + + constrained_tool_choice || constrained_response_format +} + +#[cfg(test)] +mod tests { + use super::*; + use openai_protocol::common::{ + FunctionChoice, JsonSchemaFormat, ToolReference, + }; + + // ── has_constrained_output: tool_choice variants ──────────────────── + + #[test] + fn no_tool_choice_no_response_format_is_unconstrained() { + assert!(!has_constrained_output(None, None)); + } + + #[test] + fn tool_choice_auto_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(!has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_none_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::None); + assert!(!has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_required_is_constrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Required); + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_specific_function_is_constrained() { + let tc = ToolChoice::Function { + tool_type: "function".to_string(), + function: FunctionChoice { + name: "get_weather".to_string(), + }, + }; + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn allowed_tools_required_is_constrained() { + let tc = ToolChoice::AllowedTools { + tool_type: "allowed_tools".to_string(), + mode: "required".to_string(), + tools: vec![ToolReference::Function { + name: "search".to_string(), + }], + }; + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn allowed_tools_auto_is_unconstrained() { + let tc = ToolChoice::AllowedTools { + tool_type: "allowed_tools".to_string(), + mode: "auto".to_string(), + tools: vec![ToolReference::Function { + name: "search".to_string(), + }], + }; + assert!(!has_constrained_output(Some(&tc), None)); + } + + // ── has_constrained_output: response_format variants ──────────────── + + #[test] + fn response_format_text_is_unconstrained() { + assert!(!has_constrained_output(None, Some(&ResponseFormat::Text))); + } + + #[test] + fn response_format_json_object_is_constrained() { + assert!(has_constrained_output( + None, + Some(&ResponseFormat::JsonObject) + )); + } + + #[test] + fn response_format_json_schema_is_constrained() { + let rf = ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "feedback".to_string(), + schema: serde_json::json!({"type": "object"}), + strict: Some(true), + }, + }; + assert!(has_constrained_output(None, Some(&rf))); + } + + // ── has_constrained_output: combinations ──────────────────────────── + + #[test] + fn tool_choice_auto_with_json_object_is_constrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(has_constrained_output( + Some(&tc), + Some(&ResponseFormat::JsonObject) + )); + } + + #[test] + fn tool_choice_auto_with_text_format_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(!has_constrained_output( + Some(&tc), + Some(&ResponseFormat::Text) + )); + } + + #[test] + fn tool_choice_required_with_json_schema_both_constrain() { + let tc = ToolChoice::Value(ToolChoiceValue::Required); + let rf = ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "output".to_string(), + schema: serde_json::json!({"type": "object"}), + strict: None, + }, + }; + assert!(has_constrained_output(Some(&tc), Some(&rf))); + } + + #[test] + fn tool_choice_none_with_json_object_is_constrained_via_format() { + let tc = ToolChoice::Value(ToolChoiceValue::None); + assert!(has_constrained_output( + Some(&tc), + Some(&ResponseFormat::JsonObject) + )); + } +} From b5830db43ea194b3e6b591cfec84f58d39d57dfe Mon Sep 17 00:00:00 2001 From: Connor Li Date: Tue, 21 Apr 2026 20:05:38 -0700 Subject: [PATCH 02/27] fix(grpc): register raise_exception in chat templates and coerce tool argument types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Cherry-pick from mourya/engine_metrics_grpc_fixes (65f6c1cb). 1. Register `raise_exception` as a Jinja global in chat template env so HuggingFace templates that use `{{ raise_exception("msg") }}` don't crash with "undefined function". 2. Add `coerce_tool_args_to_schema()` post-processor in processor.rs: after tool calls are parsed, walk the function schema's `properties` and coerce values to match declared types: - Number/Bool/Null → String when schema says "string" - String → Number/Integer when schema says "number"/"integer" - String → Bool when schema says "boolean" - Scalar → Array when schema says "array" Fixes fc-dash regressions where the model emits correct tokens but the parsed JSON has wrong types: - bfcl/simple_javascript_43: `error: null` → `error: "null"` - bfcl/simple_python_218: `patient_id: 546382` → `"546382"` - bfcl/parallel_117: `year: 2000` → `"2000"` Signed-off-by: Connor Li Made-with: Cursor Signed-off-by: Connor Li --- crates/tokenizer/src/chat_template.rs | 12 +++ .../src/routers/grpc/regular/processor.rs | 96 ++++++++++++++++++- 2 files changed, 106 insertions(+), 2 deletions(-) diff --git a/crates/tokenizer/src/chat_template.rs b/crates/tokenizer/src/chat_template.rs index 2ac3b50c7..083a7f7d1 100644 --- a/crates/tokenizer/src/chat_template.rs +++ b/crates/tokenizer/src/chat_template.rs @@ -597,6 +597,18 @@ fn build_environment(template: String) -> Result> { // like ensure_ascii, separators, and sort_keys that HuggingFace templates use env.add_filter("tojson", tojson_filter); + // HuggingFace's Jinja2 environment registers `raise_exception` as a global + // callable so templates can do `{{ raise_exception("msg") }}`. + env.add_function( + "raise_exception", + |msg: String| -> Result { + Err(minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + msg, + )) + }, + ); + Ok(env) } diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 6e5737fb1..8815e1091 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -175,6 +175,7 @@ impl ResponseProcessor { &processed_text, &original_request.model, history_tool_calls_count, + original_request.tools.as_deref(), ) .await; } @@ -318,6 +319,7 @@ impl ResponseProcessor { processed_text: &str, model: &str, history_tool_calls_count: usize, + tools: Option<&[openai_protocol::common::Tool]>, ) -> (Option>, String) { // Get pooled parser for this model let pooled_parser = utils::get_tool_parser( @@ -343,19 +345,23 @@ impl ResponseProcessor { .into_iter() .enumerate() .map(|(index, tc)| { - // Generate ID for this tool call let id = utils::generate_tool_call_id( model, &tc.function.name, index, history_tool_calls_count, ); + let arguments = coerce_tool_args_to_schema( + &tc.function.arguments, + &tc.function.name, + tools.unwrap_or(&[]), + ); ToolCall { id, tool_type: "function".to_string(), function: FunctionCallResponse { name: tc.function.name, - arguments: Some(tc.function.arguments), + arguments: Some(arguments), }, } }) @@ -661,6 +667,7 @@ impl ResponseProcessor { utils::message_utils::get_history_tool_calls_count_messages( &messages_request, ), + None, ) .await; } @@ -865,3 +872,88 @@ impl ResponseProcessor { }) } } + +/// Coerce tool call argument values to match the types declared in the function +/// schema. Models sometimes return `546382` (integer) for a `"type": "string"` +/// parameter, or `null` for `"type": "string"` when the user said "null". This +/// mirrors the type coercion that production tproxy performs. +fn coerce_tool_args_to_schema( + arguments_json: &str, + function_name: &str, + tools: &[openai_protocol::common::Tool], +) -> String { + let schema = tools.iter().find_map(|t| { + let f = &t.function; + if f.name == function_name { + Some(&f.parameters) + } else { + None + } + }); + let schema = match schema { + Some(s) if s.is_object() => s, + _ => return arguments_json.to_string(), + }; + + let mut args: serde_json::Value = match serde_json::from_str(arguments_json) { + Ok(v) => v, + Err(_) => return arguments_json.to_string(), + }; + + let props = match schema.get("properties").and_then(|p| p.as_object()) { + Some(p) => p, + None => return arguments_json.to_string(), + }; + + if let Some(obj) = args.as_object_mut() { + for (key, prop_schema) in props { + let expected_type = prop_schema.get("type").and_then(|t| t.as_str()); + let val = match obj.get(key) { + Some(v) => v.clone(), + None => continue, + }; + + let coerced = match expected_type { + Some("string") => match &val { + serde_json::Value::Number(n) => Some(serde_json::Value::String(n.to_string())), + serde_json::Value::Bool(b) => Some(serde_json::Value::String(b.to_string())), + serde_json::Value::Null => Some(serde_json::Value::String("null".to_string())), + _ => None, + }, + Some("number") => match &val { + serde_json::Value::String(s) => { + s.parse::().ok().and_then(serde_json::Number::from_f64).map(serde_json::Value::Number) + } + _ => None, + }, + Some("integer") => match &val { + serde_json::Value::String(s) => { + s.parse::().ok().map(|n| serde_json::Value::Number(n.into())) + } + _ => None, + }, + Some("array") => match &val { + serde_json::Value::String(_) | serde_json::Value::Number(_) => { + Some(serde_json::Value::Array(vec![val.clone()])) + } + _ => None, + }, + Some("boolean") => match &val { + serde_json::Value::String(s) => match s.as_str() { + "true" => Some(serde_json::Value::Bool(true)), + "false" => Some(serde_json::Value::Bool(false)), + _ => None, + }, + _ => None, + }, + _ => None, + }; + + if let Some(new_val) = coerced { + obj.insert(key.clone(), new_val); + } + } + } + + serde_json::to_string(&args).unwrap_or_else(|_| arguments_json.to_string()) +} From 404afeacac8a6e2749bbcb682267daa736d3815a Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Mon, 6 Apr 2026 23:25:33 -0700 Subject: [PATCH 03/27] fix(tool_parser): fix function call parsing for models with native tool-call tokens - Prioritize explicitly configured tool parser over JSON schema parsing - Support alternative delimiters (<|func_start|>/<|func_end|>) in KimiK2 parser - Prevent reasoning parser from consuming tool call markers - Strip leaked chatml tokens only when a parser is explicitly configured - Truncate trailing content after first valid JSON object for JSON response formats Signed-off-by: ConnorLi96 Made-with: Cursor --- crates/reasoning_parser/src/parsers/base.rs | 14 +++ crates/tool_parser/src/parsers/kimik2.rs | 31 ++++--- .../src/routers/grpc/regular/processor.rs | 93 ++++++++++++++++++- .../src/routers/grpc/regular/streaming.rs | 37 ++++++-- 4 files changed, 150 insertions(+), 25 deletions(-) diff --git a/crates/reasoning_parser/src/parsers/base.rs b/crates/reasoning_parser/src/parsers/base.rs index 6214acd69..0adcc36ba 100644 --- a/crates/reasoning_parser/src/parsers/base.rs +++ b/crates/reasoning_parser/src/parsers/base.rs @@ -63,6 +63,12 @@ impl ReasoningParser for BaseReasoningParser { .to_string(); if !processed_text.contains(&self.config.think_end_token) { + // Don't consume tool call markers as reasoning content + if let Some(tool_pos) = processed_text.find("<|tool_calls_section_begin|>") { + let reasoning_text = processed_text[..tool_pos].trim().to_string(); + let normal_text = processed_text[tool_pos..].to_string(); + return Ok(ParserResult::new(normal_text, reasoning_text)); + } // Assume reasoning was truncated before end token return Ok(ParserResult::reasoning(processed_text)); } @@ -133,6 +139,14 @@ impl ReasoningParser for BaseReasoningParser { // Continue with reasoning content if self.in_reasoning && self.config.stream_reasoning { + // Some models skip and go straight to tool calls + if let Some(tool_pos) = current_text.find("<|tool_calls_section_begin|>") { + let reasoning_text = current_text[..tool_pos].trim().to_string(); + let normal_text = current_text[tool_pos..].to_string(); + self.buffer.clear(); + self.in_reasoning = false; + return Ok(ParserResult::new(normal_text, reasoning_text)); + } // Stream the content immediately let reasoning_text = current_text; self.buffer.clear(); diff --git a/crates/tool_parser/src/parsers/kimik2.rs b/crates/tool_parser/src/parsers/kimik2.rs index 52f47de27..6171d4a05 100644 --- a/crates/tool_parser/src/parsers/kimik2.rs +++ b/crates/tool_parser/src/parsers/kimik2.rs @@ -105,16 +105,15 @@ impl KimiK2Parser { reason = "regex patterns are compile-time string literals" )] pub fn new() -> Self { - // Pattern for complete tool calls - let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*?\})\s*<\|tool_call_end\|>"; + // Supports alternative delimiters: <|func_start|>/<|func_end|>; (?s) for multi-line JSON + let tool_call_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*?\})\s*(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); - // Pattern for streaming (partial) tool calls - let stream_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*)"; + let stream_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*)"; let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern"); // Pattern for removing completed tool calls - let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"; + let end_pattern = r"<\|tool_call_begin\|>.*?(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern"); // Robust parser for ids like "functions.search:0" or fallback "search:0" @@ -229,7 +228,7 @@ impl ToolParser for KimiK2Parser { // No tool markers detected - return all buffered content as normal text let mut normal_text = std::mem::take(&mut self.buffer); // Remove end tokens if present - for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] { + for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>", "<|func_end|>"] { normal_text = normal_text.replace(e_token, ""); } return Ok(StreamingParseResult { @@ -290,13 +289,14 @@ impl ToolParser for KimiK2Parser { function_args }; - // Split by end token before sending (like Python does) - let parsed_args_diff = - if let Some(pos) = argument_diff.find("<|tool_call_end|>") { - &argument_diff[..pos] - } else { - argument_diff - }; + // Split by end token before sending + let end_pos = argument_diff.find("<|tool_call_end|>") + .or_else(|| argument_diff.find("<|func_end|>")); + let parsed_args_diff = if let Some(pos) = end_pos { + &argument_diff[..pos] + } else { + argument_diff + }; if !parsed_args_diff.is_empty() { calls.push(ToolCallItem { @@ -313,8 +313,9 @@ impl ToolParser for KimiK2Parser { } // Check completeness - split by end token first - let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>") - { + let end_pos2 = function_args.find("<|tool_call_end|>") + .or_else(|| function_args.find("<|func_end|>")); + let parsed_args = if let Some(pos) = end_pos2 { &function_args[..pos] } else { function_args diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 8815e1091..051ca7c0e 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -162,7 +162,16 @@ impl ResponseProcessor { } }; - if used_json_schema { + if self.configured_tool_parser.is_some() && tool_parser_available { + // Explicitly configured parser takes priority (models may emit native tokens regardless of tool_choice) + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &original_request.model, + history_tool_calls_count, + ) + .await; + } else if used_json_schema { (tool_calls, processed_text) = utils::parse_json_schema_response( &processed_text, original_request.tool_choice.as_ref(), @@ -198,7 +207,58 @@ impl ResponseProcessor { utils::convert_proto_to_openai_logprobs(proto_logprobs, tokenizer) }); - // Step 5: Build ChatCompletionMessage (proper response message type) + // Strip leaked chatml tokens only when a model-specific parser is configured + if self.configured_tool_parser.is_some() || self.configured_reasoning_parser.is_some() { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + ] { + processed_text = processed_text.replace(token, ""); + } + processed_text = processed_text.trim().to_string(); + } + + let is_json_response = matches!( + &original_request.response_format, + Some(openai_protocol::common::ResponseFormat::JsonObject) + | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + ); + if is_json_response { + if processed_text.starts_with("```json") || processed_text.starts_with("```JSON") { + if let Some(start) = processed_text.find('\n') { + let inner = &processed_text[start + 1..]; + if let Some(end) = inner.rfind("```") { + processed_text = inner[..end].trim().to_string(); + } + } + } + + if processed_text.starts_with('{') { + let mut depth = 0i32; + let mut in_string = false; + let mut escape = false; + let mut json_end = None; + for (i, ch) in processed_text.char_indices() { + if escape { escape = false; continue; } + if ch == '\\' && in_string { escape = true; continue; } + if ch == '"' { in_string = !in_string; continue; } + if in_string { continue; } + if ch == '{' { depth += 1; } else if ch == '}' { + depth -= 1; + if depth == 0 { json_end = Some(i + 1); break; } + } + } + if let Some(end) = json_end { + processed_text = processed_text[..end].to_string(); + } + } + } + + // Build ChatCompletionMessage (proper response message type) let chat_message = ChatCompletionMessage { role: "assistant".to_string(), content: if processed_text.is_empty() { @@ -646,7 +706,17 @@ impl ResponseProcessor { Some(messages::ToolChoice::Tool { .. } | messages::ToolChoice::Any { .. }) ); - if used_json_schema { + if self.configured_tool_parser.is_some() && tool_parser_available { + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &messages_request.model, + utils::message_utils::get_history_tool_calls_count_messages( + &messages_request, + ), + ) + .await; + } else if used_json_schema { // Bridge Messages ToolChoice to Chat ToolChoice for reuse let chat_tool_choice = messages_request .tool_choice @@ -673,7 +743,22 @@ impl ResponseProcessor { } } - // Step 3: Build content blocks + // Strip leaked chatml tokens only when a model-specific parser is configured + if self.configured_tool_parser.is_some() || self.configured_reasoning_parser.is_some() { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + ] { + processed_text = processed_text.replace(token, ""); + } + processed_text = processed_text.trim().to_string(); + } + + // Build content blocks let mut content_blocks: Vec = Vec::new(); // Thinking block first (if present) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 0bcdd9c51..e9c9ebd70 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -400,8 +400,10 @@ impl StreamingProcessor { && tool_choice_enabled && (tool_parser_available || used_json_schema) { - let tool_chunks = if is_specific_function { - // Handle specific function case - emit tool call deltas with arguments + let tool_chunks = if is_specific_function + && !(self.configured_tool_parser.is_some() + && tool_parser_available) + { Self::process_specific_function_stream( &delta, index, @@ -414,7 +416,6 @@ impl StreamingProcessor { history_tool_calls_count, ) } else { - // Use incremental parser for regular/required modes self.process_tool_calls_stream( &delta, index, @@ -443,7 +444,19 @@ impl StreamingProcessor { } } - // Regular content emission + // Strip leaked chatml/think tokens when a parser is configured + let mut delta = delta; + if self.configured_tool_parser.is_some() + || self.configured_reasoning_parser.is_some() + { + for token in [ + "<|im_end|>", "<|im_start|>", "<|im_user|>", + "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", + "", + ] { + delta = delta.replace(token, ""); + } + } if !delta.is_empty() { let content_chunk = ChatCompletionStreamResponse::builder(request_id, model) @@ -1265,7 +1278,17 @@ impl StreamingProcessor { match parser.parse_incremental(delta, tools).await { Ok(StreamingParseResult { normal_text, calls }) => { - // Emit normal text if present + let mut normal_text = normal_text; + if self.configured_tool_parser.is_some() + || self.configured_reasoning_parser.is_some() + { + for token in [ + "<|im_end|>", "<|im_start|>", "<|im_user|>", + "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", + ] { + normal_text = normal_text.replace(token, ""); + } + } if !normal_text.is_empty() { chunks.push( ChatCompletionStreamResponse::builder(request_id, model) @@ -1787,7 +1810,9 @@ impl StreamingProcessor { // Tool call handling: incremental streaming parser if !in_reasoning && streaming_tool_parser.is_some() { - if is_specific_function { + if is_specific_function + && !(self.configured_tool_parser.is_some() && tool_parser_available) + { // Specific function: entire output is arguments for one tool if !has_tool_calls { has_tool_calls = true; From 163dc597edebce6626df78eba1f9486ae9bb9e37 Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Wed, 8 Apr 2026 09:21:43 -0700 Subject: [PATCH 04/27] fix(gateway): comprehensive func call and response quality fixes - Fix tiktoken encoder to use encode_with_special_tokens for chat template output - Prioritize configured tool parser over JSON schema with fallthrough - Skip reasoning parsing when output is constrained (specific tool_choice / JSON response_format) - Support alternative delimiters and hyphens in KimiK2 tool call parser - Prevent reasoning parser from consuming tool call markers - Keep tool call arguments as strings to preserve formatting - Generate TypeScript-style tool declarations for Kimi K2.5 - Strip leaked chatml tokens conditionally and clean up JSON response format - Add name field to Tool ChatMessage variant - Graceful metrics server bind failure Signed-off-by: ConnorLi96 Made-with: Cursor --- crates/protocols/src/chat.rs | 1 + crates/tokenizer/src/tiktoken.rs | 14 - crates/tool_parser/src/parsers/kimik2.rs | 11 +- .../src/routers/grpc/harmony/builder.rs | 1 + .../src/routers/grpc/regular/processor.rs | 17 +- .../grpc/regular/responses/conversions.rs | 2 + .../grpc/regular/stages/chat/preparation.rs | 7 +- .../regular/stages/messages/preparation.rs | 4 +- .../src/routers/grpc/regular/streaming.rs | 38 ++- .../src/routers/grpc/utils/chat_utils.rs | 267 ++++++++++++++++-- model_gateway/tests/spec/chat_message.rs | 1 + 11 files changed, 313 insertions(+), 50 deletions(-) diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index 35fdaf31d..9b6faf946 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -48,6 +48,7 @@ pub enum ChatMessage { Tool { content: MessageContent, tool_call_id: String, + name: Option, }, #[serde(rename = "function")] Function { content: String, name: String }, diff --git a/crates/tokenizer/src/tiktoken.rs b/crates/tokenizer/src/tiktoken.rs index eb5178832..8c8810715 100644 --- a/crates/tokenizer/src/tiktoken.rs +++ b/crates/tokenizer/src/tiktoken.rs @@ -443,20 +443,6 @@ pub fn is_tiktoken_file(path: &Path) -> bool { impl Encoder for TiktokenTokenizer { fn encode(&self, input: &str, _add_special_tokens: bool) -> Result { - // Always use encode_with_special_tokens so that special token strings - // in the input (e.g., <|media_pad|> from chat templates) are recognized - // as single tokens rather than split into BPE sub-tokens. - // - // NOTE: We intentionally ignore `add_special_tokens` here because the - // flag has different semantics across backends. For HuggingFace it - // controls BOS/EOS prepend/append (tiktoken has no such concept). - // For tiktoken, encode_ordinary vs encode_with_special_tokens controls - // whether special-token *patterns* in the input are recognized. - // All callers that encode chat-template-rendered text pass `false` - // (meaning "don't add BOS/EOS"), but tiktoken must still recognize - // the special tokens the template inserted. A proper fix requires - // redesigning the Encoder trait to separate "add wrapper tokens" from - // "recognize special-token patterns". let tokens = self.tokenizer.encode_with_special_tokens(input); Ok(Encoding::Tiktoken(tokens)) } diff --git a/crates/tool_parser/src/parsers/kimik2.rs b/crates/tool_parser/src/parsers/kimik2.rs index 6171d4a05..1eb3dfd88 100644 --- a/crates/tool_parser/src/parsers/kimik2.rs +++ b/crates/tool_parser/src/parsers/kimik2.rs @@ -105,19 +105,20 @@ impl KimiK2Parser { reason = "regex patterns are compile-time string literals" )] pub fn new() -> Self { - // Supports alternative delimiters: <|func_start|>/<|func_end|>; (?s) for multi-line JSON - let tool_call_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*?\})\s*(?:<\|tool_call_end\|>|<\|func_end\|>)"; + // Supports alternative delimiters: <|func_start|>/<|func_end|>; (?s) for multi-line JSON. + // Tool call IDs may contain hyphens (e.g. "functions.execute-tool:0"). + let tool_call_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.\-]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*?\})\s*(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); - let stream_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*)"; + let stream_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.\-]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*)"; let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern"); // Pattern for removing completed tool calls let end_pattern = r"<\|tool_call_begin\|>.*?(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern"); - // Robust parser for ids like "functions.search:0" or fallback "search:0" - let id_pattern = r"^(?:functions\.)?(?P[\w\.]+):(?P\d+)$"; + // Robust parser for ids like "functions.execute-tool:0" or fallback "search:0" + let id_pattern = r"^(?:functions\.)?(?P[\w\.\-]+):(?P\d+)$"; let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern"); Self { diff --git a/model_gateway/src/routers/grpc/harmony/builder.rs b/model_gateway/src/routers/grpc/harmony/builder.rs index 3efaacd30..64360ff7a 100644 --- a/model_gateway/src/routers/grpc/harmony/builder.rs +++ b/model_gateway/src/routers/grpc/harmony/builder.rs @@ -1073,6 +1073,7 @@ impl HarmonyBuilder { ChatMessage::Tool { content, tool_call_id, + .. } => { // Look up the function name from the tool_call_id let function_name = tool_call_map diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 051ca7c0e..a55fdc13d 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -162,8 +162,11 @@ impl ResponseProcessor { } }; - if self.configured_tool_parser.is_some() && tool_parser_available { - // Explicitly configured parser takes priority (models may emit native tokens regardless of tool_choice) + if self.configured_tool_parser.is_some() + && tool_parser_available + && !output_is_constrained + { + // Configured parser for native tool call tokens (auto/required modes) (tool_calls, processed_text) = self .parse_tool_calls( &processed_text, @@ -171,14 +174,20 @@ impl ResponseProcessor { history_tool_calls_count, ) .await; - } else if used_json_schema { + } + + if tool_calls.is_none() && used_json_schema { + // Constrained decoding output: pure JSON wrapped as a tool call (tool_calls, processed_text) = utils::parse_json_schema_response( &processed_text, original_request.tool_choice.as_ref(), &original_request.model, history_tool_calls_count, ); - } else if tool_parser_available { + } + + if tool_calls.is_none() && tool_parser_available { + // Fallback: auto-detected parser (tool_calls, processed_text) = self .parse_tool_calls( &processed_text, diff --git a/model_gateway/src/routers/grpc/regular/responses/conversions.rs b/model_gateway/src/routers/grpc/regular/responses/conversions.rs index ad7387039..53ed450b2 100644 --- a/model_gateway/src/routers/grpc/regular/responses/conversions.rs +++ b/model_gateway/src/routers/grpc/regular/responses/conversions.rs @@ -111,6 +111,7 @@ pub(crate) fn responses_to_chat(req: &ResponsesRequest) -> Result Result, <|im_end|>, ) embedded in the chat template + // output as single token IDs rather than splitting them into characters. + let encoding = match tokenizer.encode(&processed_messages.text, true) { Ok(encoding) => encoding, Err(e) => { error!(function = "ChatPreparationStage::execute", error = %e, "Tokenization failed"); diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs b/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs index 69d66d1b2..d872ce873 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs @@ -153,8 +153,8 @@ impl MessagePreparationStage { } }; - // Step 3: Tokenize the processed text - let encoding = match tokenizer.encode(&processed_messages.text, false) { + // Step 3: Tokenize the processed text (with special token recognition) + let encoding = match tokenizer.encode(&processed_messages.text, true) { Ok(encoding) => encoding, Err(e) => { error!(function = "MessagePreparationStage::execute", error = %e, "Tokenization failed"); diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index e9c9ebd70..39ff8fc2b 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -277,6 +277,16 @@ impl StreamingProcessor { let is_specific_function = used_json_schema && matches!(tool_choice, Some(ToolChoice::Function { .. })); + // Skip reasoning parsing when constrained decoding is active. + // The model emits pure JSON without wrappers, so the + // reasoning parser would swallow the output as reasoning content. + let output_is_constrained = is_specific_function + || matches!( + &original_request.response_format, + Some(openai_protocol::common::ResponseFormat::JsonObject) + | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + ); + let tool_parser_available = tools.is_some() && utils::check_tool_parser_availability( &self.tool_parser_factory, @@ -366,7 +376,10 @@ impl StreamingProcessor { stream_buffer.push_str(&delta); // Reasoning content handling - let in_reasoning = if separate_reasoning && reasoning_parser_available { + let in_reasoning = if separate_reasoning + && reasoning_parser_available + && !output_is_constrained + { let (normal_text, reasoning_chunk, in_reasoning) = self .process_reasoning_stream( &delta, @@ -401,8 +414,9 @@ impl StreamingProcessor { && (tool_parser_available || used_json_schema) { let tool_chunks = if is_specific_function - && !(self.configured_tool_parser.is_some() - && tool_parser_available) + && (output_is_constrained + || !(self.configured_tool_parser.is_some() + && tool_parser_available)) { Self::process_specific_function_stream( &delta, @@ -1220,12 +1234,24 @@ impl StreamingProcessor { ); } - // Emit arguments delta - if !delta.is_empty() { + // Emit arguments delta, stripping any chatml tokens + let mut clean_delta = delta.to_string(); + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + ] { + clean_delta = clean_delta.replace(token, ""); + } + if !clean_delta.is_empty() { chunks.push( ChatCompletionStreamResponse::builder(request_id, model) .created(created) - .add_choice_tool_args(index, delta.to_string()) + .add_choice_tool_args(index, clean_delta) .maybe_system_fingerprint(system_fingerprint) .build(), ); diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index b75b8abf6..5bbb29809 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -83,35 +83,36 @@ pub(crate) fn resolve_tokenizer( /// Process tool call arguments in messages /// Per Transformers docs, tool call arguments in assistant messages should be dicts pub(crate) fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> { + // Validate that arguments are valid JSON but keep them as strings. + // The chat template handles both string and object arguments: + // {% if ... is string %}{{ args }}{% else %}{{ args | tojson }}{% endif %} + // Keeping strings preserves original formatting (e.g. spacing), + // matching the behavior of Python/TGL which passes arguments through as-is. for msg in messages { let role = msg.get("role").and_then(|v| v.as_str()); if role != Some("assistant") { continue; } - let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else { + let Some(tool_calls) = msg.get("tool_calls").and_then(|tc| tc.as_array()) else { continue; }; for call in tool_calls { - let Some(function) = call.get_mut("function") else { + let Some(function) = call.get("function") else { continue; }; - let Some(args) = function.get_mut("arguments") else { + let Some(args) = function.get("arguments") else { continue; }; let Some(args_str) = args.as_str() else { continue; }; - // Parse JSON string to object (like Python json.loads) - match serde_json::from_str::(args_str) { - Ok(parsed) => *args = parsed, - Err(e) => { - return Err(format!( - "Failed to parse tool call arguments as JSON: '{args_str}'. Error: {e}" - )) - } + if serde_json::from_str::(args_str).is_err() { + return Err(format!( + "Invalid JSON in tool call arguments: '{args_str}'" + )); } } } @@ -291,9 +292,23 @@ pub fn process_chat_messages( .transpose() .map_err(|e| format!("Failed to serialize tools: {e}"))?; - let kwargs_capacity = 1 + request.chat_template_kwargs.as_ref().map_or(0, |k| k.len()); + let kwargs_capacity = 2 + request.chat_template_kwargs.as_ref().map_or(0, |k| k.len()); let mut combined_template_kwargs = HashMap::with_capacity(kwargs_capacity); + // Generate TypeScript-style tool declarations for models that use it + // (e.g., Kimi K2.5 chat template checks for tools_ts_str) + if let Some(ref tools) = request.tools { + if !tools.is_empty() { + let ts_str = tools_to_typescript(tools); + if !ts_str.is_empty() { + combined_template_kwargs.insert( + "tools_ts_str".to_string(), + Value::String(ts_str), + ); + } + } + } + // Add reasoning_effort if present (like Python does) if let Some(reasoning_effort) = &request.reasoning_effort { combined_template_kwargs.insert( @@ -426,10 +441,25 @@ pub(crate) fn parse_json_schema_response( model: &str, history_tool_calls_count: usize, ) -> (Option>, String) { + // Strip chatml tokens that may trail the constrained JSON output + let mut clean = processed_text.to_string(); + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + ] { + clean = clean.replace(token, ""); + } + let clean = clean.trim(); + match tool_choice { Some(ToolChoice::Function { function, .. }) => { // Specific function: Parse parameters directly - match serde_json::from_str::(processed_text) { + match serde_json::from_str::(clean) { Ok(params) => { let tool_call = ToolCall { id: generate_tool_call_id( @@ -450,14 +480,14 @@ pub(crate) fn parse_json_schema_response( } Err(e) => { error!("Failed to parse specific function parameters: {}", e); - (None, processed_text.to_string()) + (None, clean.to_string()) } } } Some(ToolChoice::Value(ToolChoiceValue::Required)) | Some(ToolChoice::AllowedTools { .. }) => { // Required mode: Parse array of tool calls - match serde_json::from_str::>(processed_text) { + match serde_json::from_str::>(clean) { Ok(parsed_array) => { let spec_tool_calls: Vec = parsed_array .into_iter() @@ -489,11 +519,11 @@ pub(crate) fn parse_json_schema_response( } Err(e) => { error!("Failed to parse required tool call array: {}", e); - (None, processed_text.to_string()) + (None, clean.to_string()) } } } - _ => (None, processed_text.to_string()), + _ => (None, clean.to_string()), } } @@ -781,3 +811,206 @@ mod tests { assert_eq!(content_array[1], json!({"type": "image"})); } } + +// ============================================================================ +// TypeScript-style tool declaration generator +// ============================================================================ + +/// Convert OpenAI tools to TypeScript-style declaration string. +/// +/// Produces the format expected by models like Kimi K2.5, whose chat templates +/// check for a `tools_ts_str` variable and prefer it over raw JSON. +/// +/// Example output: +/// ```text +/// # Tools +/// +/// ## functions +/// namespace functions { +/// // Get the current weather +/// type getCurrentWeather = (_: { +/// // The city and state +/// location: string, +/// unit?: "celsius" | "fahrenheit" +/// }) => any; +/// } +/// ``` +pub fn tools_to_typescript(tools: &[Tool]) -> String { + let mut functions = Vec::new(); + + for tool in tools { + if tool.tool_type != "function" { + continue; + } + functions.push(function_to_typescript(&tool.function)); + } + + if functions.is_empty() { + return String::new(); + } + + let mut result = String::from("# Tools\n\n## functions\nnamespace functions {\n"); + result.push_str(&functions.join("\n")); + result.push_str("\n}\n"); + result +} + +fn function_to_typescript( + func: &openai_protocol::common::Function, +) -> String { + let mut out = String::new(); + + // Description comment + if let Some(ref desc) = func.description { + for line in desc.lines() { + if line.is_empty() { + out.push('\n'); + } else { + out.push_str(&format!("// {line}\n")); + } + } + } + + // Parameters + let params_str = if func.parameters.is_null() || func.parameters == Value::Object(Default::default()) { + "{}".to_string() + } else { + schema_to_typescript(&func.parameters, "", &[]) + }; + + out.push_str(&format!("type {} = (_: {params_str}) => any;", func.name)); + out +} + +fn schema_to_typescript(schema: &Value, indent: &str, required: &[&str]) -> String { + match schema.get("type").and_then(|t| t.as_str()) { + Some("object") => object_to_typescript(schema, indent), + Some("array") => array_to_typescript(schema, indent), + Some("string") => { + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + enum_to_typescript(enum_vals) + } else { + "string".to_string() + } + } + Some("integer" | "number") => "number".to_string(), + Some("boolean") => "boolean".to_string(), + Some("null") => "null".to_string(), + _ => { + // Handle enum without type + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + return enum_to_typescript(enum_vals); + } + // Handle anyOf + if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) { + let types: Vec = any_of + .iter() + .map(|t| schema_to_typescript(t, indent, &[])) + .collect(); + return types.join(" | "); + } + "any".to_string() + } + } +} + +fn object_to_typescript(schema: &Value, indent: &str) -> String { + let properties = match schema.get("properties").and_then(|p| p.as_object()) { + Some(p) => p, + None => return "{}".to_string(), + }; + + let required_fields: Vec<&str> = schema + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + + let child_indent = format!("{indent} "); + + // Sort: required first, then optional, both alphabetically + let mut required_params: Vec<(&String, &Value)> = Vec::new(); + let mut optional_params: Vec<(&String, &Value)> = Vec::new(); + + for (name, prop) in properties { + if required_fields.contains(&name.as_str()) { + required_params.push((name, prop)); + } else { + optional_params.push((name, prop)); + } + } + required_params.sort_by_key(|(n, _)| *n); + optional_params.sort_by_key(|(n, _)| *n); + + let mut params: Vec<(&String, &Value, bool)> = Vec::new(); + for (n, v) in required_params { + params.push((n, v, false)); + } + for (n, v) in optional_params { + params.push((n, v, true)); + } + + if params.is_empty() { + return "{}".to_string(); + } + + let mut parts = Vec::new(); + for (name, prop, optional) in ¶ms { + let mut part = String::new(); + + // Description comment + if let Some(desc) = prop.get("description").and_then(|d| d.as_str()) { + for line in desc.lines() { + if line.is_empty() { + part.push('\n'); + } else { + part.push_str(&format!("{child_indent}// {line}\n")); + } + } + } + + let type_str = schema_to_typescript(prop, &child_indent, &[]); + let opt_marker = if *optional { "?" } else { "" }; + part.push_str(&format!("{child_indent}{name}{opt_marker}: {type_str}")); + parts.push(part); + } + + format!("{{\n{}\n{indent}}}", parts.join(",\n")) +} + +fn array_to_typescript(schema: &Value, indent: &str) -> String { + let items = schema.get("items"); + let item_type = match items { + Some(item_schema) => { + let child_indent = format!("{indent} "); + // Check if item has description + let item_desc = item_schema + .get("description") + .and_then(|d| d.as_str()); + let type_str = schema_to_typescript(item_schema, &child_indent, &[]); + + if let Some(desc) = item_desc { + return format!( + "Array<\n{child_indent}// {desc}\n{child_indent}{type_str}\n{indent}>" + ); + } + type_str + } + None => "any".to_string(), + }; + format!("Array<{item_type}>") +} + +fn enum_to_typescript(values: &[Value]) -> String { + let parts: Vec = values + .iter() + .map(|v| match v { + Value::String(s) => format!("\"{s}\""), + Value::Number(n) => n.to_string(), + Value::Bool(b) => b.to_string(), + Value::Null => "null".to_string(), + _ => "any".to_string(), + }) + .collect(); + parts.join(" | ") +} diff --git a/model_gateway/tests/spec/chat_message.rs b/model_gateway/tests/spec/chat_message.rs index 232154a62..92ef8e5c0 100644 --- a/model_gateway/tests/spec/chat_message.rs +++ b/model_gateway/tests/spec/chat_message.rs @@ -66,6 +66,7 @@ fn test_chat_message_tagged_by_role_tool() { ChatMessage::Tool { content, tool_call_id, + .. } => { match content { MessageContent::Text(text) => { From 0a999ea1b40066524f54c8a7f3445c768dd64fe2 Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Wed, 8 Apr 2026 15:16:30 -0700 Subject: [PATCH 05/27] fix(tokenizer): load merged EOS token IDs from config.json + generation_config.json Models like Kimi-K2.5 define different EOS tokens in config.json ([EOS]=163585) and generation_config.json (<|im_end|>=163586). The engine must stop at both. SGLang handles this internally by merging both sources, but TRT-LLM's gRPC path does not. This change: - Adds load_eos_token_ids() to read and merge EOS IDs from both config files - Exposes eos_token_ids() on the Tokenizer trait - Passes merged IDs as stop_token_ids in TRT-LLM gRPC requests - Removes the hardcoded <|im_end|> stop string workaround (bfab5ada) - Strips [EOS]/[BOS] text from constrained decoding output Result: SMG+TRT-LLM fc-dash 55/60 (was 52/60 with hack, 18/60 without) Signed-off-by: ConnorLi96 Made-with: Cursor --- crates/grpc_client/src/trtllm_service.rs | 12 +++++++++++- crates/tokenizer/src/huggingface.rs | 5 +++++ crates/tokenizer/src/mock.rs | 6 +----- crates/tokenizer/src/tiktoken.rs | 15 ++++----------- crates/tokenizer/src/traits.rs | 11 +++++++++++ model_gateway/src/routers/grpc/client.rs | 2 ++ .../grpc/harmony/stages/request_building.rs | 5 +++++ .../grpc/regular/stages/chat/request_building.rs | 6 ++++++ .../src/routers/grpc/regular/streaming.rs | 4 ++-- .../src/routers/grpc/utils/chat_utils.rs | 4 +++- 10 files changed, 50 insertions(+), 20 deletions(-) diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index cc4d6ac9e..ba07bfff9 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -273,6 +273,7 @@ impl TrtllmServiceClient { token_ids: Vec, multimodal_input: Option, tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) + eos_token_ids: &[u32], ) -> Result { // Build sampling config let sampling_config = Self::build_sampling_config_from_chat(body); @@ -287,6 +288,15 @@ impl TrtllmServiceClient { let max_tokens = body.max_completion_tokens.unwrap_or(2048); + // Pass merged EOS token IDs from config.json + generation_config.json. + // TRT-LLM's gRPC path does not reliably merge these internally, + // so we provide them explicitly via the standard stop_token_ids field. + let stop_token_ids: Vec = if body.ignore_eos { + vec![] + } else { + eos_token_ids.to_vec() + }; + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -299,7 +309,7 @@ impl TrtllmServiceClient { max_tokens, streaming: body.stream, stop, - stop_token_ids: vec![], + stop_token_ids, ignore_eos: body.ignore_eos, bad: vec![], bad_token_ids: vec![], diff --git a/crates/tokenizer/src/huggingface.rs b/crates/tokenizer/src/huggingface.rs index e13ca1b00..dd6f18907 100644 --- a/crates/tokenizer/src/huggingface.rs +++ b/crates/tokenizer/src/huggingface.rs @@ -236,6 +236,7 @@ impl HuggingFaceTokenizer { cls_token: find_token(&["[CLS]", "", ""]), mask_token: find_token(&["[MASK]", "", ""]), additional_special_tokens, + ..Default::default() } } @@ -437,6 +438,10 @@ impl TokenizerTrait for HuggingFaceTokenizer { } } + fn eos_token_ids(&self) -> &[TokenIdType] { + &self.special_tokens.eos_token_ids + } + fn set_chat_template(&mut self, template: String) -> Result<()> { self.chat_template.set(template) } diff --git a/crates/tokenizer/src/mock.rs b/crates/tokenizer/src/mock.rs index f057aabe2..ad2a6b636 100644 --- a/crates/tokenizer/src/mock.rs +++ b/crates/tokenizer/src/mock.rs @@ -51,11 +51,7 @@ impl MockTokenizer { bos_token: Some("".to_string()), eos_token: Some("".to_string()), unk_token: Some("".to_string()), - sep_token: None, - pad_token: None, - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }; Self { diff --git a/crates/tokenizer/src/tiktoken.rs b/crates/tokenizer/src/tiktoken.rs index 8c8810715..c55e4d79f 100644 --- a/crates/tokenizer/src/tiktoken.rs +++ b/crates/tokenizer/src/tiktoken.rs @@ -119,6 +119,7 @@ fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens { cls_token: get_str("cls_token"), mask_token: get_str("mask_token"), additional_special_tokens: additional, + ..Default::default() } } @@ -261,12 +262,11 @@ impl TiktokenTokenizer { }) }; - // Load merged EOS token IDs from config.json + generation_config.json let eos_token_ids = crate::eos::load_eos_token_ids(dir); Ok(TiktokenTokenizer { tokenizer, - special_tokens: config.special_tokens, + special_tokens, vocab, reverse_vocab, vocab_size, @@ -314,27 +314,20 @@ impl TiktokenTokenizer { TiktokenModel::Cl100kBase => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), - unk_token: None, - sep_token: None, pad_token: Some("<|endoftext|>".to_string()), - cls_token: None, - mask_token: None, additional_special_tokens: vec![ "<|fim_prefix|>".to_string(), "<|fim_middle|>".to_string(), "<|fim_suffix|>".to_string(), "<|endofprompt|>".to_string(), ], + ..Default::default() }, _ => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), - unk_token: None, - sep_token: None, pad_token: Some("<|endoftext|>".to_string()), - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }, } } diff --git a/crates/tokenizer/src/traits.rs b/crates/tokenizer/src/traits.rs index c645d5910..5ffb95990 100644 --- a/crates/tokenizer/src/traits.rs +++ b/crates/tokenizer/src/traits.rs @@ -114,6 +114,12 @@ pub trait Tokenizer: Encoder + Decoder { false } + /// Merged EOS token IDs from config.json and generation_config.json. + /// Backends should stop generation when any of these tokens is produced. + fn eos_token_ids(&self) -> &[TokenIdType] { + &[] + } + /// Set or override the chat template. /// /// Returns an error if the template fails to parse or the tokenizer @@ -184,4 +190,9 @@ pub struct SpecialTokens { pub cls_token: Option, pub mask_token: Option, pub additional_special_tokens: Vec, + /// Merged EOS token IDs from config.json + generation_config.json. + /// Models like Kimi-K2.5 define different EOS tokens in each file + /// (e.g. `[EOS]` in config.json, `<|im_end|>` in generation_config.json). + /// The engine must stop at any of these tokens. + pub eos_token_ids: Vec, } diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 81cfc8f11..dce1302fe 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -324,6 +324,7 @@ impl GrpcClient { token_ids: Vec, multimodal_inputs: Option, tool_constraints: Option<(String, String)>, + eos_token_ids: &[u32], ) -> Result { match self { Self::Sglang(client) => { @@ -368,6 +369,7 @@ impl GrpcClient { token_ids, trtllm_mm, tool_constraints, + eos_token_ids, )?; Ok(ProtoGenerateRequest::Trtllm(Box::new(req))) } diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index d084d66f3..ef6567b7b 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -189,6 +189,10 @@ impl PipelineStage for HarmonyRequestBuildingStage { ProtoGenerateRequest::Vllm(Box::new(req)) } GrpcClient::Trtllm(trtllm_client) => { + let eos_ids = ctx + .tokenizer_arc() + .map(|t| t.eos_token_ids().to_vec()) + .unwrap_or_default(); let req = match &ctx.input.request_type { RequestType::Chat(request) => { let body = modified_request.as_deref().unwrap_or_else(|| request.as_ref()); @@ -200,6 +204,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { token_ids, None, // No multimodal in Harmony pipeline tool_constraints, + &eos_ids, ) .map_err(|e| { error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TensorRT-LLM generate request"); diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index aadcc89b3..02e538977 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -88,6 +88,11 @@ impl PipelineStage for ChatRequestBuildingStage { .multimodal_intermediate .map(|intermediate| assemble_multimodal_data(intermediate, builder_client)); + let eos_token_ids = ctx + .tokenizer_arc() + .map(|t| t.eos_token_ids().to_vec()) + .unwrap_or_default(); + let mut proto_request = builder_client .build_chat_request( request_id, @@ -96,6 +101,7 @@ impl PipelineStage for ChatRequestBuildingStage { token_ids, multimodal_data, tool_constraints, + &eos_token_ids, ) .map_err(|e| { error!(function = "ChatRequestBuildingStage::execute", error = %e, "Failed to build generate request"); diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 39ff8fc2b..db3bab3bd 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -466,7 +466,7 @@ impl StreamingProcessor { for token in [ "<|im_end|>", "<|im_start|>", "<|im_user|>", "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", - "", + "", "[EOS]", "[BOS]", ] { delta = delta.replace(token, ""); } @@ -1243,7 +1243,7 @@ impl StreamingProcessor { "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", - "", + "", "[EOS]", "[BOS]", ] { clean_delta = clean_delta.replace(token, ""); } diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index 5bbb29809..f205723b1 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -441,7 +441,7 @@ pub(crate) fn parse_json_schema_response( model: &str, history_tool_calls_count: usize, ) -> (Option>, String) { - // Strip chatml tokens that may trail the constrained JSON output + // Strip chatml / special tokens that may trail the constrained JSON output let mut clean = processed_text.to_string(); for token in [ "<|im_end|>", @@ -451,6 +451,8 @@ pub(crate) fn parse_json_schema_response( "<|im_system|>", "<|im_middle|>", "", + "[EOS]", + "[BOS]", ] { clean = clean.replace(token, ""); } From efc36fa761fe2e36940d9f703d7f389bf3fffe8b Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Thu, 9 Apr 2026 21:42:08 -0700 Subject: [PATCH 06/27] fix(tokenizer): reduce tiktoken partial UTF-8 decode log from warn to debug During streaming detokenization, multi-byte characters (CJK, emoji) are split across token boundaries. Decoding a partial token sequence produces incomplete UTF-8, which is expected and handled correctly by the lossy fallback + streaming buffer. This was logging at WARN level, flooding production logs on every request containing non-ASCII output. Made-with: Cursor --- crates/tokenizer/src/tiktoken.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/tokenizer/src/tiktoken.rs b/crates/tokenizer/src/tiktoken.rs index c55e4d79f..57478ebcc 100644 --- a/crates/tokenizer/src/tiktoken.rs +++ b/crates/tokenizer/src/tiktoken.rs @@ -459,7 +459,7 @@ impl Decoder for TiktokenTokenizer { ._decode_native_and_split(token_ids.to_vec()) .flatten() .collect(); - tracing::warn!( + tracing::debug!( error = %err, token_count = token_ids.len(), "tiktoken decode failed; returning lossy UTF-8 fallback" From 35de058be686d307125084594784e6d5115f9731 Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Fri, 10 Apr 2026 18:59:13 -0700 Subject: [PATCH 07/27] feat(protocol): add thinking param to Chat API and support bare string image_url 1. Add Anthropic-style `thinking` field to ChatCompletionRequest: `{"thinking": {"type": "enabled", "budget_tokens": N}}` or `{"thinking": {"type": "disabled"}}`. Normalized to chat_template_kwargs (both `thinking` and `enable_thinking` keys) so it works with Kimi-K2.5 and Qwen3 templates. 2. Support bare string image_url format: accept both `"image_url": "https://..."` and `"image_url": {"url": "https://..."}` via custom serde deserializer on ImageUrl. Made-with: Cursor --- crates/protocols/src/chat.rs | 20 +++++++++++++ crates/protocols/src/common.rs | 54 ++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 2 deletions(-) diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index 9b6faf946..b12c811e3 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -11,6 +11,7 @@ use super::{ StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference, Usage, }, + messages::ThinkingConfig, sampling_params::{validate_top_k_value, validate_top_p_value}, }; use crate::{ @@ -202,6 +203,10 @@ pub struct ChatCompletionRequest { /// Effort level for reasoning models (low, medium, high) pub reasoning_effort: Option, + /// Configuration for extended thinking (Anthropic-style). + /// Maps to chat_template_kwargs for thinking-capable models. + pub thinking: Option, + /// An object specifying the format that the model must output pub response_format: Option, @@ -568,6 +573,21 @@ impl Normalizable for ChatCompletionRequest { self.function_call = None; // Clear deprecated field } + // Migrate thinking → chat_template_kwargs + if let Some(ref thinking) = self.thinking { + let kwargs = self.chat_template_kwargs.get_or_insert_with(HashMap::new); + match thinking { + ThinkingConfig::Enabled { .. } => { + kwargs.entry("enable_thinking".to_string()).or_insert(Value::Bool(true)); + kwargs.entry("thinking".to_string()).or_insert(Value::Bool(true)); + } + ThinkingConfig::Disabled => { + kwargs.entry("enable_thinking".to_string()).or_insert(Value::Bool(false)); + kwargs.entry("thinking".to_string()).or_insert(Value::Bool(false)); + } + } + } + // Apply tool_choice defaults if self.tool_choice.is_none() { if let Some(tools) = &self.tools { diff --git a/crates/protocols/src/common.rs b/crates/protocols/src/common.rs index 9d2688c9c..e3dfdbdb9 100644 --- a/crates/protocols/src/common.rs +++ b/crates/protocols/src/common.rs @@ -195,13 +195,63 @@ pub enum ContentPart { VideoUrl { video_url: VideoUrl }, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)] +#[derive(Debug, Clone, PartialEq, schemars::JsonSchema)] pub struct ImageUrl { pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] pub detail: Option, // "auto", "low", or "high" } +impl Serialize for ImageUrl { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let field_count = 1 + self.detail.is_some() as usize; + let mut state = serializer.serialize_struct("ImageUrl", field_count)?; + state.serialize_field("url", &self.url)?; + if let Some(ref detail) = self.detail { + state.serialize_field("detail", detail)?; + } + state.end() + } +} + +impl<'de> Deserialize<'de> for ImageUrl { + fn deserialize>(deserializer: D) -> Result { + use serde::de; + + struct ImageUrlVisitor; + + impl<'de> de::Visitor<'de> for ImageUrlVisitor { + type Value = ImageUrl; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string URL or an object with url field") + } + + fn visit_str(self, v: &str) -> Result { + Ok(ImageUrl { url: v.to_string(), detail: None }) + } + + fn visit_map>(self, mut map: M) -> Result { + let mut url = None; + let mut detail = None; + while let Some(key) = map.next_key::()? { + match key.as_str() { + "url" => url = Some(map.next_value()?), + "detail" => detail = map.next_value()?, + _ => { let _ = map.next_value::()?; } + } + } + Ok(ImageUrl { + url: url.ok_or_else(|| de::Error::missing_field("url"))?, + detail, + }) + } + } + + deserializer.deserialize_any(ImageUrlVisitor) + } +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)] pub struct VideoUrl { pub url: String, From 2d10dbb0ef38cbd85e64f694f73da8232775cade Mon Sep 17 00:00:00 2001 From: ConnorLi96 Date: Sat, 18 Apr 2026 07:37:02 -0700 Subject: [PATCH 08/27] fix(reasoning): run reasoning parser before JSON/tool post-processing when thinking is active MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When separate_reasoning is true the chat template injects in the prompt, so the model emits thinking content regardless of backend constraints. TRT-LLM does not have TGL's ReasonerGrammarObject that defers JSON constraint enforcement until after , so the raw output contains ... followed by the actual JSON or tool-call tokens. The existing output_is_constrained gate unconditionally skipped the reasoning parser for response_format json_object/json_schema and tool_choice Function, assuming constrained decoding produces pure JSON. This left thinking content in the response, breaking structured output parsing (0/3 on all 4 structured-output tests) and tool_choice streaming (0/3). Changes: - Gate output_is_constrained on !separate_reasoning so the reasoning parser runs when thinking is expected (processor.rs + streaming.rs) - Use native tool parser instead of JSON parser for streaming tool_choice Function when reasoning is active, since the backend may emit native tool-call tags rather than constrained JSON - Strip markdown code fences (```json / ```) from streaming JSON responses, including residual language tags split across chunk boundaries Tested: all 5 previously-failing tests now pass 3/3 (47/58 → 52/58). Made-with: Cursor --- .../src/routers/grpc/regular/processor.rs | 11 ++-- .../src/routers/grpc/regular/streaming.rs | 52 +++++++++++++++---- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index a55fdc13d..7cefb6068 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -96,12 +96,15 @@ impl ResponseProcessor { let mut reasoning_text: Option = None; let mut processed_text = final_text; - if original_request.separate_reasoning - && reasoning_parser_available - && !utils::has_constrained_output( + let output_is_constrained = !original_request.separate_reasoning + && utils::has_constrained_output( original_request.tool_choice.as_ref(), original_request.response_format.as_ref(), - ) + ); + + if original_request.separate_reasoning + && reasoning_parser_available + && !output_is_constrained { let pooled_parser = utils::get_reasoning_parser( &self.reasoning_parser_factory, diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index db3bab3bd..56f357e6f 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -277,15 +277,17 @@ impl StreamingProcessor { let is_specific_function = used_json_schema && matches!(tool_choice, Some(ToolChoice::Function { .. })); - // Skip reasoning parsing when constrained decoding is active. - // The model emits pure JSON without wrappers, so the - // reasoning parser would swallow the output as reasoning content. - let output_is_constrained = is_specific_function - || matches!( - &original_request.response_format, - Some(openai_protocol::common::ResponseFormat::JsonObject) - | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) - ); + // Skip reasoning parsing when constrained decoding is active AND + // reasoning is not expected. When separate_reasoning is true the + // chat template injects , so the model emits thinking + // content even with constraints. The reasoning parser must run. + let output_is_constrained = !separate_reasoning + && (is_specific_function + || matches!( + &original_request.response_format, + Some(openai_protocol::common::ResponseFormat::JsonObject) + | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + )); let tool_parser_available = tools.is_some() && utils::check_tool_parser_availability( @@ -430,6 +432,14 @@ impl StreamingProcessor { history_tool_calls_count, ) } else { + // When reasoning is active the backend may + // not enforce constrained decoding, so the + // model can emit native tool-call tags. + // Use the native parser instead of JSON. + let effective_json_parser = used_json_schema + && !(separate_reasoning + && self.configured_tool_parser.is_some() + && tool_parser_available); self.process_tool_calls_stream( &delta, index, @@ -441,7 +451,7 @@ impl StreamingProcessor { created, system_fingerprint, history_tool_calls_count, - used_json_schema, + effective_json_parser, ) .await }; @@ -471,6 +481,28 @@ impl StreamingProcessor { delta = delta.replace(token, ""); } } + + // Strip markdown code fences for JSON responses so + // streamed content is directly parseable. + let is_json_response = matches!( + &original_request.response_format, + Some(openai_protocol::common::ResponseFormat::JsonObject) + | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + ); + if is_json_response { + delta = delta + .replace("```json", "") + .replace("```JSON", "") + .replace("```", ""); + // Fence tokens may split across chunks (e.g. + // "```" in one delta, "json\n" in the next). + // Drop residual language tags left over. + let trimmed = delta.trim(); + if trimmed == "json" || trimmed == "JSON" { + delta = String::new(); + } + } + if !delta.is_empty() { let content_chunk = ChatCompletionStreamResponse::builder(request_id, model) From 849382a71b6c0d91609d07b81ae219dbc7bf2dcd Mon Sep 17 00:00:00 2001 From: Connor Li Date: Wed, 22 Apr 2026 17:37:46 -0700 Subject: [PATCH 09/27] style: fix formatting, clippy warnings, and merge artifacts from cherry-pick Signed-off-by: Connor Li --- crates/grpc_client/src/trtllm_service.rs | 4 + crates/multimodal/src/registry/mod.rs | 9 +- crates/protocols/src/chat.rs | 16 +- crates/protocols/src/common.rs | 9 +- crates/tokenizer/src/chat_template.rs | 5 +- crates/tokenizer/src/huggingface.rs | 4 - crates/tokenizer/src/tiktoken.rs | 2 +- crates/tokenizer/src/traits.rs | 8 - crates/tool_parser/src/parsers/kimik2.rs | 12 +- model_gateway/src/routers/grpc/client.rs | 4 + .../src/routers/grpc/regular/processor.rs | 45 +- .../src/routers/grpc/regular/streaming.rs | 26 +- .../src/routers/grpc/utils/chat_utils.rs | 413 +++++++++--------- .../src/routers/grpc/utils/parsers.rs | 5 +- 14 files changed, 295 insertions(+), 267 deletions(-) diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index ba07bfff9..81b3bb4c7 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -265,6 +265,10 @@ impl TrtllmServiceClient { clippy::unused_self, reason = "method receiver kept for consistent public API across gRPC backends" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder requires all fields for the proto message" + )] pub fn build_generate_request_from_chat( &self, request_id: String, diff --git a/crates/multimodal/src/registry/mod.rs b/crates/multimodal/src/registry/mod.rs index 6ddb366da..a70ddfc61 100644 --- a/crates/multimodal/src/registry/mod.rs +++ b/crates/multimodal/src/registry/mod.rs @@ -126,14 +126,7 @@ pub(super) mod test_helpers { fn get_special_tokens(&self) -> &SpecialTokens { static TOKENS: Lazy = Lazy::new(|| SpecialTokens { - bos_token: None, - eos_token: None, - unk_token: None, - sep_token: None, - pad_token: None, - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }); &TOKENS } diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index b12c811e3..1a7919523 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -578,12 +578,20 @@ impl Normalizable for ChatCompletionRequest { let kwargs = self.chat_template_kwargs.get_or_insert_with(HashMap::new); match thinking { ThinkingConfig::Enabled { .. } => { - kwargs.entry("enable_thinking".to_string()).or_insert(Value::Bool(true)); - kwargs.entry("thinking".to_string()).or_insert(Value::Bool(true)); + kwargs + .entry("enable_thinking".to_string()) + .or_insert(Value::Bool(true)); + kwargs + .entry("thinking".to_string()) + .or_insert(Value::Bool(true)); } ThinkingConfig::Disabled => { - kwargs.entry("enable_thinking".to_string()).or_insert(Value::Bool(false)); - kwargs.entry("thinking".to_string()).or_insert(Value::Bool(false)); + kwargs + .entry("enable_thinking".to_string()) + .or_insert(Value::Bool(false)); + kwargs + .entry("thinking".to_string()) + .or_insert(Value::Bool(false)); } } } diff --git a/crates/protocols/src/common.rs b/crates/protocols/src/common.rs index e3dfdbdb9..08cac4a87 100644 --- a/crates/protocols/src/common.rs +++ b/crates/protocols/src/common.rs @@ -228,7 +228,10 @@ impl<'de> Deserialize<'de> for ImageUrl { } fn visit_str(self, v: &str) -> Result { - Ok(ImageUrl { url: v.to_string(), detail: None }) + Ok(ImageUrl { + url: v.to_string(), + detail: None, + }) } fn visit_map>(self, mut map: M) -> Result { @@ -238,7 +241,9 @@ impl<'de> Deserialize<'de> for ImageUrl { match key.as_str() { "url" => url = Some(map.next_value()?), "detail" => detail = map.next_value()?, - _ => { let _ = map.next_value::()?; } + _ => { + let _ = map.next_value::()?; + } } } Ok(ImageUrl { diff --git a/crates/tokenizer/src/chat_template.rs b/crates/tokenizer/src/chat_template.rs index 083a7f7d1..7a45a9273 100644 --- a/crates/tokenizer/src/chat_template.rs +++ b/crates/tokenizer/src/chat_template.rs @@ -602,10 +602,7 @@ fn build_environment(template: String) -> Result> { env.add_function( "raise_exception", |msg: String| -> Result { - Err(minijinja::Error::new( - minijinja::ErrorKind::InvalidOperation, - msg, - )) + Err(MinijinjaError::new(ErrorKind::InvalidOperation, msg)) }, ); diff --git a/crates/tokenizer/src/huggingface.rs b/crates/tokenizer/src/huggingface.rs index dd6f18907..fc3cbeebc 100644 --- a/crates/tokenizer/src/huggingface.rs +++ b/crates/tokenizer/src/huggingface.rs @@ -438,10 +438,6 @@ impl TokenizerTrait for HuggingFaceTokenizer { } } - fn eos_token_ids(&self) -> &[TokenIdType] { - &self.special_tokens.eos_token_ids - } - fn set_chat_template(&mut self, template: String) -> Result<()> { self.chat_template.set(template) } diff --git a/crates/tokenizer/src/tiktoken.rs b/crates/tokenizer/src/tiktoken.rs index 57478ebcc..e56a34ddf 100644 --- a/crates/tokenizer/src/tiktoken.rs +++ b/crates/tokenizer/src/tiktoken.rs @@ -266,7 +266,7 @@ impl TiktokenTokenizer { Ok(TiktokenTokenizer { tokenizer, - special_tokens, + special_tokens: config.special_tokens, vocab, reverse_vocab, vocab_size, diff --git a/crates/tokenizer/src/traits.rs b/crates/tokenizer/src/traits.rs index 5ffb95990..12ae7ab4d 100644 --- a/crates/tokenizer/src/traits.rs +++ b/crates/tokenizer/src/traits.rs @@ -129,14 +129,6 @@ pub trait Tokenizer: Encoder + Decoder { "set_chat_template is not supported by this tokenizer" )) } - - /// EOS token IDs for stop detection. - /// - /// Merged from `config.json` and `generation_config.json` (eos_token_id, int or list). - /// Models can have multiple EOS tokens (e.g., Llama 3: end_of_text + eom_id + eot_id). - fn eos_token_ids(&self) -> &[TokenIdType] { - &[] - } } /// Contains the results of tokenizing text: token IDs, string tokens, and their spans diff --git a/crates/tool_parser/src/parsers/kimik2.rs b/crates/tool_parser/src/parsers/kimik2.rs index 1eb3dfd88..d5af43791 100644 --- a/crates/tool_parser/src/parsers/kimik2.rs +++ b/crates/tool_parser/src/parsers/kimik2.rs @@ -229,7 +229,11 @@ impl ToolParser for KimiK2Parser { // No tool markers detected - return all buffered content as normal text let mut normal_text = std::mem::take(&mut self.buffer); // Remove end tokens if present - for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>", "<|func_end|>"] { + for e_token in [ + "<|tool_calls_section_end|>", + "<|tool_call_end|>", + "<|func_end|>", + ] { normal_text = normal_text.replace(e_token, ""); } return Ok(StreamingParseResult { @@ -291,7 +295,8 @@ impl ToolParser for KimiK2Parser { }; // Split by end token before sending - let end_pos = argument_diff.find("<|tool_call_end|>") + let end_pos = argument_diff + .find("<|tool_call_end|>") .or_else(|| argument_diff.find("<|func_end|>")); let parsed_args_diff = if let Some(pos) = end_pos { &argument_diff[..pos] @@ -314,7 +319,8 @@ impl ToolParser for KimiK2Parser { } // Check completeness - split by end token first - let end_pos2 = function_args.find("<|tool_call_end|>") + let end_pos2 = function_args + .find("<|tool_call_end|>") .or_else(|| function_args.find("<|func_end|>")); let parsed_args = if let Some(pos) = end_pos2 { &function_args[..pos] diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index dce1302fe..4cd9c20c6 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -316,6 +316,10 @@ impl GrpcClient { clippy::unreachable, reason = "assembly stage guarantees matching MultimodalData variant for each backend" )] + #[expect( + clippy::too_many_arguments, + reason = "request builder requires all fields for each backend's proto message" + )] pub fn build_chat_request( &self, request_id: String, diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 7cefb6068..e24b0427f 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -175,6 +175,7 @@ impl ResponseProcessor { &processed_text, &original_request.model, history_tool_calls_count, + original_request.tools.as_deref(), ) .await; } @@ -255,13 +256,29 @@ impl ResponseProcessor { let mut escape = false; let mut json_end = None; for (i, ch) in processed_text.char_indices() { - if escape { escape = false; continue; } - if ch == '\\' && in_string { escape = true; continue; } - if ch == '"' { in_string = !in_string; continue; } - if in_string { continue; } - if ch == '{' { depth += 1; } else if ch == '}' { + if escape { + escape = false; + continue; + } + if ch == '\\' && in_string { + escape = true; + continue; + } + if ch == '"' { + in_string = !in_string; + continue; + } + if in_string { + continue; + } + if ch == '{' { + depth += 1; + } else if ch == '}' { depth -= 1; - if depth == 0 { json_end = Some(i + 1); break; } + if depth == 0 { + json_end = Some(i + 1); + break; + } } } if let Some(end) = json_end { @@ -726,6 +743,7 @@ impl ResponseProcessor { utils::message_utils::get_history_tool_calls_count_messages( &messages_request, ), + None, ) .await; } else if used_json_schema { @@ -1018,15 +1036,18 @@ fn coerce_tool_args_to_schema( _ => None, }, Some("number") => match &val { - serde_json::Value::String(s) => { - s.parse::().ok().and_then(serde_json::Number::from_f64).map(serde_json::Value::Number) - } + serde_json::Value::String(s) => s + .parse::() + .ok() + .and_then(serde_json::Number::from_f64) + .map(serde_json::Value::Number), _ => None, }, Some("integer") => match &val { - serde_json::Value::String(s) => { - s.parse::().ok().map(|n| serde_json::Value::Number(n.into())) - } + serde_json::Value::String(s) => s + .parse::() + .ok() + .map(|n| serde_json::Value::Number(n.into())), _ => None, }, Some("array") => match &val { diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 56f357e6f..36fd109cb 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -468,15 +468,19 @@ impl StreamingProcessor { } } - // Strip leaked chatml/think tokens when a parser is configured - let mut delta = delta; if self.configured_tool_parser.is_some() || self.configured_reasoning_parser.is_some() { for token in [ - "<|im_end|>", "<|im_start|>", "<|im_user|>", - "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", - "", "[EOS]", "[BOS]", + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + "[EOS]", + "[BOS]", ] { delta = delta.replace(token, ""); } @@ -1275,7 +1279,9 @@ impl StreamingProcessor { "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", - "", "[EOS]", "[BOS]", + "", + "[EOS]", + "[BOS]", ] { clean_delta = clean_delta.replace(token, ""); } @@ -1341,8 +1347,12 @@ impl StreamingProcessor { || self.configured_reasoning_parser.is_some() { for token in [ - "<|im_end|>", "<|im_start|>", "<|im_user|>", - "<|im_assistant|>", "<|im_system|>", "<|im_middle|>", + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", ] { normal_text = normal_text.replace(token, ""); } diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index f205723b1..230d5c9ce 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -110,9 +110,7 @@ pub(crate) fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), }; if serde_json::from_str::(args_str).is_err() { - return Err(format!( - "Invalid JSON in tool call arguments: '{args_str}'" - )); + return Err(format!("Invalid JSON in tool call arguments: '{args_str}'")); } } } @@ -301,10 +299,8 @@ pub fn process_chat_messages( if !tools.is_empty() { let ts_str = tools_to_typescript(tools); if !ts_str.is_empty() { - combined_template_kwargs.insert( - "tools_ts_str".to_string(), - Value::String(ts_str), - ); + combined_template_kwargs + .insert("tools_ts_str".to_string(), Value::String(ts_str)); } } } @@ -612,6 +608,206 @@ pub(crate) fn parse_finish_reason( } } +// ============================================================================ +// TypeScript-style tool declaration generator +// ============================================================================ + +/// Convert OpenAI tools to TypeScript-style declaration string. +/// +/// Produces the format expected by models like Kimi K2.5, whose chat templates +/// check for a `tools_ts_str` variable and prefer it over raw JSON. +/// +/// Example output: +/// ```text +/// # Tools +/// +/// ## functions +/// namespace functions { +/// // Get the current weather +/// type getCurrentWeather = (_: { +/// // The city and state +/// location: string, +/// unit?: "celsius" | "fahrenheit" +/// }) => any; +/// } +/// ``` +pub fn tools_to_typescript(tools: &[Tool]) -> String { + let mut functions = Vec::new(); + + for tool in tools { + if tool.tool_type != "function" { + continue; + } + functions.push(function_to_typescript(&tool.function)); + } + + if functions.is_empty() { + return String::new(); + } + + let mut result = String::from("# Tools\n\n## functions\nnamespace functions {\n"); + result.push_str(&functions.join("\n")); + result.push_str("\n}\n"); + result +} + +fn function_to_typescript(func: &openai_protocol::common::Function) -> String { + let mut out = String::new(); + + // Description comment + if let Some(ref desc) = func.description { + for line in desc.lines() { + if line.is_empty() { + out.push('\n'); + } else { + out.push_str(&format!("// {line}\n")); + } + } + } + + // Parameters + let params_str = + if func.parameters.is_null() || func.parameters == Value::Object(Default::default()) { + "{}".to_string() + } else { + schema_to_typescript(&func.parameters, "", &[]) + }; + + out.push_str(&format!("type {} = (_: {params_str}) => any;", func.name)); + out +} + +fn schema_to_typescript(schema: &Value, indent: &str, _required: &[&str]) -> String { + match schema.get("type").and_then(|t| t.as_str()) { + Some("object") => object_to_typescript(schema, indent), + Some("array") => array_to_typescript(schema, indent), + Some("string") => { + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + enum_to_typescript(enum_vals) + } else { + "string".to_string() + } + } + Some("integer" | "number") => "number".to_string(), + Some("boolean") => "boolean".to_string(), + Some("null") => "null".to_string(), + _ => { + // Handle enum without type + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + return enum_to_typescript(enum_vals); + } + // Handle anyOf + if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) { + let types: Vec = any_of + .iter() + .map(|t| schema_to_typescript(t, indent, &[])) + .collect(); + return types.join(" | "); + } + "any".to_string() + } + } +} + +fn object_to_typescript(schema: &Value, indent: &str) -> String { + let properties = match schema.get("properties").and_then(|p| p.as_object()) { + Some(p) => p, + None => return "{}".to_string(), + }; + + let required_fields: Vec<&str> = schema + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + + let child_indent = format!("{indent} "); + + // Sort: required first, then optional, both alphabetically + let mut required_params: Vec<(&String, &Value)> = Vec::new(); + let mut optional_params: Vec<(&String, &Value)> = Vec::new(); + + for (name, prop) in properties { + if required_fields.contains(&name.as_str()) { + required_params.push((name, prop)); + } else { + optional_params.push((name, prop)); + } + } + required_params.sort_by_key(|(n, _)| *n); + optional_params.sort_by_key(|(n, _)| *n); + + let mut params: Vec<(&String, &Value, bool)> = Vec::new(); + for (n, v) in required_params { + params.push((n, v, false)); + } + for (n, v) in optional_params { + params.push((n, v, true)); + } + + if params.is_empty() { + return "{}".to_string(); + } + + let mut parts = Vec::new(); + for (name, prop, optional) in ¶ms { + let mut part = String::new(); + + // Description comment + if let Some(desc) = prop.get("description").and_then(|d| d.as_str()) { + for line in desc.lines() { + if line.is_empty() { + part.push('\n'); + } else { + part.push_str(&format!("{child_indent}// {line}\n")); + } + } + } + + let type_str = schema_to_typescript(prop, &child_indent, &[]); + let opt_marker = if *optional { "?" } else { "" }; + part.push_str(&format!("{child_indent}{name}{opt_marker}: {type_str}")); + parts.push(part); + } + + format!("{{\n{}\n{indent}}}", parts.join(",\n")) +} + +fn array_to_typescript(schema: &Value, indent: &str) -> String { + let items = schema.get("items"); + let item_type = match items { + Some(item_schema) => { + let child_indent = format!("{indent} "); + // Check if item has description + let item_desc = item_schema.get("description").and_then(|d| d.as_str()); + let type_str = schema_to_typescript(item_schema, &child_indent, &[]); + + if let Some(desc) = item_desc { + return format!( + "Array<\n{child_indent}// {desc}\n{child_indent}{type_str}\n{indent}>" + ); + } + type_str + } + None => "any".to_string(), + }; + format!("Array<{item_type}>") +} + +fn enum_to_typescript(values: &[Value]) -> String { + let parts: Vec = values + .iter() + .map(|v| match v { + Value::String(s) => format!("\"{s}\""), + Value::Number(n) => n.to_string(), + Value::Bool(b) => b.to_string(), + Value::Null => "null".to_string(), + _ => "any".to_string(), + }) + .collect(); + parts.join(" | ") +} + #[cfg(test)] mod tests { use llm_tokenizer::chat_template::ChatTemplateContentFormat; @@ -813,206 +1009,3 @@ mod tests { assert_eq!(content_array[1], json!({"type": "image"})); } } - -// ============================================================================ -// TypeScript-style tool declaration generator -// ============================================================================ - -/// Convert OpenAI tools to TypeScript-style declaration string. -/// -/// Produces the format expected by models like Kimi K2.5, whose chat templates -/// check for a `tools_ts_str` variable and prefer it over raw JSON. -/// -/// Example output: -/// ```text -/// # Tools -/// -/// ## functions -/// namespace functions { -/// // Get the current weather -/// type getCurrentWeather = (_: { -/// // The city and state -/// location: string, -/// unit?: "celsius" | "fahrenheit" -/// }) => any; -/// } -/// ``` -pub fn tools_to_typescript(tools: &[Tool]) -> String { - let mut functions = Vec::new(); - - for tool in tools { - if tool.tool_type != "function" { - continue; - } - functions.push(function_to_typescript(&tool.function)); - } - - if functions.is_empty() { - return String::new(); - } - - let mut result = String::from("# Tools\n\n## functions\nnamespace functions {\n"); - result.push_str(&functions.join("\n")); - result.push_str("\n}\n"); - result -} - -fn function_to_typescript( - func: &openai_protocol::common::Function, -) -> String { - let mut out = String::new(); - - // Description comment - if let Some(ref desc) = func.description { - for line in desc.lines() { - if line.is_empty() { - out.push('\n'); - } else { - out.push_str(&format!("// {line}\n")); - } - } - } - - // Parameters - let params_str = if func.parameters.is_null() || func.parameters == Value::Object(Default::default()) { - "{}".to_string() - } else { - schema_to_typescript(&func.parameters, "", &[]) - }; - - out.push_str(&format!("type {} = (_: {params_str}) => any;", func.name)); - out -} - -fn schema_to_typescript(schema: &Value, indent: &str, required: &[&str]) -> String { - match schema.get("type").and_then(|t| t.as_str()) { - Some("object") => object_to_typescript(schema, indent), - Some("array") => array_to_typescript(schema, indent), - Some("string") => { - if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { - enum_to_typescript(enum_vals) - } else { - "string".to_string() - } - } - Some("integer" | "number") => "number".to_string(), - Some("boolean") => "boolean".to_string(), - Some("null") => "null".to_string(), - _ => { - // Handle enum without type - if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { - return enum_to_typescript(enum_vals); - } - // Handle anyOf - if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) { - let types: Vec = any_of - .iter() - .map(|t| schema_to_typescript(t, indent, &[])) - .collect(); - return types.join(" | "); - } - "any".to_string() - } - } -} - -fn object_to_typescript(schema: &Value, indent: &str) -> String { - let properties = match schema.get("properties").and_then(|p| p.as_object()) { - Some(p) => p, - None => return "{}".to_string(), - }; - - let required_fields: Vec<&str> = schema - .get("required") - .and_then(|r| r.as_array()) - .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) - .unwrap_or_default(); - - let child_indent = format!("{indent} "); - - // Sort: required first, then optional, both alphabetically - let mut required_params: Vec<(&String, &Value)> = Vec::new(); - let mut optional_params: Vec<(&String, &Value)> = Vec::new(); - - for (name, prop) in properties { - if required_fields.contains(&name.as_str()) { - required_params.push((name, prop)); - } else { - optional_params.push((name, prop)); - } - } - required_params.sort_by_key(|(n, _)| *n); - optional_params.sort_by_key(|(n, _)| *n); - - let mut params: Vec<(&String, &Value, bool)> = Vec::new(); - for (n, v) in required_params { - params.push((n, v, false)); - } - for (n, v) in optional_params { - params.push((n, v, true)); - } - - if params.is_empty() { - return "{}".to_string(); - } - - let mut parts = Vec::new(); - for (name, prop, optional) in ¶ms { - let mut part = String::new(); - - // Description comment - if let Some(desc) = prop.get("description").and_then(|d| d.as_str()) { - for line in desc.lines() { - if line.is_empty() { - part.push('\n'); - } else { - part.push_str(&format!("{child_indent}// {line}\n")); - } - } - } - - let type_str = schema_to_typescript(prop, &child_indent, &[]); - let opt_marker = if *optional { "?" } else { "" }; - part.push_str(&format!("{child_indent}{name}{opt_marker}: {type_str}")); - parts.push(part); - } - - format!("{{\n{}\n{indent}}}", parts.join(",\n")) -} - -fn array_to_typescript(schema: &Value, indent: &str) -> String { - let items = schema.get("items"); - let item_type = match items { - Some(item_schema) => { - let child_indent = format!("{indent} "); - // Check if item has description - let item_desc = item_schema - .get("description") - .and_then(|d| d.as_str()); - let type_str = schema_to_typescript(item_schema, &child_indent, &[]); - - if let Some(desc) = item_desc { - return format!( - "Array<\n{child_indent}// {desc}\n{child_indent}{type_str}\n{indent}>" - ); - } - type_str - } - None => "any".to_string(), - }; - format!("Array<{item_type}>") -} - -fn enum_to_typescript(values: &[Value]) -> String { - let parts: Vec = values - .iter() - .map(|v| match v { - Value::String(s) => format!("\"{s}\""), - Value::Number(n) => n.to_string(), - Value::Bool(b) => b.to_string(), - Value::Null => "null".to_string(), - _ => "any".to_string(), - }) - .collect(); - parts.join(" | ") -} diff --git a/model_gateway/src/routers/grpc/utils/parsers.rs b/model_gateway/src/routers/grpc/utils/parsers.rs index d77135b12..e69a24f5c 100644 --- a/model_gateway/src/routers/grpc/utils/parsers.rs +++ b/model_gateway/src/routers/grpc/utils/parsers.rs @@ -211,10 +211,9 @@ pub(crate) fn has_constrained_output( #[cfg(test)] mod tests { + use openai_protocol::common::{FunctionChoice, JsonSchemaFormat, ToolReference}; + use super::*; - use openai_protocol::common::{ - FunctionChoice, JsonSchemaFormat, ToolReference, - }; // ── has_constrained_output: tool_choice variants ──────────────────── From 1ea9977ebd1715be9c0b7580af5089ca5b1aa010 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Wed, 22 Apr 2026 18:32:19 -0700 Subject: [PATCH 10/27] fix(streaming): enable reasoning parser for constrained outputs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The streaming path incorrectly gated reasoning_parser_available on !has_constrained_output(), preventing the reasoning parser from running when response_format (json_object/json_schema) was set — even when separate_reasoning was true. This caused thinking content to leak into the content field instead of reasoning_content during streaming. The non-streaming path correctly handled this case by not including the constrained output check in reasoning_parser_available. The downstream output_is_constrained variable already correctly gates on !separate_reasoning, so the upfront check was both redundant and wrong. --- model_gateway/src/routers/grpc/regular/streaming.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 36fd109cb..8fd99f345 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -233,10 +233,6 @@ impl StreamingProcessor { // Check parser availability once upfront (log warning only once per request) let reasoning_parser_available = separate_reasoning - && !utils::has_constrained_output( - tool_choice.as_ref(), - original_request.response_format.as_ref(), - ) && utils::check_reasoning_parser_availability( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), From 4e123d9c3d9008da43e18ac9e0444ddce7dd0027 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Thu, 23 Apr 2026 19:26:45 -0700 Subject: [PATCH 11/27] fix(kimik2): rewrite tool_call IDs and fix cross-chunk fence stripping Two fixes for Kimi K2.5 fc-dash regressions: 1. Rewrite user-supplied tool_call IDs (e.g. "call_1") to the "functions.NAME:INDEX" format that the Kimi K2 chat template and parser expect. Without this, the model generates unparseable IDs and the tool parser returns raw special tokens in content. 2. Handle markdown code fences (```json) that split across streaming chunk boundaries. Add a cross-chunk buffer that holds partial fence suffixes and flushes/drops fence residue at end-of-stream. Fixes: multi-turn/native-tag-leak, native-tag-leak-streaming, structured-outputs/json-mode-streaming, call-feedback-classification-streaming Signed-off-by: Connor Li --- .../src/routers/grpc/regular/streaming.rs | 77 ++++++++++++++++++- .../src/routers/grpc/utils/chat_utils.rs | 70 +++++++++++++---- 2 files changed, 130 insertions(+), 17 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 8fd99f345..97854205e 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -225,6 +225,12 @@ impl StreamingProcessor { // Reusable SSE formatting buffer to avoid allocations per chunk let mut sse_buffer = Vec::with_capacity(512); + // Buffer for cross-chunk markdown fence stripping in JSON responses. + // Holds trailing content that could be a partial fence (e.g. "`", "``", + // "```", "```j", "```js", "```jso", "```json", or just "json"/"JSON" + // after a ``` was already stripped). + let mut fence_buffer = String::new(); + // Use dispatch metadata for consistent response fields let request_id = &dispatch.request_id; let model = &dispatch.model; @@ -490,16 +496,52 @@ impl StreamingProcessor { | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) ); if is_json_response { + // Prepend any buffered partial-fence content from the + // previous chunk so cross-chunk fences are handled + // atomically. + if !fence_buffer.is_empty() { + delta = std::mem::take(&mut fence_buffer) + δ + } + delta = delta .replace("```json", "") .replace("```JSON", "") .replace("```", ""); - // Fence tokens may split across chunks (e.g. - // "```" in one delta, "json\n" in the next). - // Drop residual language tags left over. + + // Drop residual language tag left after fence removal let trimmed = delta.trim(); if trimmed == "json" || trimmed == "JSON" { delta = String::new(); + } else if !delta.is_empty() { + // Check if the delta ends with a partial fence + // or language tag that could complete in the + // next chunk. Buffer it instead of emitting. + let suffixes: &[&str] = &[ + "`", "``", "```", "```j", "```js", "```jso", "```J", "```JS", + "```JSO", + ]; + let mut buffered = false; + for suffix in suffixes { + if delta.ends_with(suffix) { + let split_at = delta.len() - suffix.len(); + fence_buffer = delta[split_at..].to_string(); + delta.truncate(split_at); + buffered = true; + break; + } + } + if !buffered { + // Also check if we end with a standalone + // "json"/"JSON" that may be residual + let end_trimmed = delta.trim_end(); + if end_trimmed.ends_with("json") || end_trimmed.ends_with("JSON") { + let tag_start = end_trimmed.len() - 4; + let before_tag = &end_trimmed[..tag_start]; + if before_tag.is_empty() || before_tag.ends_with('\n') { + delta.truncate(tag_start); + } + } + } } } @@ -523,6 +565,35 @@ impl StreamingProcessor { ProtoResponseVariant::Complete(complete) => { let index = complete.index(); + // Flush fence_buffer: at end-of-stream, partial fences + // (backticks, "json" tags) are closing-fence residue and + // should be dropped. Only emit non-fence content. + if !fence_buffer.is_empty() { + let flushed = std::mem::take(&mut fence_buffer) + .replace("```json", "") + .replace("```JSON", "") + .replace("```", ""); + let trimmed = flushed.trim(); + if !trimmed.is_empty() + && trimmed != "json" + && trimmed != "JSON" + && trimmed != "`" + && trimmed != "``" + { + let stream_buffer = stream_buffers.entry(index).or_default(); + stream_buffer.push_str(&flushed); + + let fb_chunk = ChatCompletionStreamResponse::builder(request_id, model) + .created(created) + .add_choice_content(index, "assistant", flushed) + .maybe_system_fingerprint(system_fingerprint) + .build(); + Self::format_sse_chunk_into(&mut sse_buffer, &fb_chunk); + tx.send(Ok(Bytes::from(sse_buffer.clone()))) + .map_err(|_| "Failed to send fence buffer".to_string())?; + } + } + // Flush any remaining text for this index's stop_decoder if let Some(decoder) = stop_decoders.get_mut(&index) { if let SequenceDecoderOutput::Text(text) = decoder.flush() { diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index 230d5c9ce..85a8090c5 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -88,32 +88,74 @@ pub(crate) fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), // {% if ... is string %}{{ args }}{% else %}{{ args | tojson }}{% endif %} // Keeping strings preserves original formatting (e.g. spacing), // matching the behavior of Python/TGL which passes arguments through as-is. - for msg in messages { + // + // Also rewrites tool_call IDs to `functions.NAME:INDEX` format. Some chat + // templates (Kimi K2) pass the raw ID through to the model prompt. When + // users send IDs like `call_1`, the model continues the pattern (`call_2`) + // instead of using the expected `functions.NAME:INDEX` format, causing the + // tool parser to fail. + + // First pass: rewrite assistant tool_call IDs and collect old→new mapping. + let mut id_rewrites: HashMap = HashMap::new(); + + for msg in messages.iter_mut() { let role = msg.get("role").and_then(|v| v.as_str()); if role != Some("assistant") { continue; } - let Some(tool_calls) = msg.get("tool_calls").and_then(|tc| tc.as_array()) else { + let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut()) else { continue; }; - for call in tool_calls { - let Some(function) = call.get("function") else { - continue; - }; - let Some(args) = function.get("arguments") else { - continue; - }; - let Some(args_str) = args.as_str() else { - continue; - }; + for (index, call) in tool_calls.iter_mut().enumerate() { + // Validate arguments JSON + if let Some(args_str) = call + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()) + { + if serde_json::from_str::(args_str).is_err() { + return Err(format!("Invalid JSON in tool call arguments: '{args_str}'")); + } + } - if serde_json::from_str::(args_str).is_err() { - return Err(format!("Invalid JSON in tool call arguments: '{args_str}'")); + // Rewrite ID to functions.NAME:INDEX if not already in that format + let func_name = call + .get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()); + let old_id = call.get("id").and_then(|v| v.as_str()); + + if let (Some(name), Some(old)) = (func_name, old_id) { + let canonical = format!("functions.{name}:{index}"); + if old != canonical { + id_rewrites.insert(old.to_string(), canonical.clone()); + if let Some(obj) = call.as_object_mut() { + obj.insert("id".to_string(), Value::String(canonical)); + } + } + } + } + } + + // Second pass: rewrite tool message tool_call_ids to match + if !id_rewrites.is_empty() { + for msg in messages.iter_mut() { + let role = msg.get("role").and_then(|v| v.as_str()); + if role != Some("tool") { + continue; + } + if let Some(old_id) = msg.get("tool_call_id").and_then(|v| v.as_str()) { + if let Some(new_id) = id_rewrites.get(old_id) { + if let Some(obj) = msg.as_object_mut() { + obj.insert("tool_call_id".to_string(), Value::String(new_id.clone())); + } + } } } } + Ok(()) } From 0bc9f24a268bad17cb3fc1db472bc3681ab95e3d Mon Sep 17 00:00:00 2001 From: Connor Li Date: Fri, 24 Apr 2026 10:40:44 -0700 Subject: [PATCH 12/27] feat(health): make /health_generate issue a real backend probe with logging MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the cached-flag-only health_generate with a real inference probe: sends POST /v1/chat/completions (max_tokens=1, temperature=0, stream=false, 3s timeout) through the router's own API port (self-call via 127.0.0.1). Adds structured INFO/WARN logs at smg::health target: - "sending real inference probe" — before the request, with model + URL - "probe succeeded" — on 200, with duration_ms + worker count - "probe failed (backend error)" — on non-2xx, with status + duration - "probe failed (transport error)" — on timeout/network error Widens scope from External-runtime workers only to ALL workers, so gRPC-connected backends (sglang, TRT-LLM, vLLM) are visible to K8s readiness probes. Both changes are intentional K8s probe semantic changes. Signed-off-by: Connor Li --- model_gateway/src/server.rs | 98 ++++++++++++++++++++++++++++++++++++- 1 file changed, 96 insertions(+), 2 deletions(-) diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index 8ec3aa0b4..fa9b7ecb2 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -87,6 +87,7 @@ pub struct AppState { pub concurrency_queue_tx: Option>, pub router_manager: Option>, pub mesh_handler: Option>, + pub api_port: u16, } async fn parse_function_call( @@ -161,8 +162,100 @@ async fn health(_state: State>) -> Response { liveness().await } -async fn health_generate(State(state): State>, req: Request) -> Response { - state.router.health_generate(req).await +async fn health_generate(State(state): State>, _req: Request) -> Response { + use std::time::Instant; + + let registry = &state.context.worker_registry; + let workers = registry.get_all(); + if workers.is_empty() { + return (StatusCode::SERVICE_UNAVAILABLE, "No workers registered").into_response(); + } + + let healthy: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect(); + if healthy.is_empty() { + let info: Vec<_> = workers.iter().map(|w| format!("{} ({})", w.model_id(), w.url())).collect(); + return ( + StatusCode::SERVICE_UNAVAILABLE, + format!("0/{} workers healthy: {}", workers.len(), info.join(", ")), + ).into_response(); + } + + let model_id = healthy[0].model_id().to_string(); + let probe_url = format!("http://127.0.0.1:{}/v1/chat/completions", state.api_port); + let probe_body = serde_json::json!({ + "model": model_id, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + "stream": false, + "temperature": 0 + }); + + info!( + target: "smg::health", + model = %model_id, + probe_url = %probe_url, + max_tokens = 1, + "health_generate: sending real inference probe" + ); + + let start = Instant::now(); + let probe_result = state.context.client + .post(&probe_url) + .json(&probe_body) + .timeout(std::time::Duration::from_secs(3)) + .send() + .await; + let duration_ms = start.elapsed().as_millis(); + + match probe_result { + Ok(resp) if resp.status().is_success() => { + info!( + target: "smg::health", + model = %model_id, + duration_ms = %duration_ms, + workers = healthy.len(), + "health_generate: probe succeeded" + ); + ( + StatusCode::OK, + format!( + "OK - {} workers healthy, probe succeeded in {}ms (model: {})", + healthy.len(), duration_ms, model_id + ), + ).into_response() + } + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + warn!( + target: "smg::health", + model = %model_id, + status = %status, + duration_ms = %duration_ms, + "health_generate: probe failed (backend error)" + ); + ( + StatusCode::SERVICE_UNAVAILABLE, + format!( + "Probe failed: backend returned {} in {}ms — {}", + status, duration_ms, body + ), + ).into_response() + } + Err(e) => { + warn!( + target: "smg::health", + model = %model_id, + error = %e, + duration_ms = %duration_ms, + "health_generate: probe failed (transport error)" + ); + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Probe failed: {} in {}ms", e, duration_ms), + ).into_response() + } + } } async fn engine_metrics(State(state): State>) -> Response { @@ -1372,6 +1465,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Date: Fri, 24 Apr 2026 14:14:10 -0700 Subject: [PATCH 13/27] feat(logging): pass through user-supplied request_id to engine Add request_id field to ChatCompletionRequest, CompletionRequest, and CreateMessageRequest protocol structs. When a user supplies request_id in the request payload, SMG now passes it through to the engine as the gRPC request ID instead of generating chatcmpl-{Uuid}. This enables correlating engine logs with upstream request tracking. All 4 API paths updated: Chat, Completions, Messages, and Responses (Harmony). Each path logs an INFO message when a user-supplied request_id is used. Also fixes clippy uninlined_format_args and unnecessary qualification warnings in server.rs (from cargo fmt reformatting). Signed-off-by: Connor Li Signed-off-by: Connor Li --- crates/protocols/src/chat.rs | 5 +++ crates/protocols/src/completion.rs | 4 +++ crates/protocols/src/messages.rs | 5 +++ model_gateway/benches/request_processing.rs | 1 + .../grpc/harmony/stages/request_building.rs | 36 +++++++++++++++---- .../regular/stages/chat/request_building.rs | 25 ++++++++++--- .../stages/completion/request_building.rs | 13 +++++-- .../stages/generate/request_building.rs | 2 +- .../stages/messages/request_building.rs | 25 ++++++++++--- .../grpc/regular/stages/request_building.rs | 9 +++-- .../src/routers/grpc/utils/message_utils.rs | 2 ++ model_gateway/src/server.rs | 34 +++++++++++------- .../tests/routing/test_openai_routing.rs | 1 + 13 files changed, 126 insertions(+), 36 deletions(-) diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index 1a7919523..fe0203104 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -323,6 +323,11 @@ pub struct ChatCompletionRequest { /// Random seed for sampling for deterministic outputs pub sampling_seed: Option, + /// User-supplied request ID for log correlation. + /// If set, SMG passes it through to the engine instead of generating its own UUID. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Additional fields not explicitly defined above (e.g. engine-specific parameters) #[serde(flatten)] pub other: Map, diff --git a/crates/protocols/src/completion.rs b/crates/protocols/src/completion.rs index 393596fe9..6229f29cd 100644 --- a/crates/protocols/src/completion.rs +++ b/crates/protocols/src/completion.rs @@ -140,6 +140,10 @@ pub struct CompletionRequest { /// Sampling seed for deterministic outputs pub sampling_seed: Option, + /// User-supplied request ID for log correlation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Additional fields including bootstrap info for PD routing #[serde(flatten)] pub other: Map, diff --git a/crates/protocols/src/messages.rs b/crates/protocols/src/messages.rs index 76adcf77e..4b6c3b4d9 100644 --- a/crates/protocols/src/messages.rs +++ b/crates/protocols/src/messages.rs @@ -73,6 +73,10 @@ pub struct CreateMessageRequest { /// MCP servers to be utilized in this request (beta). pub mcp_servers: Option>, + + /// User-supplied request ID for log correlation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, } impl Normalizable for CreateMessageRequest { @@ -1843,6 +1847,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, } } diff --git a/model_gateway/benches/request_processing.rs b/model_gateway/benches/request_processing.rs index 465143ec4..573deca16 100644 --- a/model_gateway/benches/request_processing.rs +++ b/model_gateway/benches/request_processing.rs @@ -126,6 +126,7 @@ fn default_completion_request() -> CompletionRequest { session_params: None, return_hidden_states: false, sampling_seed: None, + request_id: None, other: serde_json::Map::new(), } } diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index ef6567b7b..5d1867356 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::{debug, error}; +use tracing::{debug, error, info}; use uuid::Uuid; use crate::routers::{ @@ -21,12 +21,16 @@ use crate::routers::{ /// Unlike regular request building, this uses token_ids directly (Harmony encoding handles messages). pub(crate) struct HarmonyRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl HarmonyRequestBuildingStage { /// Create a new Harmony request building stage - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -72,10 +76,24 @@ impl PipelineStage for HarmonyRequestBuildingStage { ClientSelection::Dual { prefill, .. } => prefill, }; - // Generate request_id based on request type + // Generate request_id based on request type — use user-supplied request_id if provided let request_id = match &ctx.input.request_type { - RequestType::Chat(_) => format!("chatcmpl-{}", Uuid::now_v7()), - RequestType::Responses(_) => format!("responses-{}", Uuid::now_v7()), + RequestType::Chat(req) => { + if let Some(id) = req.request_id.clone() { + info!(target: "smg::request", request_id = %id, "Using user-supplied request ID"); + id + } else { + format!("chatcmpl-{}", Uuid::now_v7()) + } + } + RequestType::Responses(req) => { + if let Some(id) = req.request_id.clone() { + info!(target: "smg::request", request_id = %id, "Using user-supplied request ID"); + id + } else { + format!("responses-{}", Uuid::now_v7()) + } + } request_type @ (RequestType::Generate(_) | RequestType::Completion(_) | RequestType::Embedding(_) @@ -93,6 +111,12 @@ impl PipelineStage for HarmonyRequestBuildingStage { } }; + if self.enable_message_hash { + if let RequestType::Chat(req) = &ctx.input.request_type { + helpers::compute_and_log_message_hashes(&request_id, &req.messages); + } + } + // Build gRPC request using token_ids directly (Harmony encoding already handled message rendering) let placeholder_processed_text = "[harmony]".to_string(); diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index 02e538977..59eb35222 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -20,11 +20,15 @@ use crate::routers::{ /// Extracts chat-specific request building logic from the old unified RequestBuildingStage. pub(crate) struct ChatRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl ChatRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -72,8 +76,19 @@ impl PipelineStage for ChatRequestBuildingStage { )); }; - // Build chat request - let request_id = format!("chatcmpl-{}", Uuid::now_v7()); + // Build chat request — use user-supplied request_id if provided + let user_supplied = chat_request.request_id.is_some(); + let request_id = chat_request + .request_id + .clone() + .unwrap_or_else(|| format!("chatcmpl-{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } + + if self.enable_message_hash { + helpers::compute_and_log_message_hashes(&request_id, &chat_request.messages); + } // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { diff --git a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs index 20784b842..3098bdf82 100644 --- a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -27,7 +27,7 @@ pub(crate) struct CompletionRequestBuildingStage { } impl CompletionRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, _enable_message_hash: bool) -> Self { Self { inject_pd_metadata } } } @@ -61,7 +61,14 @@ impl PipelineStage for CompletionRequestBuildingStage { ClientSelection::Dual { prefill, .. } => prefill, }; - let request_id = format!("cmpl_{}", Uuid::now_v7()); + let user_supplied = completion_request.request_id.is_some(); + let request_id = completion_request + .request_id + .clone() + .unwrap_or_else(|| format!("cmpl_{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } let mut proto_request = builder_client .build_completion_request( diff --git a/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs index d2d4273b6..ebf682e72 100644 --- a/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs @@ -22,7 +22,7 @@ pub(crate) struct GenerateRequestBuildingStage { } impl GenerateRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, _enable_message_hash: bool) -> Self { Self { inject_pd_metadata } } } diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs index fc1285456..1cb088037 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -21,11 +21,15 @@ use crate::routers::{ /// and CreateMessageRequest sampling parameters. pub(crate) struct MessageRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl MessageRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -73,8 +77,19 @@ impl PipelineStage for MessageRequestBuildingStage { )); }; - // Build message request - let request_id = format!("msg_{}", Uuid::now_v7()); + // Build message request — use user-supplied request_id if provided + let user_supplied = messages_request.request_id.is_some(); + let request_id = messages_request + .request_id + .clone() + .unwrap_or_else(|| format!("msg_{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } + + if self.enable_message_hash { + helpers::compute_and_log_input_message_hashes(&request_id, &messages_request.messages); + } // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { diff --git a/model_gateway/src/routers/grpc/regular/stages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/request_building.rs index 48a1879cd..cc5477706 100644 --- a/model_gateway/src/routers/grpc/regular/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/request_building.rs @@ -24,10 +24,13 @@ pub(crate) struct ChatGenerateRequestBuildingStage { } impl ChatGenerateRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { Self { - chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata), - generate_stage: GenerateRequestBuildingStage::new(inject_pd_metadata), + chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata, enable_message_hash), + generate_stage: GenerateRequestBuildingStage::new( + inject_pd_metadata, + enable_message_hash, + ), } } } diff --git a/model_gateway/src/routers/grpc/utils/message_utils.rs b/model_gateway/src/routers/grpc/utils/message_utils.rs index eda24a7ad..89b2c98eb 100644 --- a/model_gateway/src/routers/grpc/utils/message_utils.rs +++ b/model_gateway/src/routers/grpc/utils/message_utils.rs @@ -648,6 +648,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, }; assert_eq!(get_history_tool_calls_count_messages(&request), 0); @@ -696,6 +697,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, }; assert_eq!(get_history_tool_calls_count_messages(&request), 2); } diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index fa9b7ecb2..09bf1a5bd 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -173,11 +173,15 @@ async fn health_generate(State(state): State>, _req: Request) -> R let healthy: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect(); if healthy.is_empty() { - let info: Vec<_> = workers.iter().map(|w| format!("{} ({})", w.model_id(), w.url())).collect(); + let info: Vec<_> = workers + .iter() + .map(|w| format!("{} ({})", w.model_id(), w.url())) + .collect(); return ( StatusCode::SERVICE_UNAVAILABLE, format!("0/{} workers healthy: {}", workers.len(), info.join(", ")), - ).into_response(); + ) + .into_response(); } let model_id = healthy[0].model_id().to_string(); @@ -199,10 +203,12 @@ async fn health_generate(State(state): State>, _req: Request) -> R ); let start = Instant::now(); - let probe_result = state.context.client + let probe_result = state + .context + .client .post(&probe_url) .json(&probe_body) - .timeout(std::time::Duration::from_secs(3)) + .timeout(Duration::from_secs(3)) .send() .await; let duration_ms = start.elapsed().as_millis(); @@ -220,9 +226,12 @@ async fn health_generate(State(state): State>, _req: Request) -> R StatusCode::OK, format!( "OK - {} workers healthy, probe succeeded in {}ms (model: {})", - healthy.len(), duration_ms, model_id + healthy.len(), + duration_ms, + model_id ), - ).into_response() + ) + .into_response() } Ok(resp) => { let status = resp.status(); @@ -236,11 +245,9 @@ async fn health_generate(State(state): State>, _req: Request) -> R ); ( StatusCode::SERVICE_UNAVAILABLE, - format!( - "Probe failed: backend returned {} in {}ms — {}", - status, duration_ms, body - ), - ).into_response() + format!("Probe failed: backend returned {status} in {duration_ms}ms — {body}"), + ) + .into_response() } Err(e) => { warn!( @@ -252,8 +259,9 @@ async fn health_generate(State(state): State>, _req: Request) -> R ); ( StatusCode::SERVICE_UNAVAILABLE, - format!("Probe failed: {} in {}ms", e, duration_ms), - ).into_response() + format!("Probe failed: {e} in {duration_ms}ms"), + ) + .into_response() } } } diff --git a/model_gateway/tests/routing/test_openai_routing.rs b/model_gateway/tests/routing/test_openai_routing.rs index 0f66d3c5a..dd94bf78b 100644 --- a/model_gateway/tests/routing/test_openai_routing.rs +++ b/model_gateway/tests/routing/test_openai_routing.rs @@ -94,6 +94,7 @@ fn create_minimal_completion_request() -> CompletionRequest { session_params: None, return_hidden_states: false, sampling_seed: None, + request_id: None, other: serde_json::Map::new(), } } From ddccd3b6c903676ad4b76ff9c107d1ffcda672d4 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Fri, 24 Apr 2026 14:14:22 -0700 Subject: [PATCH 14/27] feat(logging): compute per-message SHA-256 hashes for session reconstruction Add --enable-message-hash CLI flag (default: false) that computes per-message SHA-256 hashes matching TRT-LLM's openai_server.py format: sha256(role + "\x00" + content).hexdigest()[:12] When enabled, hashes are emitted as structured INFO logs on the smg::request target with the request_id, enabling session reconstruction for traffic simulation across all engine runtimes (TRT-LLM, sglang, vLLM). Hash computation wired through RouterConfig -> pipeline constructors -> stage constructors for all API paths. Python bindings updated with enable_message_hash field and --enable-message-hash CLI argument. Signed-off-by: Connor Li Signed-off-by: Connor Li --- bindings/python/src/lib.rs | 5 + bindings/python/src/smg/router_args.py | 8 ++ model_gateway/src/config/builder.rs | 7 ++ model_gateway/src/config/types.rs | 4 + model_gateway/src/main.rs | 5 + .../src/routers/grpc/common/stages/helpers.rs | 91 ++++++++++++++++++- model_gateway/src/routers/grpc/pd_router.rs | 11 ++- model_gateway/src/routers/grpc/pipeline.rs | 42 +++++++-- model_gateway/src/routers/grpc/router.rs | 12 ++- 9 files changed, 172 insertions(+), 13 deletions(-) diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 715cd98e1..7e9b2a6ed 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -445,6 +445,7 @@ struct Router { reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, + enable_message_hash: bool, storage_hook_wasm_path: Option, backend: BackendType, history_backend: HistoryBackendType, @@ -731,6 +732,7 @@ impl Router { .maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_mcp_config_path(self.mcp_config_path.as_ref()) .maybe_storage_hook_wasm_path(self.storage_hook_wasm_path.as_deref()) + .enable_message_hash(self.enable_message_hash) .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) @@ -833,6 +835,7 @@ impl Router { reasoning_parser = None, tool_call_parser = None, mcp_config_path = None, + enable_message_hash = false, storage_hook_wasm_path = None, backend = BackendType::Sglang, history_backend = HistoryBackendType::Memory, @@ -942,6 +945,7 @@ impl Router { reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, + enable_message_hash: bool, storage_hook_wasm_path: Option, backend: BackendType, history_backend: HistoryBackendType, @@ -1062,6 +1066,7 @@ impl Router { reasoning_parser, tool_call_parser, mcp_config_path, + enable_message_hash, storage_hook_wasm_path, backend, history_backend, diff --git a/bindings/python/src/smg/router_args.py b/bindings/python/src/smg/router_args.py index f769eccd2..65955b298 100644 --- a/bindings/python/src/smg/router_args.py +++ b/bindings/python/src/smg/router_args.py @@ -119,6 +119,8 @@ class RouterArgs: mcp_config_path: str | None = None # Backend selection backend: str = "sglang" + # Message hash logging for session reconstruction + enable_message_hash: bool = False # Storage hooks (WASM) storage_hook_wasm_path: str | None = None # History backend configuration @@ -472,6 +474,12 @@ def add_cli_args( default=RouterArgs.log_json, help="Output logs in JSON format", ) + logging_group.add_argument( + f"--{prefix}enable-message-hash", + action="store_true", + default=RouterArgs.enable_message_hash, + help="Compute per-message SHA-256 hashes for session reconstruction logging", + ) # Service discovery configuration k8s_group.add_argument( diff --git a/model_gateway/src/config/builder.rs b/model_gateway/src/config/builder.rs index 209daefdc..16ff17acf 100644 --- a/model_gateway/src/config/builder.rs +++ b/model_gateway/src/config/builder.rs @@ -385,6 +385,13 @@ impl RouterConfigBuilder { self } + // ==================== LOGGING ==================== + + pub fn enable_message_hash(mut self, enable: bool) -> Self { + self.config.enable_message_hash = enable; + self + } + // ==================== WASM ==================== pub fn enable_wasm(mut self, enable: bool) -> Self { diff --git a/model_gateway/src/config/types.rs b/model_gateway/src/config/types.rs index 1a675bd4b..55fb887a6 100644 --- a/model_gateway/src/config/types.rs +++ b/model_gateway/src/config/types.rs @@ -175,6 +175,9 @@ pub struct RouterConfig { /// Loaded from skills_config_path during config creation. #[serde(skip)] pub skills: Option, + /// Compute per-message SHA-256 hashes for session reconstruction logging + #[serde(default)] + pub enable_message_hash: bool, /// Enable WASM support #[serde(default)] pub enable_wasm: bool, @@ -674,6 +677,7 @@ impl Default for RouterConfig { mcp_config: None, skills_enabled: false, skills: None, + enable_message_hash: false, enable_wasm: false, storage_hook_wasm_path: None, server_cert: None, diff --git a/model_gateway/src/main.rs b/model_gateway/src/main.rs index b2d99fca2..2087e1b06 100644 --- a/model_gateway/src/main.rs +++ b/model_gateway/src/main.rs @@ -500,6 +500,10 @@ struct CliArgs { #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle", "postgres", "redis"], help_heading = "Backend")] history_backend: String, + /// Compute per-message SHA-256 hashes for session reconstruction logging + #[arg(long, default_value_t = false, help_heading = "Logging")] + enable_message_hash: bool, + /// Enable WebAssembly support #[arg(long, default_value_t = false, help_heading = "Backend")] enable_wasm: bool, @@ -1263,6 +1267,7 @@ impl CliArgs { .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) + .enable_message_hash(self.enable_message_hash) .enable_wasm(self.enable_wasm) .maybe_storage_hook_wasm_path(self.storage_hook_wasm_path.as_deref()) .igw(self.enable_igw) diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index c7cf84a5c..d5ffc4c26 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -2,9 +2,11 @@ use std::sync::Arc; +use openai_protocol::chat::ChatMessage; use rand::Rng; +use sha2::{Digest, Sha256}; use smg_grpc_client::sglang_proto::DisaggregatedParams; -use tracing::debug; +use tracing::{debug, info}; use crate::{ routers::grpc::{context::WorkerSelection, proto_wrapper::ProtoGenerateRequest}, @@ -56,3 +58,90 @@ fn inject_sglang_bootstrap_metadata( hostname, bootstrap_port, room_id ); } + +fn chat_message_role(msg: &ChatMessage) -> &'static str { + match msg { + ChatMessage::System { .. } => "system", + ChatMessage::User { .. } => "user", + ChatMessage::Assistant { .. } => "assistant", + ChatMessage::Tool { .. } => "tool", + ChatMessage::Function { .. } => "function", + ChatMessage::Developer { .. } => "developer", + } +} + +fn chat_message_text_content(msg: &ChatMessage) -> String { + match msg { + ChatMessage::System { content, .. } + | ChatMessage::User { content, .. } + | ChatMessage::Tool { content, .. } + | ChatMessage::Developer { content, .. } => content.to_simple_string(), + ChatMessage::Assistant { content, .. } => content + .as_ref() + .map_or_else(String::new, |c| c.to_simple_string()), + ChatMessage::Function { content, .. } => content.clone(), + } +} + +/// Compute per-message SHA-256 hashes matching TRT-LLM's `openai_server.py` format: +/// `sha256(role + "\x00" + content).hexdigest()[:12]` +pub(crate) fn compute_and_log_message_hashes(request_id: &str, messages: &[ChatMessage]) { + let hashes: Vec<(&str, String)> = messages + .iter() + .map(|msg| { + let role = chat_message_role(msg); + let content = chat_message_text_content(msg); + let mut hasher = Sha256::new(); + hasher.update(format!("{role}\x00{content}").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + (role, hash[..12].to_string()) + }) + .collect(); + info!( + target: "smg::request", + request_id = %request_id, + message_hashes = ?hashes, + "Request message hashes for session reconstruction" + ); +} + +/// Compute per-message SHA-256 hashes from InputMessage (Messages API) format. +pub(crate) fn compute_and_log_input_message_hashes( + request_id: &str, + messages: &[openai_protocol::messages::InputMessage], +) { + use openai_protocol::messages::Role; + let hashes: Vec<(&str, String)> = messages + .iter() + .map(|msg| { + let role = match msg.role { + Role::User => "user", + Role::Assistant => "assistant", + }; + let content = match &msg.content { + openai_protocol::messages::InputContent::String(s) => s.clone(), + openai_protocol::messages::InputContent::Blocks(blocks) => blocks + .iter() + .filter_map(|b| { + if let openai_protocol::messages::InputContentBlock::Text(t) = b { + Some(t.text.as_str()) + } else { + None + } + }) + .collect::>() + .join(" "), + }; + let mut hasher = Sha256::new(); + hasher.update(format!("{role}\x00{content}").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + (role, hash[..12].to_string()) + }) + .collect(); + info!( + target: "smg::request", + request_id = %request_id, + message_hashes = ?hashes, + "Request message hashes for session reconstruction" + ); +} diff --git a/model_gateway/src/routers/grpc/pd_router.rs b/model_gateway/src/routers/grpc/pd_router.rs index ab96d4101..f8cc3e58c 100644 --- a/model_gateway/src/routers/grpc/pd_router.rs +++ b/model_gateway/src/routers/grpc/pd_router.rs @@ -73,6 +73,8 @@ impl GrpcPDRouter { multimodal, }); + let enable_message_hash = ctx.router_config.enable_message_hash; + // Create PD pipeline let pipeline = RequestPipeline::new_pd( worker_registry.clone(), @@ -81,6 +83,7 @@ impl GrpcPDRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Messages PD pipeline @@ -91,11 +94,15 @@ impl GrpcPDRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Completion PD pipeline - let completion_pipeline = - RequestPipeline::new_completion_pd(worker_registry.clone(), policy_registry.clone()); + let completion_pipeline = RequestPipeline::new_completion_pd( + worker_registry.clone(), + policy_registry.clone(), + enable_message_hash, + ); Ok(GrpcPDRouter { worker_registry, diff --git a/model_gateway/src/routers/grpc/pipeline.rs b/model_gateway/src/routers/grpc/pipeline.rs index f5e4f20b4..fd397ecd9 100644 --- a/model_gateway/src/routers/grpc/pipeline.rs +++ b/model_gateway/src/routers/grpc/pipeline.rs @@ -117,6 +117,7 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -141,7 +142,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(ChatGenerateRequestBuildingStage::new(false)), // No PD metadata + Box::new(ChatGenerateRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(ChatGenerateResponseProcessingStage::new( @@ -164,6 +168,7 @@ impl RequestPipeline { _reasoning_parser_factory: ReasoningParserFactory, _configured_tool_parser: Option, _configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let stages: Vec> = vec![ Box::new(harmony::stages::HarmonyPreparationStage::new()), @@ -173,7 +178,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(harmony::stages::HarmonyRequestBuildingStage::new(false)), + Box::new(harmony::stages::HarmonyRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(harmony::stages::HarmonyResponseProcessingStage::new()), @@ -194,6 +202,7 @@ impl RequestPipeline { _reasoning_parser_factory: ReasoningParserFactory, _configured_tool_parser: Option, _configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let stages: Vec> = vec![ Box::new(harmony::stages::HarmonyPreparationStage::new()), @@ -203,7 +212,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(harmony::stages::HarmonyRequestBuildingStage::new(true)), + Box::new(harmony::stages::HarmonyRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(harmony::stages::HarmonyResponseProcessingStage::new()), @@ -223,6 +235,7 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -247,7 +260,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(ChatGenerateRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(ChatGenerateRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(ChatGenerateResponseProcessingStage::new( @@ -327,6 +343,7 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -351,7 +368,7 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(MessageRequestBuildingStage::new(false)), // No PD metadata + Box::new(MessageRequestBuildingStage::new(false, enable_message_hash)), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(MessageResponseProcessingStage::new( @@ -374,6 +391,7 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -398,7 +416,7 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(MessageRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(MessageRequestBuildingStage::new(true, enable_message_hash)), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(MessageResponseProcessingStage::new( @@ -421,6 +439,7 @@ impl RequestPipeline { pub fn new_completion( worker_registry: Arc, policy_registry: Arc, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -445,7 +464,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(CompletionRequestBuildingStage::new(false)), // No PD metadata + Box::new(CompletionRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(CompletionResponseProcessingStage::new( @@ -464,6 +486,7 @@ impl RequestPipeline { pub fn new_completion_pd( worker_registry: Arc, policy_registry: Arc, + enable_message_hash: bool, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -488,7 +511,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(CompletionRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(CompletionRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(CompletionResponseProcessingStage::new( diff --git a/model_gateway/src/routers/grpc/router.rs b/model_gateway/src/routers/grpc/router.rs index 524da0773..9e5619dbb 100644 --- a/model_gateway/src/routers/grpc/router.rs +++ b/model_gateway/src/routers/grpc/router.rs @@ -88,6 +88,8 @@ impl GrpcRouter { multimodal, }); + let enable_message_hash = ctx.router_config.enable_message_hash; + // Create regular pipeline let pipeline = RequestPipeline::new_regular( worker_registry.clone(), @@ -96,6 +98,7 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Harmony pipelines @@ -106,6 +109,7 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Embedding pipeline @@ -124,11 +128,15 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Completion pipeline - let completion_pipeline = - RequestPipeline::new_completion(worker_registry.clone(), _policy_registry.clone()); + let completion_pipeline = RequestPipeline::new_completion( + worker_registry.clone(), + _policy_registry.clone(), + enable_message_hash, + ); // Extract shared dependencies for responses contexts let mcp_orchestrator = ctx From 9edea19f0234bc0061648030c6720210f17efd56 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Fri, 24 Apr 2026 15:35:53 -0700 Subject: [PATCH 15/27] fix(streaming): replace fence_buffer with simple cross-chunk fence stripping The fence_buffer approach from 06574da0 failed to strip the "json\n" prefix when markdown code fences split across streaming chunks (e.g. "```" in chunk N, "json\n{..." in chunk N+1). Replace with a simpler fence_backticks_stripped state flag that tracks when backticks were stripped, then removes leading language tags from the next chunk. Signed-off-by: Connor Li --- .../src/routers/grpc/regular/streaming.rs | 82 +++---------------- 1 file changed, 12 insertions(+), 70 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 97854205e..1661e7a2b 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -225,11 +225,9 @@ impl StreamingProcessor { // Reusable SSE formatting buffer to avoid allocations per chunk let mut sse_buffer = Vec::with_capacity(512); - // Buffer for cross-chunk markdown fence stripping in JSON responses. - // Holds trailing content that could be a partial fence (e.g. "`", "``", - // "```", "```j", "```js", "```jso", "```json", or just "json"/"JSON" - // after a ``` was already stripped). - let mut fence_buffer = String::new(); + // Tracks whether the previous chunk's backticks were stripped so that + // a language tag arriving at the start of the next chunk can be removed. + let mut fence_backticks_stripped = false; // Use dispatch metadata for consistent response fields let request_id = &dispatch.request_id; @@ -496,53 +494,26 @@ impl StreamingProcessor { | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) ); if is_json_response { - // Prepend any buffered partial-fence content from the - // previous chunk so cross-chunk fences are handled - // atomically. - if !fence_buffer.is_empty() { - delta = std::mem::take(&mut fence_buffer) + δ - } - delta = delta .replace("```json", "") .replace("```JSON", "") .replace("```", ""); - - // Drop residual language tag left after fence removal let trimmed = delta.trim(); if trimmed == "json" || trimmed == "JSON" { delta = String::new(); - } else if !delta.is_empty() { - // Check if the delta ends with a partial fence - // or language tag that could complete in the - // next chunk. Buffer it instead of emitting. - let suffixes: &[&str] = &[ - "`", "``", "```", "```j", "```js", "```jso", "```J", "```JS", - "```JSO", - ]; - let mut buffered = false; - for suffix in suffixes { - if delta.ends_with(suffix) { - let split_at = delta.len() - suffix.len(); - fence_buffer = delta[split_at..].to_string(); - delta.truncate(split_at); - buffered = true; + } + // Handle cross-chunk fence split: backticks were + // stripped from the previous chunk, language tag + // arrives at the start of this one. + if fence_backticks_stripped { + for tag in ["json\r\n", "JSON\r\n", "json\n", "JSON\n"] { + if delta.starts_with(tag) { + delta = delta[tag.len()..].to_string(); break; } } - if !buffered { - // Also check if we end with a standalone - // "json"/"JSON" that may be residual - let end_trimmed = delta.trim_end(); - if end_trimmed.ends_with("json") || end_trimmed.ends_with("JSON") { - let tag_start = end_trimmed.len() - 4; - let before_tag = &end_trimmed[..tag_start]; - if before_tag.is_empty() || before_tag.ends_with('\n') { - delta.truncate(tag_start); - } - } - } } + fence_backticks_stripped = delta.is_empty(); } if !delta.is_empty() { @@ -565,35 +536,6 @@ impl StreamingProcessor { ProtoResponseVariant::Complete(complete) => { let index = complete.index(); - // Flush fence_buffer: at end-of-stream, partial fences - // (backticks, "json" tags) are closing-fence residue and - // should be dropped. Only emit non-fence content. - if !fence_buffer.is_empty() { - let flushed = std::mem::take(&mut fence_buffer) - .replace("```json", "") - .replace("```JSON", "") - .replace("```", ""); - let trimmed = flushed.trim(); - if !trimmed.is_empty() - && trimmed != "json" - && trimmed != "JSON" - && trimmed != "`" - && trimmed != "``" - { - let stream_buffer = stream_buffers.entry(index).or_default(); - stream_buffer.push_str(&flushed); - - let fb_chunk = ChatCompletionStreamResponse::builder(request_id, model) - .created(created) - .add_choice_content(index, "assistant", flushed) - .maybe_system_fingerprint(system_fingerprint) - .build(); - Self::format_sse_chunk_into(&mut sse_buffer, &fb_chunk); - tx.send(Ok(Bytes::from(sse_buffer.clone()))) - .map_err(|_| "Failed to send fence buffer".to_string())?; - } - } - // Flush any remaining text for this index's stop_decoder if let Some(decoder) = stop_decoders.get_mut(&index) { if let SequenceDecoderOutput::Text(text) = decoder.flush() { From 70868745ae12524ca458ca37ad6ebfb416081383 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Fri, 24 Apr 2026 23:16:11 -0700 Subject: [PATCH 16/27] refactor(streaming): extract strip_json_fence helper with unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract inline fence stripping logic into a standalone strip_json_fence() function for testability. Add 7 unit tests covering: - full fence in single chunk - cross-chunk split (backticks then language tag) - CRLF line endings - standalone language tag drop - no-fence passthrough - uppercase JSON/json variants No behavior change — pure refactor of commit eae26bc6. Signed-off-by: Connor Li Made-with: Cursor --- .../src/routers/grpc/regular/streaming.rs | 102 ++++++++++++++---- 1 file changed, 82 insertions(+), 20 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 1661e7a2b..dd3b80f4b 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -494,26 +494,7 @@ impl StreamingProcessor { | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) ); if is_json_response { - delta = delta - .replace("```json", "") - .replace("```JSON", "") - .replace("```", ""); - let trimmed = delta.trim(); - if trimmed == "json" || trimmed == "JSON" { - delta = String::new(); - } - // Handle cross-chunk fence split: backticks were - // stripped from the previous chunk, language tag - // arrives at the start of this one. - if fence_backticks_stripped { - for tag in ["json\r\n", "JSON\r\n", "json\n", "JSON\n"] { - if delta.starts_with(tag) { - delta = delta[tag.len()..].to_string(); - break; - } - } - } - fence_backticks_stripped = delta.is_empty(); + delta = strip_json_fence(delta, &mut fence_backticks_stripped); } if !delta.is_empty() { @@ -2795,3 +2776,84 @@ impl StreamingProcessor { buffer.extend_from_slice(b"\n\n"); } } + +fn strip_json_fence(mut delta: String, fence_backticks_stripped: &mut bool) -> String { + delta = delta + .replace("```json", "") + .replace("```JSON", "") + .replace("```", ""); + let trimmed = delta.trim(); + if trimmed == "json" || trimmed == "JSON" { + delta = String::new(); + } + // Handle cross-chunk fence split: backticks were + // stripped from the previous chunk, language tag + // arrives at the start of this one. + if *fence_backticks_stripped { + for tag in ["json\r\n", "JSON\r\n", "json\n", "JSON\n"] { + if delta.starts_with(tag) { + delta = delta[tag.len()..].to_string(); + break; + } + } + } + *fence_backticks_stripped = delta.is_empty(); + delta +} + +#[cfg(test)] +mod tests { + use super::strip_json_fence; + + fn apply_chunks(chunks: &[&str]) -> String { + let mut fence_state = false; + let mut result = String::new(); + for chunk in chunks { + let out = strip_json_fence(chunk.to_string(), &mut fence_state); + result.push_str(&out); + } + result + } + + #[test] + fn full_fence_single_chunk() { + let out = apply_chunks(&["```json\n{\"a\":1}\n```"]); + assert_eq!(out, "\n{\"a\":1}\n"); + } + + #[test] + fn fence_split_backticks_then_language_tag() { + let out = apply_chunks(&["```", "json\n{\"a\":1}\n```"]); + assert_eq!(out, "{\"a\":1}\n"); + } + + #[test] + fn fence_split_backticks_then_language_tag_crlf() { + let out = apply_chunks(&["```", "json\r\n{\"a\":1}\r\n```"]); + assert_eq!(out, "{\"a\":1}\r\n"); + } + + #[test] + fn standalone_json_tag_dropped() { + let out = apply_chunks(&["json"]); + assert_eq!(out, ""); + } + + #[test] + fn no_fence_passthrough() { + let out = apply_chunks(&["{\"a\":", "1}"]); + assert_eq!(out, "{\"a\":1}"); + } + + #[test] + fn uppercase_fence() { + let out = apply_chunks(&["```JSON\n{\"a\":1}\n```"]); + assert_eq!(out, "\n{\"a\":1}\n"); + } + + #[test] + fn cross_chunk_uppercase() { + let out = apply_chunks(&["```", "JSON\n{\"a\":1}"]); + assert_eq!(out, "{\"a\":1}"); + } +} From 52dd84bc84424eb00f516afacd891dd5d2fa0377 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Fri, 24 Apr 2026 23:42:40 -0700 Subject: [PATCH 17/27] feat(grpc): pass message_hashes through gRPC proto to TRT-LLM Add MessageHash message type and message_hashes field to GenerateRequest proto. Wire computed hashes from request building stages through the gRPC client to TRT-LLM backend, gated behind --enable-message-hash. Signed-off-by: Connor Li Signed-off-by: Connor Li --- crates/grpc_client/proto/trtllm_service.proto | 10 ++++++ crates/grpc_client/src/trtllm_service.rs | 32 +++++++++++++++++++ model_gateway/src/routers/grpc/client.rs | 4 +++ .../src/routers/grpc/common/stages/helpers.rs | 17 ++++++---- .../grpc/harmony/stages/request_building.rs | 12 +++++-- .../regular/stages/chat/request_building.rs | 9 ++++-- .../stages/messages/request_building.rs | 9 ++++-- 7 files changed, 78 insertions(+), 15 deletions(-) diff --git a/crates/grpc_client/proto/trtllm_service.proto b/crates/grpc_client/proto/trtllm_service.proto index e8427f528..a593ebf86 100644 --- a/crates/grpc_client/proto/trtllm_service.proto +++ b/crates/grpc_client/proto/trtllm_service.proto @@ -132,6 +132,16 @@ message GenerateRequest { // When true, stop token IDs are retained in output_token_ids instead of // being stripped. bool include_stop_token_in_output = 26; + + // Per-message SHA-256 hashes for session reconstruction. + repeated MessageHash message_hashes = 27; +} + +// Per-message hash for session reconstruction auditing. +// Hash is the first 12 hex chars of sha256(role + "\x00" + content). +message MessageHash { + string role = 1; + string hash = 2; } // Tokenized input from router diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index 81b3bb4c7..c5a5e5679 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -278,6 +278,7 @@ impl TrtllmServiceClient { multimodal_input: Option, tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) eos_token_ids: &[u32], + message_hashes: Option>, ) -> Result { // Build sampling config let sampling_config = Self::build_sampling_config_from_chat(body); @@ -301,6 +302,14 @@ impl TrtllmServiceClient { eos_token_ids.to_vec() }; + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -328,6 +337,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -411,6 +421,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: vec![], }; Ok(grpc_request) @@ -428,6 +439,7 @@ impl TrtllmServiceClient { processed_text: String, token_ids: Vec, constraint: Option<(String, String)>, + message_hashes: Option>, ) -> Result { let sampling_config = Self::build_sampling_config_from_responses(body); let output_config = proto::OutputConfig { @@ -444,6 +456,14 @@ impl TrtllmServiceClient { let max_tokens = body.max_output_tokens.unwrap_or(2048); + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -471,6 +491,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -663,6 +684,7 @@ impl TrtllmServiceClient { token_ids: Vec, multimodal_input: Option, tool_call_constraint: Option<(String, String)>, + message_hashes: Option>, ) -> Result { let sampling_config = Self::build_sampling_config_from_messages(body); let output_config = proto::OutputConfig { @@ -680,6 +702,14 @@ impl TrtllmServiceClient { let stop = body.stop_sequences.clone().unwrap_or_default(); let max_tokens = body.max_tokens; + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -707,6 +737,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -798,6 +829,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: body.no_stop_trim, + message_hashes: vec![], }; Ok(grpc_request) diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 4cd9c20c6..7e3a910ec 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -329,6 +329,7 @@ impl GrpcClient { multimodal_inputs: Option, tool_constraints: Option<(String, String)>, eos_token_ids: &[u32], + message_hashes: Option>, ) -> Result { match self { Self::Sglang(client) => { @@ -374,6 +375,7 @@ impl GrpcClient { trtllm_mm, tool_constraints, eos_token_ids, + message_hashes, )?; Ok(ProtoGenerateRequest::Trtllm(Box::new(req))) } @@ -403,6 +405,7 @@ impl GrpcClient { token_ids: Vec, multimodal_inputs: Option, tool_constraints: Option<(String, String)>, + message_hashes: Option>, ) -> Result { match self { Self::Sglang(client) => { @@ -447,6 +450,7 @@ impl GrpcClient { token_ids, trtllm_mm, tool_constraints, + message_hashes, )?; Ok(ProtoGenerateRequest::Trtllm(Box::new(req))) } diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index d5ffc4c26..419ef7155 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -85,8 +85,11 @@ fn chat_message_text_content(msg: &ChatMessage) -> String { /// Compute per-message SHA-256 hashes matching TRT-LLM's `openai_server.py` format: /// `sha256(role + "\x00" + content).hexdigest()[:12]` -pub(crate) fn compute_and_log_message_hashes(request_id: &str, messages: &[ChatMessage]) { - let hashes: Vec<(&str, String)> = messages +pub(crate) fn compute_and_log_message_hashes( + request_id: &str, + messages: &[ChatMessage], +) -> Vec<(String, String)> { + let hashes: Vec<(String, String)> = messages .iter() .map(|msg| { let role = chat_message_role(msg); @@ -94,7 +97,7 @@ pub(crate) fn compute_and_log_message_hashes(request_id: &str, messages: &[ChatM let mut hasher = Sha256::new(); hasher.update(format!("{role}\x00{content}").as_bytes()); let hash = format!("{:x}", hasher.finalize()); - (role, hash[..12].to_string()) + (role.to_string(), hash[..12].to_string()) }) .collect(); info!( @@ -103,15 +106,16 @@ pub(crate) fn compute_and_log_message_hashes(request_id: &str, messages: &[ChatM message_hashes = ?hashes, "Request message hashes for session reconstruction" ); + hashes } /// Compute per-message SHA-256 hashes from InputMessage (Messages API) format. pub(crate) fn compute_and_log_input_message_hashes( request_id: &str, messages: &[openai_protocol::messages::InputMessage], -) { +) -> Vec<(String, String)> { use openai_protocol::messages::Role; - let hashes: Vec<(&str, String)> = messages + let hashes: Vec<(String, String)> = messages .iter() .map(|msg| { let role = match msg.role { @@ -135,7 +139,7 @@ pub(crate) fn compute_and_log_input_message_hashes( let mut hasher = Sha256::new(); hasher.update(format!("{role}\x00{content}").as_bytes()); let hash = format!("{:x}", hasher.finalize()); - (role, hash[..12].to_string()) + (role.to_string(), hash[..12].to_string()) }) .collect(); info!( @@ -144,4 +148,5 @@ pub(crate) fn compute_and_log_input_message_hashes( message_hashes = ?hashes, "Request message hashes for session reconstruction" ); + hashes } diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index 5d1867356..7815159bf 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -111,11 +111,15 @@ impl PipelineStage for HarmonyRequestBuildingStage { } }; - if self.enable_message_hash { + let message_hashes = if self.enable_message_hash { if let RequestType::Chat(req) = &ctx.input.request_type { - helpers::compute_and_log_message_hashes(&request_id, &req.messages); + Some(helpers::compute_and_log_message_hashes(&request_id, &req.messages)) + } else { + None } - } + } else { + None + }; // Build gRPC request using token_ids directly (Harmony encoding already handled message rendering) let placeholder_processed_text = "[harmony]".to_string(); @@ -229,6 +233,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { None, // No multimodal in Harmony pipeline tool_constraints, &eos_ids, + message_hashes, ) .map_err(|e| { error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TensorRT-LLM generate request"); @@ -242,6 +247,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { placeholder_processed_text, token_ids, tool_constraints, + None, ) .map_err(|e| { error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TensorRT-LLM generate request from responses"); diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index 59eb35222..c88e910b1 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -86,9 +86,11 @@ impl PipelineStage for ChatRequestBuildingStage { info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); } - if self.enable_message_hash { - helpers::compute_and_log_message_hashes(&request_id, &chat_request.messages); - } + let message_hashes = if self.enable_message_hash { + Some(helpers::compute_and_log_message_hashes(&request_id, &chat_request.messages)) + } else { + None + }; // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { @@ -117,6 +119,7 @@ impl PipelineStage for ChatRequestBuildingStage { multimodal_data, tool_constraints, &eos_token_ids, + message_hashes, ) .map_err(|e| { error!(function = "ChatRequestBuildingStage::execute", error = %e, "Failed to build generate request"); diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs index 1cb088037..f4ebca7b1 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs @@ -87,9 +87,11 @@ impl PipelineStage for MessageRequestBuildingStage { info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); } - if self.enable_message_hash { - helpers::compute_and_log_input_message_hashes(&request_id, &messages_request.messages); - } + let message_hashes = if self.enable_message_hash { + Some(helpers::compute_and_log_input_message_hashes(&request_id, &messages_request.messages)) + } else { + None + }; // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { @@ -112,6 +114,7 @@ impl PipelineStage for MessageRequestBuildingStage { token_ids, multimodal_data, tool_constraints, + message_hashes, ) .map_err(|e| { error!(function = "MessageRequestBuildingStage::execute", error = %e, "Failed to build generate request"); From f9791668d76be7620a635cae108255a386cfc955 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Mon, 27 Apr 2026 22:59:23 -0700 Subject: [PATCH 18/27] fix(health): increase health_generate probe timeout from 3s to 60s Prevents false-positive health check failures under load and during graceful shutdown, which were causing unnecessary k8s pod restarts. --- model_gateway/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index 09bf1a5bd..454d76bbe 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -208,7 +208,7 @@ async fn health_generate(State(state): State>, _req: Request) -> R .client .post(&probe_url) .json(&probe_body) - .timeout(Duration::from_secs(3)) + .timeout(Duration::from_secs(60)) .send() .await; let duration_ms = start.elapsed().as_millis(); From 65ebdbf6e613308ba8ea7521c127cbf41f53b5cc Mon Sep 17 00:00:00 2001 From: Connor Li Date: Tue, 28 Apr 2026 01:57:50 -0700 Subject: [PATCH 19/27] fix(reasoning): skip reasoning parsing for structured output requests When response_format is JsonObject or JsonSchema, the model produces JSON content, not reasoning tokens. The kimi_k25 reasoning parser was consuming this JSON as reasoning_content, causing structured-outputs fc-dash tests to fail (53/58 -> 58/58 with this fix). Check response_format directly instead of gating on !separate_reasoning (which was always false for Kimi, making the constrained-output check a no-op). Applied to all three code paths: non-streaming single choice, non-streaming batch, and streaming. Also fix pre-existing clippy warnings: unnecessary qualifications for ResponseFormat, unneeded binding pattern in monitor.rs, and missing too_many_arguments expects in client.rs and trtllm_service.rs. Format request_building.rs files. Signed-off-by: Connor Li --- crates/grpc_client/src/trtllm_service.rs | 4 +++ model_gateway/src/routers/grpc/client.rs | 4 +++ .../grpc/harmony/stages/request_building.rs | 5 +++- .../src/routers/grpc/regular/processor.rs | 26 ++++++++++++------- .../regular/stages/chat/request_building.rs | 5 +++- .../stages/messages/request_building.rs | 5 +++- .../src/routers/grpc/regular/streaming.rs | 17 +++++++----- model_gateway/src/worker/monitor.rs | 5 +--- 8 files changed, 48 insertions(+), 23 deletions(-) diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index c5a5e5679..5acd50374 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -676,6 +676,10 @@ impl TrtllmServiceClient { clippy::unused_self, reason = "method receiver kept for consistent public API" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder needs all fields from the Messages API request" + )] pub fn build_generate_request_from_messages( &self, request_id: String, diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 7e3a910ec..f14d59ac2 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -397,6 +397,10 @@ impl GrpcClient { clippy::unreachable, reason = "assembly stage guarantees matching MultimodalData variant for each backend" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder needs all fields from the Messages API request" + )] pub fn build_messages_request( &self, request_id: String, diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index 7815159bf..317f25c39 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -113,7 +113,10 @@ impl PipelineStage for HarmonyRequestBuildingStage { let message_hashes = if self.enable_message_hash { if let RequestType::Chat(req) = &ctx.input.request_type { - Some(helpers::compute_and_log_message_hashes(&request_id, &req.messages)) + Some(helpers::compute_and_log_message_hashes( + &request_id, + &req.messages, + )) } else { None } diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index e24b0427f..98b8c6017 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -11,7 +11,7 @@ use llm_tokenizer::{ }; use openai_protocol::{ chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse}, - common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage}, + common::{FunctionCallResponse, ResponseFormat, ToolCall, ToolChoice, ToolChoiceValue, Usage}, completion::{CompletionChoice, CompletionRequest, CompletionResponse}, generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}, messages::{self, CreateMessageRequest, Message}, @@ -96,15 +96,14 @@ impl ResponseProcessor { let mut reasoning_text: Option = None; let mut processed_text = final_text; - let output_is_constrained = !original_request.separate_reasoning - && utils::has_constrained_output( - original_request.tool_choice.as_ref(), - original_request.response_format.as_ref(), - ); + let has_structured_output = matches!( + original_request.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); if original_request.separate_reasoning && reasoning_parser_available - && !output_is_constrained + && !has_structured_output { let pooled_parser = utils::get_reasoning_parser( &self.reasoning_parser_factory, @@ -165,6 +164,10 @@ impl ResponseProcessor { } }; + let output_is_constrained = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), + ); if self.configured_tool_parser.is_some() && tool_parser_available && !output_is_constrained @@ -237,8 +240,7 @@ impl ResponseProcessor { let is_json_response = matches!( &original_request.response_format, - Some(openai_protocol::common::ResponseFormat::JsonObject) - | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) ); if is_json_response { if processed_text.starts_with("```json") || processed_text.starts_with("```JSON") { @@ -326,8 +328,14 @@ impl ResponseProcessor { let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); + let has_structured_output = matches!( + chat_request.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); + // Check parser availability once upfront (not per choice) let reasoning_parser_available = chat_request.separate_reasoning + && !has_structured_output && utils::check_reasoning_parser_availability( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index c88e910b1..139b811a7 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -87,7 +87,10 @@ impl PipelineStage for ChatRequestBuildingStage { } let message_hashes = if self.enable_message_hash { - Some(helpers::compute_and_log_message_hashes(&request_id, &chat_request.messages)) + Some(helpers::compute_and_log_message_hashes( + &request_id, + &chat_request.messages, + )) } else { None }; diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs index f4ebca7b1..6cd3cdd8d 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs @@ -88,7 +88,10 @@ impl PipelineStage for MessageRequestBuildingStage { } let message_hashes = if self.enable_message_hash { - Some(helpers::compute_and_log_input_message_hashes(&request_id, &messages_request.messages)) + Some(helpers::compute_and_log_input_message_hashes( + &request_id, + &messages_request.messages, + )) } else { None }; diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index dd3b80f4b..dc91aa61d 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -18,7 +18,8 @@ use llm_tokenizer::{ use openai_protocol::{ chat::{ChatCompletionRequest, ChatCompletionStreamResponse}, common::{ - FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, ToolChoiceValue, Usage, + FunctionCallDelta, ResponseFormat, StringOrArray, Tool, ToolCallDelta, ToolChoice, + ToolChoiceValue, Usage, }, completion::{CompletionRequest, CompletionStreamChoice, CompletionStreamResponse}, generate::GenerateRequest, @@ -195,8 +196,12 @@ impl StreamingProcessor { let start_time = Instant::now(); let mut first_token_time: Option = None; - // Extract request parameters - let separate_reasoning = original_request.separate_reasoning; + // Extract request parameters — skip reasoning parsing for structured output + let has_structured_output = matches!( + original_request.response_format, + Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + ); + let separate_reasoning = original_request.separate_reasoning && !has_structured_output; let tool_choice = &original_request.tool_choice; let tools = &original_request.tools; let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request); @@ -285,8 +290,7 @@ impl StreamingProcessor { && (is_specific_function || matches!( &original_request.response_format, - Some(openai_protocol::common::ResponseFormat::JsonObject) - | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) )); let tool_parser_available = tools.is_some() @@ -490,8 +494,7 @@ impl StreamingProcessor { // streamed content is directly parseable. let is_json_response = matches!( &original_request.response_format, - Some(openai_protocol::common::ResponseFormat::JsonObject) - | Some(openai_protocol::common::ResponseFormat::JsonSchema { .. }) + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) ); if is_json_response { delta = strip_json_fence(delta, &mut fence_backticks_stripped); diff --git a/model_gateway/src/worker/monitor.rs b/model_gateway/src/worker/monitor.rs index 9be5b225f..583b2c667 100644 --- a/model_gateway/src/worker/monitor.rs +++ b/model_gateway/src/worker/monitor.rs @@ -537,10 +537,7 @@ async fn run_event_loop( } } Ok(WorkerEvent::StatusChanged { - worker, - new_status, - old_status: _, - .. + worker, new_status, .. }) => { if new_status != WorkerStatus::Ready { monitor.evict_worker_loads(worker.url()); From 0061eead9e4e2f2a6073612cad16e638b1e91660 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Tue, 28 Apr 2026 14:25:27 -0700 Subject: [PATCH 20/27] feat(health): skip inference probe when tokens forwarded recently Mirrors TRT-LLM / TGL approach: if any token chunk was forwarded to a client within the last 10 seconds, /health_generate returns 200 immediately without sending a synthetic inference request. Under high load the probe adds backpressure and can itself time out under k8s's 10s kubelet deadline, causing spurious restarts. The timestamp fast-path eliminates that failure mode. Implementation: - AppContext gains last_token_time: Arc (unix seconds, 0 = none) - StreamingProcessor (gRPC path) wraps the SSE channel with a map() that stores the current timestamp on every successful chunk via build_tracked_sse_response() - OpenAI HTTP relay (chat.rs) stores the timestamp after each successful tx.send() in the bytes relay task - health_generate checks last_token_time first; falls back to the real inference probe only when the server has been idle for >=10s Signed-off-by: Connor Li --- .../benches/wasm_middleware_latency.rs | 1 + model_gateway/src/app_context.rs | 10 +- .../src/routers/grpc/common/responses/mod.rs | 2 +- .../grpc/common/responses/streaming.rs | 95 ++++++++++++- model_gateway/src/routers/grpc/pd_router.rs | 3 + model_gateway/src/routers/grpc/pipeline.rs | 18 ++- .../src/routers/grpc/regular/streaming.rs | 21 ++- model_gateway/src/routers/grpc/router.rs | 3 + model_gateway/src/routers/openai/chat.rs | 8 ++ model_gateway/src/routers/openai/context.rs | 6 +- model_gateway/src/routers/openai/router.rs | 1 + model_gateway/src/server.rs | 34 ++++- model_gateway/src/service_discovery.rs | 3 + model_gateway/tests/api/api_endpoints_test.rs | 133 +++++++++++++++++- model_gateway/tests/common/test_app.rs | 2 + model_gateway/tests/wasm_test.rs | 1 + 16 files changed, 327 insertions(+), 14 deletions(-) diff --git a/model_gateway/benches/wasm_middleware_latency.rs b/model_gateway/benches/wasm_middleware_latency.rs index 33b1b3221..f1c8889e1 100644 --- a/model_gateway/benches/wasm_middleware_latency.rs +++ b/model_gateway/benches/wasm_middleware_latency.rs @@ -79,6 +79,7 @@ fn bench_wasm_middleware_buffering(c: &mut Criterion) { concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); c.bench_function("wasm_middleware_pre_fix_latency", |b| { diff --git a/model_gateway/src/app_context.rs b/model_gateway/src/app_context.rs index f53898e3c..ea81dde1e 100644 --- a/model_gateway/src/app_context.rs +++ b/model_gateway/src/app_context.rs @@ -1,5 +1,8 @@ use std::{ - sync::{Arc, OnceLock}, + sync::{ + atomic::AtomicU64, + Arc, OnceLock, + }, time::Duration, }; @@ -77,6 +80,10 @@ pub struct AppContext { pub wasm_manager: Option>, pub worker_service: Arc, pub inflight_tracker: Arc, + /// Unix timestamp (seconds) of the last token chunk forwarded to a client. + /// Zero means no tokens have been forwarded yet. Used by health_generate to + /// skip the real inference probe when traffic is flowing recently. + pub last_token_time: Arc, pub kv_event_monitor: Option>, pub realtime_registry: Arc, /// Bind address for WebRTC UDP sockets (`None` = `0.0.0.0`, auto-detect). @@ -400,6 +407,7 @@ impl AppContextBuilder { wasm_manager: self.wasm_manager, worker_service, inflight_tracker: InFlightRequestTracker::new(), + last_token_time: Arc::new(AtomicU64::new(0)), kv_event_monitor: self.kv_event_monitor, realtime_registry: Arc::new(RealtimeRegistry::new()), webrtc_bind_addr: self.webrtc_bind_addr, diff --git a/model_gateway/src/routers/grpc/common/responses/mod.rs b/model_gateway/src/routers/grpc/common/responses/mod.rs index fd94e65ff..1bbdbdf3b 100644 --- a/model_gateway/src/routers/grpc/common/responses/mod.rs +++ b/model_gateway/src/routers/grpc/common/responses/mod.rs @@ -7,7 +7,7 @@ pub(crate) mod utils; // Re-export commonly used items pub(crate) use context::ResponsesContext; -pub(crate) use streaming::build_sse_response; +pub(crate) use streaming::{build_sse_response, build_tracked_sse_response}; pub(crate) use utils::{ensure_mcp_connection, persist_response_if_needed}; pub(crate) use crate::routers::common::mcp_utils::collect_user_function_names; diff --git a/model_gateway/src/routers/grpc/common/responses/streaming.rs b/model_gateway/src/routers/grpc/common/responses/streaming.rs index 5cc119df8..cea5f9944 100644 --- a/model_gateway/src/routers/grpc/common/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/common/responses/streaming.rs @@ -1,5 +1,7 @@ //! Streaming infrastructure for /v1/responses endpoint +use std::sync::{atomic::AtomicU64, Arc}; + use axum::{body::Body, http::StatusCode, response::Response}; use bytes::Bytes; use http::header::{HeaderValue, CONTENT_TYPE}; @@ -18,7 +20,7 @@ use openai_protocol::{ use serde_json::json; use smg_mcp::{self as mcp, ResponseFormat}; use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; use tracing::warn; use uuid::Uuid; @@ -1042,6 +1044,42 @@ pub(crate) fn build_sse_response( .expect("infallible: static headers and valid status code") } +/// Like `build_sse_response` but atomically records the current time whenever a +/// chunk is successfully written, allowing the health check to skip its inference +/// probe when tokens have flowed recently. +#[expect( + clippy::expect_used, + reason = "Response::builder with static headers and valid status code is infallible" +)] +pub(crate) fn build_tracked_sse_response( + rx: mpsc::UnboundedReceiver>, + last_token_time: &Arc, +) -> Response { + use std::sync::atomic::Ordering; + + let last_token_time = last_token_time.clone(); + let stream = UnboundedReceiverStream::new(rx).map(move |result| { + if result.is_ok() { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + last_token_time.store(now, Ordering::Relaxed); + } + result + }); + Response::builder() + .status(StatusCode::OK) + .header( + CONTENT_TYPE, + HeaderValue::from_static("text/event-stream; charset=utf-8"), + ) + .header("Cache-Control", HeaderValue::from_static("no-cache")) + .header("Connection", HeaderValue::from_static("keep-alive")) + .body(Body::from_stream(stream)) + .expect("infallible: static headers and valid status code") +} + /// Attach `server_label` to an MCP tool-call JSON item. /// /// Only sets the field when `response_format` indicates a passthrough (mcp_call) @@ -1056,3 +1094,58 @@ pub(crate) fn attach_mcp_server_label( item["server_label"] = json!(label); } } + +#[cfg(test)] +mod tests { + use std::sync::{atomic::{AtomicU64, Ordering}, Arc}; + + use axum::{body::to_bytes, http::StatusCode}; + use bytes::Bytes; + use tokio::sync::mpsc; + + use super::build_tracked_sse_response; + + #[tokio::test] + async fn tracked_sse_response_updates_timestamp_on_ok_chunk() { + use std::time::{SystemTime, UNIX_EPOCH}; + + let last_token_time = Arc::new(AtomicU64::new(0)); + let (tx, rx) = mpsc::unbounded_channel::>(); + + let resp = build_tracked_sse_response(rx, &last_token_time); + assert_eq!(resp.status(), StatusCode::OK); + + let before = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + tx.send(Ok(Bytes::from("data: hello\n\n"))).unwrap(); + drop(tx); + + let _ = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + + let stored = last_token_time.load(Ordering::Relaxed); + assert!(stored >= before, "timestamp should be >= before: {stored} < {before}"); + assert!(stored <= before + 5, "timestamp should be within 5s: {stored}"); + } + + #[tokio::test] + async fn tracked_sse_response_does_not_update_timestamp_on_error_chunk() { + let last_token_time = Arc::new(AtomicU64::new(0)); + let (tx, rx) = mpsc::unbounded_channel::>(); + + let resp = build_tracked_sse_response(rx, &last_token_time); + + tx.send(Err(std::io::Error::other("stream error"))).unwrap(); + drop(tx); + + let _ = to_bytes(resp.into_body(), usize::MAX).await; + + assert_eq!( + last_token_time.load(Ordering::Relaxed), + 0, + "error chunk should not update the timestamp" + ); + } +} diff --git a/model_gateway/src/routers/grpc/pd_router.rs b/model_gateway/src/routers/grpc/pd_router.rs index f8cc3e58c..221c62492 100644 --- a/model_gateway/src/routers/grpc/pd_router.rs +++ b/model_gateway/src/routers/grpc/pd_router.rs @@ -84,6 +84,7 @@ impl GrpcPDRouter { ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); // Create Messages PD pipeline @@ -95,6 +96,7 @@ impl GrpcPDRouter { ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); // Create Completion PD pipeline @@ -102,6 +104,7 @@ impl GrpcPDRouter { worker_registry.clone(), policy_registry.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); Ok(GrpcPDRouter { diff --git a/model_gateway/src/routers/grpc/pipeline.rs b/model_gateway/src/routers/grpc/pipeline.rs index fd397ecd9..c9a081e4f 100644 --- a/model_gateway/src/routers/grpc/pipeline.rs +++ b/model_gateway/src/routers/grpc/pipeline.rs @@ -3,7 +3,7 @@ //! This module defines the RequestPipeline orchestrator that coordinates //! the execution of pipeline stages from request preparation to response delivery. -use std::{sync::Arc, time::Instant}; +use std::{sync::{atomic::AtomicU64, Arc}, time::Instant}; use axum::response::{IntoResponse, Response}; use openai_protocol::{ @@ -110,6 +110,7 @@ impl RequestPipeline { } /// Create a regular (single-worker) pipeline + #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] pub fn new_regular( worker_registry: Arc, policy_registry: Arc, @@ -118,6 +119,7 @@ impl RequestPipeline { configured_tool_parser: Option, configured_reasoning_parser: Option, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -132,6 +134,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -228,6 +231,7 @@ impl RequestPipeline { } /// Create a PD (prefill-decode) pipeline + #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] pub fn new_pd( worker_registry: Arc, policy_registry: Arc, @@ -236,6 +240,7 @@ impl RequestPipeline { configured_tool_parser: Option, configured_reasoning_parser: Option, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -250,6 +255,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ @@ -336,6 +342,7 @@ impl RequestPipeline { /// Uses Messages-specific stages for preparation, request building, and response /// processing. Shares worker selection, client acquisition, dispatch metadata, /// and request execution stages with other pipelines. + #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] pub fn new_messages( worker_registry: Arc, policy_registry: Arc, @@ -344,6 +351,7 @@ impl RequestPipeline { configured_tool_parser: Option, configured_reasoning_parser: Option, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -358,6 +366,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -384,6 +393,7 @@ impl RequestPipeline { } /// Create a Messages API PD (prefill-decode) pipeline + #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] pub fn new_messages_pd( worker_registry: Arc, policy_registry: Arc, @@ -392,6 +402,7 @@ impl RequestPipeline { configured_tool_parser: Option, configured_reasoning_parser: Option, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -406,6 +417,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ @@ -440,6 +452,7 @@ impl RequestPipeline { worker_registry: Arc, policy_registry: Arc, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -454,6 +467,7 @@ impl RequestPipeline { None, None, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -487,6 +501,7 @@ impl RequestPipeline { worker_registry: Arc, policy_registry: Arc, enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -501,6 +516,7 @@ impl RequestPipeline { None, None, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index dc91aa61d..876567e5c 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -5,7 +5,10 @@ use std::{ collections::{HashMap, HashSet}, io, - sync::Arc, + sync::{ + atomic::AtomicU64, + Arc, + }, time::Instant, }; @@ -37,7 +40,10 @@ use tracing::{debug, error, warn}; use crate::{ observability::metrics::{metrics_labels, Metrics, StreamingMetricsParams}, routers::grpc::{ - common::{response_formatting::CompletionTokenTracker, responses::build_sse_response}, + common::{ + response_formatting::CompletionTokenTracker, + responses::build_tracked_sse_response, + }, context, proto_wrapper::{ProtoResponseVariant, ProtoStream}, utils, @@ -53,6 +59,7 @@ pub(crate) struct StreamingProcessor { configured_tool_parser: Option, configured_reasoning_parser: Option, backend_type: &'static str, + last_token_time: Arc, } /// Context for generate endpoint streaming - groups config params to reduce function arguments @@ -71,6 +78,7 @@ impl StreamingProcessor { configured_tool_parser: Option, configured_reasoning_parser: Option, backend_type: &'static str, + last_token_time: Arc, ) -> Self { Self { tool_parser_factory, @@ -78,6 +86,7 @@ impl StreamingProcessor { configured_tool_parser, configured_reasoning_parser, backend_type, + last_token_time, } } @@ -179,7 +188,7 @@ impl StreamingProcessor { } // Return SSE response - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process streaming chunks from a single stream (Regular mode) @@ -787,7 +796,7 @@ impl StreamingProcessor { } // Return SSE response - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process streaming chunks for generate endpoint (no tool/reasoning parsing) @@ -1633,7 +1642,7 @@ impl StreamingProcessor { } } - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process Messages API streaming chunks from a single stream. @@ -2379,7 +2388,7 @@ impl StreamingProcessor { } } - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process completion streaming chunks from a single stream. diff --git a/model_gateway/src/routers/grpc/router.rs b/model_gateway/src/routers/grpc/router.rs index 9e5619dbb..1776ae1f0 100644 --- a/model_gateway/src/routers/grpc/router.rs +++ b/model_gateway/src/routers/grpc/router.rs @@ -99,6 +99,7 @@ impl GrpcRouter { ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); // Create Harmony pipelines @@ -129,6 +130,7 @@ impl GrpcRouter { ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); // Create Completion pipeline @@ -136,6 +138,7 @@ impl GrpcRouter { worker_registry.clone(), _policy_registry.clone(), enable_message_hash, + ctx.last_token_time.clone(), ); // Extract shared dependencies for responses contexts diff --git a/model_gateway/src/routers/openai/chat.rs b/model_gateway/src/routers/openai/chat.rs index 724dda03a..69acf9435 100644 --- a/model_gateway/src/routers/openai/chat.rs +++ b/model_gateway/src/routers/openai/chat.rs @@ -52,6 +52,7 @@ pub(super) async fn route_chat( let start = Instant::now(); let model = model_id; let streaming = body.stream; + let last_token_time = deps.shared_components.last_token_time.clone(); Metrics::record_router_request( metrics_labels::ROUTER_OPENAI, @@ -160,6 +161,7 @@ pub(super) async fn route_chat( let headers = Arc::clone(&headers_cloned); let worker_api_key = Arc::clone(&worker_api_key); let worker = Arc::clone(&worker); + let last_token_time = last_token_time.clone(); async move { let mut req = client.post(&url).json(&*payload); @@ -196,6 +198,7 @@ pub(super) async fn route_chat( let (tx, rx) = mpsc::unbounded_channel(); #[expect(clippy::disallowed_methods, reason = "fire-and-forget stream relay; gateway shutdown need not wait for individual stream forwarding")] tokio::spawn(async move { + use std::sync::atomic::Ordering; let mut s = stream; while let Some(chunk) = s.next().await { match chunk { @@ -203,6 +206,11 @@ pub(super) async fn route_chat( if tx.send(Ok(bytes)).is_err() { break; } + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + last_token_time.store(now, Ordering::Relaxed); } Err(e) => { let _ = tx.send(Err(format!("Stream error: {e}"))); diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 9e94ccd24..57b704d32 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -1,6 +1,9 @@ //! Request context types for OpenAI router pipeline. -use std::sync::Arc; +use std::sync::{ + atomic::AtomicU64, + Arc, +}; use axum::http::HeaderMap; use openai_protocol::{chat::ChatCompletionRequest, responses::ResponsesRequest}; @@ -42,6 +45,7 @@ pub enum RequestType { pub struct SharedComponents { pub client: reqwest::Client, pub router_config: Arc, + pub last_token_time: Arc, } pub struct ResponsesComponents { diff --git a/model_gateway/src/routers/openai/router.rs b/model_gateway/src/routers/openai/router.rs index 0bda10e99..ee6da1de3 100644 --- a/model_gateway/src/routers/openai/router.rs +++ b/model_gateway/src/routers/openai/router.rs @@ -91,6 +91,7 @@ impl OpenAIRouter { let shared_components = Arc::new(SharedComponents { client: ctx.client.clone(), router_config: Arc::new(ctx.router_config.clone()), + last_token_time: ctx.last_token_time.clone(), }); let responses_components = Arc::new(ResponsesComponents { diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index 454d76bbe..532dbc100 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -163,7 +163,10 @@ async fn health(_state: State>) -> Response { } async fn health_generate(State(state): State>, _req: Request) -> Response { - use std::time::Instant; + use std::{ + sync::atomic::Ordering, + time::{Instant, SystemTime, UNIX_EPOCH}, + }; let registry = &state.context.worker_registry; let workers = registry.get_all(); @@ -184,6 +187,35 @@ async fn health_generate(State(state): State>, _req: Request) -> R .into_response(); } + // Fast path: if a token was forwarded within the last 10 seconds, skip the + // probe entirely. Under high load this avoids adding a synthetic request that + // could tip the backend over its concurrency limit. + let last_token_unix = state.context.last_token_time.load(Ordering::Relaxed); + if last_token_unix > 0 { + let now_unix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let age_secs = now_unix.saturating_sub(last_token_unix); + if age_secs < 10 { + info!( + target: "smg::health", + workers = healthy.len(), + last_token_age_secs = age_secs, + "health_generate: skipping probe — token forwarded recently" + ); + return ( + StatusCode::OK, + format!( + "OK - {} workers healthy, last token {}s ago (probe skipped)", + healthy.len(), + age_secs + ), + ) + .into_response(); + } + } + let model_id = healthy[0].model_id().to_string(); let probe_url = format!("http://127.0.0.1:{}/v1/chat/completions", state.api_port); let probe_body = serde_json::json!({ diff --git a/model_gateway/src/service_discovery.rs b/model_gateway/src/service_discovery.rs index fc35a7997..5f57dddff 100644 --- a/model_gateway/src/service_discovery.rs +++ b/model_gateway/src/service_discovery.rs @@ -1237,6 +1237,8 @@ mod tests { } fn create_test_app_context() -> Arc { + use std::sync::atomic::AtomicU64; + use crate::{ config::RouterConfig, middleware::TokenBucket, observability::inflight_tracker::InFlightRequestTracker, @@ -1288,6 +1290,7 @@ mod tests { router_config, )), inflight_tracker: InFlightRequestTracker::new(), + last_token_time: Arc::new(AtomicU64::new(0)), kv_event_monitor: None, realtime_registry: Arc::new(RealtimeRegistry::new()), webrtc_bind_addr: None, diff --git a/model_gateway/tests/api/api_endpoints_test.rs b/model_gateway/tests/api/api_endpoints_test.rs index d5357c27d..66cf103f5 100644 --- a/model_gateway/tests/api/api_endpoints_test.rs +++ b/model_gateway/tests/api/api_endpoints_test.rs @@ -112,6 +112,9 @@ mod health_tests { #[tokio::test] async fn test_health_generate_endpoint() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + let ctx = AppTestContext::new(vec![MockWorkerConfig { port: 18005, worker_type: WorkerType::Regular, @@ -121,6 +124,16 @@ mod health_tests { }]) .await; + // Simulate recent traffic so the fast path fires and avoids + // a real probe (which would fail since api_port=0 in tests). + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + ctx.app_context + .last_token_time + .store(now, Ordering::Relaxed); + let app = ctx.create_app(); let req = Request::builder() @@ -135,8 +148,124 @@ mod health_tests { let body = axum::body::to_bytes(resp.into_body(), usize::MAX) .await .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - assert!(body_json.is_object()); + let body_str = String::from_utf8_lossy(&body); + assert!( + body_str.contains("workers healthy"), + "Expected workers healthy in body, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// Fast path: if last_token_time is recent (< 10s), the probe is skipped. + #[tokio::test] + async fn test_health_generate_fast_path_skips_probe() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + + let ctx = AppTestContext::new(vec![MockWorkerConfig { + port: 18006, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + ctx.app_context + .last_token_time + .store(now, Ordering::Relaxed); + + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_str = String::from_utf8_lossy(&body); + assert!( + body_str.contains("probe skipped"), + "Expected 'probe skipped' in body, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// Stale fast path: last_token_time older than 10s falls through to the probe. + /// The probe will fail (no real server at api_port=0), returning 503. + #[tokio::test] + async fn test_health_generate_stale_token_falls_through_to_probe() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + + let ctx = AppTestContext::new(vec![MockWorkerConfig { + port: 18007, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let stale = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + .saturating_sub(30); + ctx.app_context + .last_token_time + .store(stale, Ordering::Relaxed); + + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + // Probe to api_port=0 will fail → 503 + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_str = String::from_utf8_lossy(&body); + assert!( + !body_str.contains("probe skipped"), + "Stale token should not skip probe, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// No workers → 503 immediately, no probe attempted. + #[tokio::test] + async fn test_health_generate_no_workers() { + let ctx = AppTestContext::new(vec![]).await; + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); ctx.shutdown().await; } diff --git a/model_gateway/tests/common/test_app.rs b/model_gateway/tests/common/test_app.rs index 40954e12b..8d640b5a5 100644 --- a/model_gateway/tests/common/test_app.rs +++ b/model_gateway/tests/common/test_app.rs @@ -106,6 +106,7 @@ pub fn create_test_app( concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); // Configure request ID headers (use defaults if not specified) @@ -148,6 +149,7 @@ pub fn create_test_app_with_context( concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); // Get config from the context diff --git a/model_gateway/tests/wasm_test.rs b/model_gateway/tests/wasm_test.rs index c877499b2..fc7527a24 100644 --- a/model_gateway/tests/wasm_test.rs +++ b/model_gateway/tests/wasm_test.rs @@ -196,6 +196,7 @@ async fn create_test_app_with_wasm() -> (axum::Router, Arc, TempDir) concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); let request_id_headers = vec!["x-request-id".to_string(), "x-correlation-id".to_string()]; From 10af42dfac33143315cf91693c965035a16422e1 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Tue, 28 Apr 2026 15:15:26 -0700 Subject: [PATCH 21/27] fix(multimodal): use 1 placeholder per image for Kimi-K2.5 Both TRT-LLM (KimiK25InputProcessor) and TGL/SGLang (KimiGPUProcessorWrapper) expand a single <|media_pad|> placeholder to N vision tokens server-side based on grid_thws. The previous code pre-expanded to N tokens in SMG, causing double expansion and a "More media placeholder tokens than media items" 400 error from the engine. Switch from PromptReplacement::repeated(num_tokens) to PromptReplacement::sequence(vec![1]) so SMG sends exactly 1 placeholder per image, matching the engine contract. Signed-off-by: Connor Li Made-with: Cursor --- crates/multimodal/src/registry/kimi_k25.rs | 27 +++++++++++----------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/crates/multimodal/src/registry/kimi_k25.rs b/crates/multimodal/src/registry/kimi_k25.rs index 1347bf285..5ef515c05 100644 --- a/crates/multimodal/src/registry/kimi_k25.rs +++ b/crates/multimodal/src/registry/kimi_k25.rs @@ -61,16 +61,16 @@ impl ModelProcessorSpec for KimiK25VisionSpec { ) -> RegistryResult> { let pad_token_id = Self::pad_token_id(metadata)?; let placeholder_token = self.placeholder_token(metadata)?; + // Keep 1 placeholder per image — TRT-LLM's KimiK25InputProcessor + // handles expansion to N vision tokens server-side based on grid_thws. + // SMG must NOT pre-expand or the engine will see N placeholders and + // attempt to expand each one again. Ok(preprocessed .num_img_tokens .iter() - .map(|&num_tokens| { - PromptReplacement::repeated( - Modality::Image, - &placeholder_token, - pad_token_id, - num_tokens, - ) + .map(|_| { + let tokens = vec![pad_token_id; 1]; + PromptReplacement::sequence(Modality::Image, &placeholder_token, tokens) }) .collect()) } @@ -142,9 +142,10 @@ mod tests { ) .unwrap(); - // 256 pad tokens (no start/end wrapper — SGLang handles that in the chat template) - assert_eq!(replacements[0].tokens.len(), 256); - assert!(replacements[0].tokens.iter().all(|&t| t == 163605)); + // 1 placeholder per image (engine expands to N vision tokens server-side) + assert_eq!(replacements.len(), 1); + assert_eq!(replacements[0].tokens.len(), 1); + assert_eq!(replacements[0].tokens[0], 163605); } #[test] @@ -173,9 +174,9 @@ mod tests { .unwrap(); assert_eq!(replacements.len(), 2); - assert_eq!(replacements[0].tokens.len(), 256); - assert_eq!(replacements[1].tokens.len(), 64); - assert!(replacements[1].tokens.iter().all(|&t| t == 163605)); + assert_eq!(replacements[0].tokens.len(), 1); + assert_eq!(replacements[1].tokens.len(), 1); + assert_eq!(replacements[0].tokens[0], 163605); } #[test] From 6764b57db43ea88e0f0513982dd0d9414e3fa87e Mon Sep 17 00:00:00 2001 From: Connor Li Date: Tue, 28 Apr 2026 17:44:19 -0700 Subject: [PATCH 22/27] fix(multimodal): collapse media placeholders for TRT-LLM only SGLang uses N <|media_pad|> tokens per image directly for embedding lookup, while TRT-LLM's KimiK25InputProcessor re-expands each single placeholder to N tokens based on grid_thws. The previous commit used 1 placeholder per image unconditionally, which fixed TRT-LLM (no double-expansion) but broke SGLang (embedding split_with_sizes mismatch: tensor=3588 vs split_sizes=[1]). Now: prompt_replacements emits N tokens (SGLang-compatible), and collapse_media_placeholders() in request_building collapses consecutive runs to 1 for TRT-LLM only, gated on builder_client.is_trtllm(). Applied to both Chat and Messages API request building paths. Signed-off-by: Connor Li Made-with: Cursor --- crates/multimodal/src/registry/kimi_k25.rs | 31 ++++++++++--------- model_gateway/src/routers/grpc/multimodal.rs | 24 ++++++++++++++ .../regular/stages/chat/request_building.rs | 18 +++++++++-- .../stages/messages/request_building.rs | 15 +++++++-- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/crates/multimodal/src/registry/kimi_k25.rs b/crates/multimodal/src/registry/kimi_k25.rs index 5ef515c05..ccdae33da 100644 --- a/crates/multimodal/src/registry/kimi_k25.rs +++ b/crates/multimodal/src/registry/kimi_k25.rs @@ -61,16 +61,20 @@ impl ModelProcessorSpec for KimiK25VisionSpec { ) -> RegistryResult> { let pad_token_id = Self::pad_token_id(metadata)?; let placeholder_token = self.placeholder_token(metadata)?; - // Keep 1 placeholder per image — TRT-LLM's KimiK25InputProcessor - // handles expansion to N vision tokens server-side based on grid_thws. - // SMG must NOT pre-expand or the engine will see N placeholders and - // attempt to expand each one again. + // Expand to N pad tokens per image. SGLang uses these directly for + // embedding lookup. TRT-LLM needs only 1 (it re-expands server-side), + // so the router collapses consecutive runs before sending to TRT-LLM + // via `collapse_media_placeholders` in multimodal.rs. Ok(preprocessed .num_img_tokens .iter() - .map(|_| { - let tokens = vec![pad_token_id; 1]; - PromptReplacement::sequence(Modality::Image, &placeholder_token, tokens) + .map(|&num_tokens| { + PromptReplacement::repeated( + Modality::Image, + &placeholder_token, + pad_token_id, + num_tokens, + ) }) .collect()) } @@ -142,10 +146,9 @@ mod tests { ) .unwrap(); - // 1 placeholder per image (engine expands to N vision tokens server-side) - assert_eq!(replacements.len(), 1); - assert_eq!(replacements[0].tokens.len(), 1); - assert_eq!(replacements[0].tokens[0], 163605); + // N pad tokens per image (SGLang uses directly; TRT-LLM collapses to 1) + assert_eq!(replacements[0].tokens.len(), 256); + assert!(replacements[0].tokens.iter().all(|&t| t == 163605)); } #[test] @@ -174,9 +177,9 @@ mod tests { .unwrap(); assert_eq!(replacements.len(), 2); - assert_eq!(replacements[0].tokens.len(), 1); - assert_eq!(replacements[1].tokens.len(), 1); - assert_eq!(replacements[0].tokens[0], 163605); + assert_eq!(replacements[0].tokens.len(), 256); + assert_eq!(replacements[1].tokens.len(), 64); + assert!(replacements[1].tokens.iter().all(|&t| t == 163605)); } #[test] diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index d28c7e3fc..8ae1a2c89 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -690,6 +690,30 @@ fn expand_tokens( // Assembly: convert MultimodalIntermediate → backend-specific MultimodalData // --------------------------------------------------------------------------- +/// Collapse consecutive runs of a media placeholder token ID to a single +/// occurrence. TRT-LLM's `KimiK25InputProcessor` re-expands each single +/// placeholder to N vision tokens based on `grid_thws`, so sending N +/// placeholders causes double-expansion and a count mismatch error. +/// +/// SGLang expects N placeholders (it uses them directly for embedding lookup), +/// so this function must only be called for TRT-LLM requests. +pub(crate) fn collapse_media_placeholders(token_ids: &[u32], im_token_id: u32) -> Vec { + let mut result = Vec::with_capacity(token_ids.len()); + let mut prev_was_placeholder = false; + for &tok in token_ids { + if tok == im_token_id { + if !prev_was_placeholder { + result.push(tok); + } + prev_was_placeholder = true; + } else { + result.push(tok); + prev_was_placeholder = false; + } + } + result +} + /// Assemble backend-specific multimodal data from the intermediate. /// /// Called in request_building after worker selection, when the backend is known. diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index 139b811a7..9c3284b6e 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -10,7 +10,7 @@ use crate::routers::{ grpc::{ common::stages::{helpers, PipelineStage}, context::{ClientSelection, PreparationOutput, RequestContext}, - multimodal::assemble_multimodal_data, + multimodal::{assemble_multimodal_data, collapse_media_placeholders}, proto_wrapper::ProtoRequest, }, }; @@ -64,7 +64,7 @@ impl PipelineStage for ChatRequestBuildingStage { }; let PreparationOutput::Chat { - token_ids, + mut token_ids, processed_messages, tool_constraints, } = prep @@ -103,10 +103,22 @@ impl PipelineStage for ChatRequestBuildingStage { )); } - // Assemble backend-specific multimodal data now that the backend is known + // Assemble backend-specific multimodal data now that the backend is known. + // For TRT-LLM, collapse N consecutive media placeholders to 1 because + // TRT-LLM's KimiK25InputProcessor re-expands each single placeholder + // to N vision tokens. SGLang uses them directly for embedding lookup. + let im_token_id = processed_messages + .multimodal_intermediate + .as_ref() + .and_then(|i| i.im_token_id); let multimodal_data = processed_messages .multimodal_intermediate .map(|intermediate| assemble_multimodal_data(intermediate, builder_client)); + if builder_client.is_trtllm() { + if let Some(id) = im_token_id { + token_ids = collapse_media_placeholders(&token_ids, id); + } + } let eos_token_ids = ctx .tokenizer_arc() diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs index 6cd3cdd8d..51fa9f79c 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs @@ -10,7 +10,7 @@ use crate::routers::{ grpc::{ common::stages::{helpers, PipelineStage}, context::{ClientSelection, PreparationOutput, RequestContext}, - multimodal::assemble_multimodal_data, + multimodal::{assemble_multimodal_data, collapse_media_placeholders}, proto_wrapper::ProtoRequest, }, }; @@ -65,7 +65,7 @@ impl PipelineStage for MessageRequestBuildingStage { }; let PreparationOutput::Messages { - token_ids, + mut token_ids, processed_messages, tool_constraints, } = prep @@ -104,10 +104,19 @@ impl PipelineStage for MessageRequestBuildingStage { )); } - // Assemble backend-specific multimodal data now that the backend is known + // Assemble backend-specific multimodal data; collapse placeholders for TRT-LLM + let im_token_id = processed_messages + .multimodal_intermediate + .as_ref() + .and_then(|i| i.im_token_id); let multimodal_data = processed_messages .multimodal_intermediate .map(|intermediate| assemble_multimodal_data(intermediate, builder_client)); + if builder_client.is_trtllm() { + if let Some(id) = im_token_id { + token_ids = collapse_media_placeholders(&token_ids, id); + } + } let mut proto_request = builder_client .build_messages_request( From 4398a4f8963192f19163d418a9639999fa1191f6 Mon Sep 17 00:00:00 2001 From: Connor Li Date: Wed, 29 Apr 2026 13:59:50 -0700 Subject: [PATCH 23/27] fix(logging): downgrade message_hash log from info to debug --- model_gateway/src/routers/grpc/common/stages/helpers.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 419ef7155..f82509dbc 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -100,7 +100,7 @@ pub(crate) fn compute_and_log_message_hashes( (role.to_string(), hash[..12].to_string()) }) .collect(); - info!( + debug!( target: "smg::request", request_id = %request_id, message_hashes = ?hashes, @@ -142,7 +142,7 @@ pub(crate) fn compute_and_log_input_message_hashes( (role.to_string(), hash[..12].to_string()) }) .collect(); - info!( + debug!( target: "smg::request", request_id = %request_id, message_hashes = ?hashes, From 6a2278fad2bb5cfc85087ea065f6758984c6be55 Mon Sep 17 00:00:00 2001 From: msrinivasa Date: Fri, 17 Apr 2026 18:23:59 -0700 Subject: [PATCH 24/27] feat(protocols): add response_format.type=regex support --- crates/grpc_client/src/sglang_scheduler.rs | 3 +++ crates/grpc_client/src/trtllm_service.rs | 6 ++++++ crates/grpc_client/src/vllm_engine.rs | 3 +++ crates/protocols/src/chat.rs | 1 + crates/protocols/src/common.rs | 2 ++ .../src/routers/grpc/harmony/stages/preparation.rs | 2 +- model_gateway/src/routers/grpc/utils/parsers.rs | 4 +++- 7 files changed, 19 insertions(+), 2 deletions(-) diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index edfc43e0f..8ea8e741f 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -497,6 +497,9 @@ impl SglangSchedulerClient { .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); } + Some(ResponseFormat::Regex { pattern }) => { + constraints.push(proto::sampling_params::Constraint::Regex(pattern.clone())); + } Some(ResponseFormat::Text) | None => { // No constraint for text format } diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index 5acd50374..1b8b9ddf3 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -580,6 +580,12 @@ impl TrtllmServiceClient { guide: schema_str, }); } + Some(ResponseFormat::Regex { pattern }) => { + return Ok(Some(proto::GuidedDecodingParams { + guide_type: proto::guided_decoding_params::GuideType::Regex as i32, + guide: pattern.clone(), + })); + } Some(ResponseFormat::Text) | None => {} } diff --git a/crates/grpc_client/src/vllm_engine.rs b/crates/grpc_client/src/vllm_engine.rs index a6daff4a2..2c408ba07 100644 --- a/crates/grpc_client/src/vllm_engine.rs +++ b/crates/grpc_client/src/vllm_engine.rs @@ -435,6 +435,9 @@ impl VllmEngineClient { .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); } + Some(ResponseFormat::Regex { pattern }) => { + constraints.push(proto::sampling_params::Constraint::Regex(pattern.clone())); + } Some(ResponseFormat::Text) | None => { // No constraint for text format } diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index fe0203104..1b621117a 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -414,6 +414,7 @@ fn validate_chat_cross_parameters( req.regex.is_some(), req.ebnf.is_some(), matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })), + matches!(req.response_format, Some(ResponseFormat::Regex { .. })), ] .iter() .filter(|&&x| x) diff --git a/crates/protocols/src/common.rs b/crates/protocols/src/common.rs index 08cac4a87..92b9b0e9e 100644 --- a/crates/protocols/src/common.rs +++ b/crates/protocols/src/common.rs @@ -275,6 +275,8 @@ pub enum ResponseFormat { JsonObject, #[serde(rename = "json_schema")] JsonSchema { json_schema: JsonSchemaFormat }, + #[serde(rename = "regex")] + Regex { pattern: String }, } #[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] diff --git a/model_gateway/src/routers/grpc/harmony/stages/preparation.rs b/model_gateway/src/routers/grpc/harmony/stages/preparation.rs index 596860750..2ae306ab3 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/preparation.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/preparation.rs @@ -291,7 +291,7 @@ impl HarmonyPreparationStage { }; let schema = match format { - ResponseFormat::Text => return Ok(None), + ResponseFormat::Text | ResponseFormat::Regex { .. } => return Ok(None), ResponseFormat::JsonObject => Cow::Owned(serde_json::json!({"type": "object"})), ResponseFormat::JsonSchema { json_schema } => Cow::Borrowed(&json_schema.schema), }; diff --git a/model_gateway/src/routers/grpc/utils/parsers.rs b/model_gateway/src/routers/grpc/utils/parsers.rs index e69a24f5c..2cbadcefc 100644 --- a/model_gateway/src/routers/grpc/utils/parsers.rs +++ b/model_gateway/src/routers/grpc/utils/parsers.rs @@ -203,7 +203,9 @@ pub(crate) fn has_constrained_output( let constrained_response_format = matches!( response_format, - Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) + Some(ResponseFormat::JsonObject) + | Some(ResponseFormat::JsonSchema { .. }) + | Some(ResponseFormat::Regex { .. }) ); constrained_tool_choice || constrained_response_format From ad34ef9ddb58f09408749ce8eca3ec8dceb0b009 Mon Sep 17 00:00:00 2001 From: Wei Gong Date: Fri, 1 May 2026 17:55:32 -0700 Subject: [PATCH 25/27] Fix regex structured output reasoning parsing --- model_gateway/src/routers/grpc/regular/processor.rs | 12 ++++++------ model_gateway/src/routers/grpc/regular/streaming.rs | 12 ++++++------ model_gateway/src/routers/grpc/utils/parsers.rs | 8 ++++++++ 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 98b8c6017..093418034 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -96,9 +96,9 @@ impl ResponseProcessor { let mut reasoning_text: Option = None; let mut processed_text = final_text; - let has_structured_output = matches!( - original_request.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + let has_structured_output = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), ); if original_request.separate_reasoning @@ -328,9 +328,9 @@ impl ResponseProcessor { let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); - let has_structured_output = matches!( - chat_request.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + let has_structured_output = utils::has_constrained_output( + chat_request.tool_choice.as_ref(), + chat_request.response_format.as_ref(), ); // Check parser availability once upfront (not per choice) diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 876567e5c..3bc34615c 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -206,9 +206,9 @@ impl StreamingProcessor { let mut first_token_time: Option = None; // Extract request parameters — skip reasoning parsing for structured output - let has_structured_output = matches!( - original_request.response_format, - Some(ResponseFormat::JsonObject | ResponseFormat::JsonSchema { .. }) + let has_structured_output = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), ); let separate_reasoning = original_request.separate_reasoning && !has_structured_output; let tool_choice = &original_request.tool_choice; @@ -297,9 +297,9 @@ impl StreamingProcessor { // content even with constraints. The reasoning parser must run. let output_is_constrained = !separate_reasoning && (is_specific_function - || matches!( - &original_request.response_format, - Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) + || utils::has_constrained_output( + tool_choice.as_ref(), + original_request.response_format.as_ref(), )); let tool_parser_available = tools.is_some() diff --git a/model_gateway/src/routers/grpc/utils/parsers.rs b/model_gateway/src/routers/grpc/utils/parsers.rs index 2cbadcefc..1c438b70d 100644 --- a/model_gateway/src/routers/grpc/utils/parsers.rs +++ b/model_gateway/src/routers/grpc/utils/parsers.rs @@ -304,6 +304,14 @@ mod tests { assert!(has_constrained_output(None, Some(&rf))); } + #[test] + fn response_format_regex_is_constrained() { + let rf = ResponseFormat::Regex { + pattern: "(positive|neutral|negative)".to_string(), + }; + assert!(has_constrained_output(None, Some(&rf))); + } + // ── has_constrained_output: combinations ──────────────────────────── #[test] From 883c8e77974179c5b73008c27f0cbaecdade6e1d Mon Sep 17 00:00:00 2001 From: Wei Gong Date: Fri, 1 May 2026 17:59:57 -0700 Subject: [PATCH 26/27] Fix lint formatting --- model_gateway/src/app_context.rs | 5 +--- .../grpc/common/responses/streaming.rs | 15 ++++++++--- .../src/routers/grpc/common/stages/helpers.rs | 2 +- model_gateway/src/routers/grpc/pipeline.rs | 25 +++++++++++++++---- .../src/routers/grpc/regular/streaming.rs | 8 ++---- model_gateway/src/routers/openai/context.rs | 5 +--- 6 files changed, 37 insertions(+), 23 deletions(-) diff --git a/model_gateway/src/app_context.rs b/model_gateway/src/app_context.rs index ea81dde1e..4572b5bf0 100644 --- a/model_gateway/src/app_context.rs +++ b/model_gateway/src/app_context.rs @@ -1,8 +1,5 @@ use std::{ - sync::{ - atomic::AtomicU64, - Arc, OnceLock, - }, + sync::{atomic::AtomicU64, Arc, OnceLock}, time::Duration, }; diff --git a/model_gateway/src/routers/grpc/common/responses/streaming.rs b/model_gateway/src/routers/grpc/common/responses/streaming.rs index cea5f9944..c6c3df18d 100644 --- a/model_gateway/src/routers/grpc/common/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/common/responses/streaming.rs @@ -1097,7 +1097,10 @@ pub(crate) fn attach_mcp_server_label( #[cfg(test)] mod tests { - use std::sync::{atomic::{AtomicU64, Ordering}, Arc}; + use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }; use axum::{body::to_bytes, http::StatusCode}; use bytes::Bytes; @@ -1126,8 +1129,14 @@ mod tests { let _ = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); let stored = last_token_time.load(Ordering::Relaxed); - assert!(stored >= before, "timestamp should be >= before: {stored} < {before}"); - assert!(stored <= before + 5, "timestamp should be within 5s: {stored}"); + assert!( + stored >= before, + "timestamp should be >= before: {stored} < {before}" + ); + assert!( + stored <= before + 5, + "timestamp should be within 5s: {stored}" + ); } #[tokio::test] diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index f82509dbc..90becc88c 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -6,7 +6,7 @@ use openai_protocol::chat::ChatMessage; use rand::Rng; use sha2::{Digest, Sha256}; use smg_grpc_client::sglang_proto::DisaggregatedParams; -use tracing::{debug, info}; +use tracing::debug; use crate::{ routers::grpc::{context::WorkerSelection, proto_wrapper::ProtoGenerateRequest}, diff --git a/model_gateway/src/routers/grpc/pipeline.rs b/model_gateway/src/routers/grpc/pipeline.rs index c9a081e4f..6be279c86 100644 --- a/model_gateway/src/routers/grpc/pipeline.rs +++ b/model_gateway/src/routers/grpc/pipeline.rs @@ -3,7 +3,10 @@ //! This module defines the RequestPipeline orchestrator that coordinates //! the execution of pipeline stages from request preparation to response delivery. -use std::{sync::{atomic::AtomicU64, Arc}, time::Instant}; +use std::{ + sync::{atomic::AtomicU64, Arc}, + time::Instant, +}; use axum::response::{IntoResponse, Response}; use openai_protocol::{ @@ -110,7 +113,10 @@ impl RequestPipeline { } /// Create a regular (single-worker) pipeline - #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_regular( worker_registry: Arc, policy_registry: Arc, @@ -231,7 +237,10 @@ impl RequestPipeline { } /// Create a PD (prefill-decode) pipeline - #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_pd( worker_registry: Arc, policy_registry: Arc, @@ -342,7 +351,10 @@ impl RequestPipeline { /// Uses Messages-specific stages for preparation, request building, and response /// processing. Shares worker selection, client acquisition, dispatch metadata, /// and request execution stages with other pipelines. - #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_messages( worker_registry: Arc, policy_registry: Arc, @@ -393,7 +405,10 @@ impl RequestPipeline { } /// Create a Messages API PD (prefill-decode) pipeline - #[expect(clippy::too_many_arguments, reason = "all params are distinct required dependencies")] + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_messages_pd( worker_registry: Arc, policy_registry: Arc, diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 3bc34615c..dcb3601ba 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -5,10 +5,7 @@ use std::{ collections::{HashMap, HashSet}, io, - sync::{ - atomic::AtomicU64, - Arc, - }, + sync::{atomic::AtomicU64, Arc}, time::Instant, }; @@ -41,8 +38,7 @@ use crate::{ observability::metrics::{metrics_labels, Metrics, StreamingMetricsParams}, routers::grpc::{ common::{ - response_formatting::CompletionTokenTracker, - responses::build_tracked_sse_response, + response_formatting::CompletionTokenTracker, responses::build_tracked_sse_response, }, context, proto_wrapper::{ProtoResponseVariant, ProtoStream}, diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 57b704d32..32b709dbc 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -1,9 +1,6 @@ //! Request context types for OpenAI router pipeline. -use std::sync::{ - atomic::AtomicU64, - Arc, -}; +use std::sync::{atomic::AtomicU64, Arc}; use axum::http::HeaderMap; use openai_protocol::{chat::ChatCompletionRequest, responses::ResponsesRequest}; From 3c42c86c2e4724adba952a2c6e027d78430a36c8 Mon Sep 17 00:00:00 2001 From: Wei Gong Date: Fri, 1 May 2026 18:29:57 -0700 Subject: [PATCH 27/27] Make health_generate probe timeout configurable --- model_gateway/src/config/builder.rs | 5 +++++ model_gateway/src/config/types.rs | 9 +++++++++ model_gateway/src/config/validation.rs | 8 ++++++++ model_gateway/src/main.rs | 5 +++++ model_gateway/src/server.rs | 5 ++++- 5 files changed, 31 insertions(+), 1 deletion(-) diff --git a/model_gateway/src/config/builder.rs b/model_gateway/src/config/builder.rs index 16ff17acf..b69b0d287 100644 --- a/model_gateway/src/config/builder.rs +++ b/model_gateway/src/config/builder.rs @@ -283,6 +283,11 @@ impl RouterConfigBuilder { self } + pub fn health_generate_timeout_secs(mut self, timeout: u64) -> Self { + self.config.health_generate_timeout_secs = timeout; + self + } + // ==================== Discovery ==================== pub fn discovery_config(mut self, discovery: DiscoveryConfig) -> Self { diff --git a/model_gateway/src/config/types.rs b/model_gateway/src/config/types.rs index 55fb887a6..43a6f620b 100644 --- a/model_gateway/src/config/types.rs +++ b/model_gateway/src/config/types.rs @@ -126,6 +126,8 @@ pub struct RouterConfig { #[serde(default)] pub disable_circuit_breaker: bool, pub health_check: HealthCheckConfig, + #[serde(default = "default_health_generate_timeout_secs")] + pub health_generate_timeout_secs: u64, #[serde(default)] pub enable_igw: bool, /// Can be a HuggingFace model ID or local path @@ -222,6 +224,10 @@ fn default_load_monitor_interval_secs() -> u64 { 10 } +fn default_health_generate_timeout_secs() -> u64 { + 3 +} + fn default_enable_l0() -> bool { false } @@ -659,6 +665,7 @@ impl Default for RouterConfig { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + health_generate_timeout_secs: default_health_generate_timeout_secs(), enable_igw: false, connection_mode: ConnectionMode::Http, model_path: None, @@ -773,6 +780,7 @@ mod tests { assert_eq!(config.worker_startup_timeout_secs, 1800); assert_eq!(config.worker_startup_check_interval_secs, 30); assert_eq!(config.load_monitor_interval_secs, 10); + assert_eq!(config.health_generate_timeout_secs, 3); assert!(config.discovery.is_none()); assert!(config.metrics.is_none()); assert!(config.trace_config.is_none()); @@ -960,6 +968,7 @@ stream_retention_secs: 3600 assert!(deserialized.skills_enabled); assert!(deserialized.skills.is_none()); assert!(!deserialized.tenant_resolution.trust_tenant_header); + assert_eq!(deserialized.health_generate_timeout_secs, 3); assert_eq!( deserialized.tenant_resolution.tenant_header_name, DEFAULT_TENANT_HEADER_NAME diff --git a/model_gateway/src/config/validation.rs b/model_gateway/src/config/validation.rs index 7dfe99d00..8914a297f 100644 --- a/model_gateway/src/config/validation.rs +++ b/model_gateway/src/config/validation.rs @@ -522,6 +522,14 @@ impl ConfigValidator { }); } + if config.health_generate_timeout_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "health_generate_timeout_secs".to_string(), + value: config.health_generate_timeout_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + if config.worker_startup_check_interval_secs == 0 { return Err(ConfigError::InvalidValue { field: "worker_startup_check_interval_secs".to_string(), diff --git a/model_gateway/src/main.rs b/model_gateway/src/main.rs index 2087e1b06..5946fba4e 100644 --- a/model_gateway/src/main.rs +++ b/model_gateway/src/main.rs @@ -426,6 +426,10 @@ struct CliArgs { #[arg(long, default_value_t = false, help_heading = "Health Checks")] disable_health_check: bool, + /// Timeout in seconds for /health_generate inference probes + #[arg(long, default_value_t = 3, help_heading = "Health Checks")] + health_generate_timeout_secs: u64, + /// Remove workers from the registry when they are marked unhealthy #[arg(long, default_value_t = false, help_heading = "Health Checks")] remove_unhealthy_workers: bool, @@ -1229,6 +1233,7 @@ impl CliArgs { disable_health_check: self.disable_health_check, remove_unhealthy_workers: self.remove_unhealthy_workers, }) + .health_generate_timeout_secs(self.health_generate_timeout_secs) .tokenizer_cache(TokenizerCacheConfig { enable_l0: self.tokenizer_cache_enable_l0, l0_max_entries: self.tokenizer_cache_l0_max_entries, diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index 532dbc100..08d699fed 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -230,6 +230,7 @@ async fn health_generate(State(state): State>, _req: Request) -> R target: "smg::health", model = %model_id, probe_url = %probe_url, + timeout_secs = state.context.router_config.health_generate_timeout_secs, max_tokens = 1, "health_generate: sending real inference probe" ); @@ -240,7 +241,9 @@ async fn health_generate(State(state): State>, _req: Request) -> R .client .post(&probe_url) .json(&probe_body) - .timeout(Duration::from_secs(60)) + .timeout(Duration::from_secs( + state.context.router_config.health_generate_timeout_secs, + )) .send() .await; let duration_ms = start.elapsed().as_millis();