diff --git a/bindings/python/src/lib.rs b/bindings/python/src/lib.rs index 715cd98e1..7e9b2a6ed 100644 --- a/bindings/python/src/lib.rs +++ b/bindings/python/src/lib.rs @@ -445,6 +445,7 @@ struct Router { reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, + enable_message_hash: bool, storage_hook_wasm_path: Option, backend: BackendType, history_backend: HistoryBackendType, @@ -731,6 +732,7 @@ impl Router { .maybe_tool_call_parser(self.tool_call_parser.as_ref()) .maybe_mcp_config_path(self.mcp_config_path.as_ref()) .maybe_storage_hook_wasm_path(self.storage_hook_wasm_path.as_deref()) + .enable_message_hash(self.enable_message_hash) .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) @@ -833,6 +835,7 @@ impl Router { reasoning_parser = None, tool_call_parser = None, mcp_config_path = None, + enable_message_hash = false, storage_hook_wasm_path = None, backend = BackendType::Sglang, history_backend = HistoryBackendType::Memory, @@ -942,6 +945,7 @@ impl Router { reasoning_parser: Option, tool_call_parser: Option, mcp_config_path: Option, + enable_message_hash: bool, storage_hook_wasm_path: Option, backend: BackendType, history_backend: HistoryBackendType, @@ -1062,6 +1066,7 @@ impl Router { reasoning_parser, tool_call_parser, mcp_config_path, + enable_message_hash, storage_hook_wasm_path, backend, history_backend, diff --git a/bindings/python/src/smg/router_args.py b/bindings/python/src/smg/router_args.py index f769eccd2..65955b298 100644 --- a/bindings/python/src/smg/router_args.py +++ b/bindings/python/src/smg/router_args.py @@ -119,6 +119,8 @@ class RouterArgs: mcp_config_path: str | None = None # Backend selection backend: str = "sglang" + # Message hash logging for session reconstruction + enable_message_hash: bool = False # Storage hooks (WASM) storage_hook_wasm_path: str | None = None # History backend configuration @@ -472,6 +474,12 @@ def add_cli_args( default=RouterArgs.log_json, help="Output logs in JSON format", ) + logging_group.add_argument( + f"--{prefix}enable-message-hash", + action="store_true", + default=RouterArgs.enable_message_hash, + help="Compute per-message SHA-256 hashes for session reconstruction logging", + ) # Service discovery configuration k8s_group.add_argument( diff --git a/crates/grpc_client/proto/trtllm_service.proto b/crates/grpc_client/proto/trtllm_service.proto index e8427f528..a593ebf86 100644 --- a/crates/grpc_client/proto/trtllm_service.proto +++ b/crates/grpc_client/proto/trtllm_service.proto @@ -132,6 +132,16 @@ message GenerateRequest { // When true, stop token IDs are retained in output_token_ids instead of // being stripped. bool include_stop_token_in_output = 26; + + // Per-message SHA-256 hashes for session reconstruction. + repeated MessageHash message_hashes = 27; +} + +// Per-message hash for session reconstruction auditing. +// Hash is the first 12 hex chars of sha256(role + "\x00" + content). +message MessageHash { + string role = 1; + string hash = 2; } // Tokenized input from router diff --git a/crates/grpc_client/src/sglang_scheduler.rs b/crates/grpc_client/src/sglang_scheduler.rs index edfc43e0f..8ea8e741f 100644 --- a/crates/grpc_client/src/sglang_scheduler.rs +++ b/crates/grpc_client/src/sglang_scheduler.rs @@ -497,6 +497,9 @@ impl SglangSchedulerClient { .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); } + Some(ResponseFormat::Regex { pattern }) => { + constraints.push(proto::sampling_params::Constraint::Regex(pattern.clone())); + } Some(ResponseFormat::Text) | None => { // No constraint for text format } diff --git a/crates/grpc_client/src/trtllm_service.rs b/crates/grpc_client/src/trtllm_service.rs index cc4d6ac9e..1b8b9ddf3 100644 --- a/crates/grpc_client/src/trtllm_service.rs +++ b/crates/grpc_client/src/trtllm_service.rs @@ -265,6 +265,10 @@ impl TrtllmServiceClient { clippy::unused_self, reason = "method receiver kept for consistent public API across gRPC backends" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder requires all fields for the proto message" + )] pub fn build_generate_request_from_chat( &self, request_id: String, @@ -273,6 +277,8 @@ impl TrtllmServiceClient { token_ids: Vec, multimodal_input: Option, tool_call_constraint: Option<(String, String)>, // (constraint_type, constraint_value) + eos_token_ids: &[u32], + message_hashes: Option>, ) -> Result { // Build sampling config let sampling_config = Self::build_sampling_config_from_chat(body); @@ -287,6 +293,23 @@ impl TrtllmServiceClient { let max_tokens = body.max_completion_tokens.unwrap_or(2048); + // Pass merged EOS token IDs from config.json + generation_config.json. + // TRT-LLM's gRPC path does not reliably merge these internally, + // so we provide them explicitly via the standard stop_token_ids field. + let stop_token_ids: Vec = if body.ignore_eos { + vec![] + } else { + eos_token_ids.to_vec() + }; + + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -299,7 +322,7 @@ impl TrtllmServiceClient { max_tokens, streaming: body.stream, stop, - stop_token_ids: vec![], + stop_token_ids, ignore_eos: body.ignore_eos, bad: vec![], bad_token_ids: vec![], @@ -314,6 +337,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -397,6 +421,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: vec![], }; Ok(grpc_request) @@ -414,6 +439,7 @@ impl TrtllmServiceClient { processed_text: String, token_ids: Vec, constraint: Option<(String, String)>, + message_hashes: Option>, ) -> Result { let sampling_config = Self::build_sampling_config_from_responses(body); let output_config = proto::OutputConfig { @@ -430,6 +456,14 @@ impl TrtllmServiceClient { let max_tokens = body.max_output_tokens.unwrap_or(2048); + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -457,6 +491,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -545,6 +580,12 @@ impl TrtllmServiceClient { guide: schema_str, }); } + Some(ResponseFormat::Regex { pattern }) => { + return Ok(Some(proto::GuidedDecodingParams { + guide_type: proto::guided_decoding_params::GuideType::Regex as i32, + guide: pattern.clone(), + })); + } Some(ResponseFormat::Text) | None => {} } @@ -641,6 +682,10 @@ impl TrtllmServiceClient { clippy::unused_self, reason = "method receiver kept for consistent public API" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder needs all fields from the Messages API request" + )] pub fn build_generate_request_from_messages( &self, request_id: String, @@ -649,6 +694,7 @@ impl TrtllmServiceClient { token_ids: Vec, multimodal_input: Option, tool_call_constraint: Option<(String, String)>, + message_hashes: Option>, ) -> Result { let sampling_config = Self::build_sampling_config_from_messages(body); let output_config = proto::OutputConfig { @@ -666,6 +712,14 @@ impl TrtllmServiceClient { let stop = body.stop_sequences.clone().unwrap_or_default(); let max_tokens = body.max_tokens; + let proto_message_hashes = message_hashes + .map(|h| { + h.into_iter() + .map(|(role, hash)| proto::MessageHash { role, hash }) + .collect() + }) + .unwrap_or_default(); + let grpc_request = proto::GenerateRequest { request_id, tokenized: Some(proto::TokenizedInput { @@ -693,6 +747,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: false, + message_hashes: proto_message_hashes, }; Ok(grpc_request) @@ -784,6 +839,7 @@ impl TrtllmServiceClient { cache_salt_id: None, arrival_time: None, include_stop_token_in_output: body.no_stop_trim, + message_hashes: vec![], }; Ok(grpc_request) diff --git a/crates/grpc_client/src/vllm_engine.rs b/crates/grpc_client/src/vllm_engine.rs index a6daff4a2..2c408ba07 100644 --- a/crates/grpc_client/src/vllm_engine.rs +++ b/crates/grpc_client/src/vllm_engine.rs @@ -435,6 +435,9 @@ impl VllmEngineClient { .map_err(|e| format!("Failed to serialize JSON schema: {e}"))?; constraints.push(proto::sampling_params::Constraint::JsonSchema(schema_str)); } + Some(ResponseFormat::Regex { pattern }) => { + constraints.push(proto::sampling_params::Constraint::Regex(pattern.clone())); + } Some(ResponseFormat::Text) | None => { // No constraint for text format } diff --git a/crates/multimodal/src/registry/kimi_k25.rs b/crates/multimodal/src/registry/kimi_k25.rs index 1347bf285..ccdae33da 100644 --- a/crates/multimodal/src/registry/kimi_k25.rs +++ b/crates/multimodal/src/registry/kimi_k25.rs @@ -61,6 +61,10 @@ impl ModelProcessorSpec for KimiK25VisionSpec { ) -> RegistryResult> { let pad_token_id = Self::pad_token_id(metadata)?; let placeholder_token = self.placeholder_token(metadata)?; + // Expand to N pad tokens per image. SGLang uses these directly for + // embedding lookup. TRT-LLM needs only 1 (it re-expands server-side), + // so the router collapses consecutive runs before sending to TRT-LLM + // via `collapse_media_placeholders` in multimodal.rs. Ok(preprocessed .num_img_tokens .iter() @@ -142,7 +146,7 @@ mod tests { ) .unwrap(); - // 256 pad tokens (no start/end wrapper — SGLang handles that in the chat template) + // N pad tokens per image (SGLang uses directly; TRT-LLM collapses to 1) assert_eq!(replacements[0].tokens.len(), 256); assert!(replacements[0].tokens.iter().all(|&t| t == 163605)); } diff --git a/crates/multimodal/src/registry/mod.rs b/crates/multimodal/src/registry/mod.rs index 6ddb366da..a70ddfc61 100644 --- a/crates/multimodal/src/registry/mod.rs +++ b/crates/multimodal/src/registry/mod.rs @@ -126,14 +126,7 @@ pub(super) mod test_helpers { fn get_special_tokens(&self) -> &SpecialTokens { static TOKENS: Lazy = Lazy::new(|| SpecialTokens { - bos_token: None, - eos_token: None, - unk_token: None, - sep_token: None, - pad_token: None, - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }); &TOKENS } diff --git a/crates/protocols/src/chat.rs b/crates/protocols/src/chat.rs index 35fdaf31d..1b621117a 100644 --- a/crates/protocols/src/chat.rs +++ b/crates/protocols/src/chat.rs @@ -11,6 +11,7 @@ use super::{ StringOrArray, Tool, ToolCall, ToolCallDelta, ToolChoice, ToolChoiceValue, ToolReference, Usage, }, + messages::ThinkingConfig, sampling_params::{validate_top_k_value, validate_top_p_value}, }; use crate::{ @@ -48,6 +49,7 @@ pub enum ChatMessage { Tool { content: MessageContent, tool_call_id: String, + name: Option, }, #[serde(rename = "function")] Function { content: String, name: String }, @@ -201,6 +203,10 @@ pub struct ChatCompletionRequest { /// Effort level for reasoning models (low, medium, high) pub reasoning_effort: Option, + /// Configuration for extended thinking (Anthropic-style). + /// Maps to chat_template_kwargs for thinking-capable models. + pub thinking: Option, + /// An object specifying the format that the model must output pub response_format: Option, @@ -317,6 +323,11 @@ pub struct ChatCompletionRequest { /// Random seed for sampling for deterministic outputs pub sampling_seed: Option, + /// User-supplied request ID for log correlation. + /// If set, SMG passes it through to the engine instead of generating its own UUID. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Additional fields not explicitly defined above (e.g. engine-specific parameters) #[serde(flatten)] pub other: Map, @@ -403,6 +414,7 @@ fn validate_chat_cross_parameters( req.regex.is_some(), req.ebnf.is_some(), matches!(req.response_format, Some(ResponseFormat::JsonSchema { .. })), + matches!(req.response_format, Some(ResponseFormat::Regex { .. })), ] .iter() .filter(|&&x| x) @@ -567,6 +579,29 @@ impl Normalizable for ChatCompletionRequest { self.function_call = None; // Clear deprecated field } + // Migrate thinking → chat_template_kwargs + if let Some(ref thinking) = self.thinking { + let kwargs = self.chat_template_kwargs.get_or_insert_with(HashMap::new); + match thinking { + ThinkingConfig::Enabled { .. } => { + kwargs + .entry("enable_thinking".to_string()) + .or_insert(Value::Bool(true)); + kwargs + .entry("thinking".to_string()) + .or_insert(Value::Bool(true)); + } + ThinkingConfig::Disabled => { + kwargs + .entry("enable_thinking".to_string()) + .or_insert(Value::Bool(false)); + kwargs + .entry("thinking".to_string()) + .or_insert(Value::Bool(false)); + } + } + } + // Apply tool_choice defaults if self.tool_choice.is_none() { if let Some(tools) = &self.tools { diff --git a/crates/protocols/src/common.rs b/crates/protocols/src/common.rs index 9d2688c9c..92b9b0e9e 100644 --- a/crates/protocols/src/common.rs +++ b/crates/protocols/src/common.rs @@ -195,13 +195,68 @@ pub enum ContentPart { VideoUrl { video_url: VideoUrl }, } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)] +#[derive(Debug, Clone, PartialEq, schemars::JsonSchema)] pub struct ImageUrl { pub url: String, - #[serde(skip_serializing_if = "Option::is_none")] pub detail: Option, // "auto", "low", or "high" } +impl Serialize for ImageUrl { + fn serialize(&self, serializer: S) -> Result { + use serde::ser::SerializeStruct; + let field_count = 1 + self.detail.is_some() as usize; + let mut state = serializer.serialize_struct("ImageUrl", field_count)?; + state.serialize_field("url", &self.url)?; + if let Some(ref detail) = self.detail { + state.serialize_field("detail", detail)?; + } + state.end() + } +} + +impl<'de> Deserialize<'de> for ImageUrl { + fn deserialize>(deserializer: D) -> Result { + use serde::de; + + struct ImageUrlVisitor; + + impl<'de> de::Visitor<'de> for ImageUrlVisitor { + type Value = ImageUrl; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string URL or an object with url field") + } + + fn visit_str(self, v: &str) -> Result { + Ok(ImageUrl { + url: v.to_string(), + detail: None, + }) + } + + fn visit_map>(self, mut map: M) -> Result { + let mut url = None; + let mut detail = None; + while let Some(key) = map.next_key::()? { + match key.as_str() { + "url" => url = Some(map.next_value()?), + "detail" => detail = map.next_value()?, + _ => { + let _ = map.next_value::()?; + } + } + } + Ok(ImageUrl { + url: url.ok_or_else(|| de::Error::missing_field("url"))?, + detail, + }) + } + } + + deserializer.deserialize_any(ImageUrlVisitor) + } +} + #[derive(Debug, Clone, Deserialize, Serialize, PartialEq, schemars::JsonSchema)] pub struct VideoUrl { pub url: String, @@ -220,6 +275,8 @@ pub enum ResponseFormat { JsonObject, #[serde(rename = "json_schema")] JsonSchema { json_schema: JsonSchemaFormat }, + #[serde(rename = "regex")] + Regex { pattern: String }, } #[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] diff --git a/crates/protocols/src/completion.rs b/crates/protocols/src/completion.rs index 393596fe9..6229f29cd 100644 --- a/crates/protocols/src/completion.rs +++ b/crates/protocols/src/completion.rs @@ -140,6 +140,10 @@ pub struct CompletionRequest { /// Sampling seed for deterministic outputs pub sampling_seed: Option, + /// User-supplied request ID for log correlation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, + /// Additional fields including bootstrap info for PD routing #[serde(flatten)] pub other: Map, diff --git a/crates/protocols/src/messages.rs b/crates/protocols/src/messages.rs index 76adcf77e..4b6c3b4d9 100644 --- a/crates/protocols/src/messages.rs +++ b/crates/protocols/src/messages.rs @@ -73,6 +73,10 @@ pub struct CreateMessageRequest { /// MCP servers to be utilized in this request (beta). pub mcp_servers: Option>, + + /// User-supplied request ID for log correlation. + #[serde(default, skip_serializing_if = "Option::is_none")] + pub request_id: Option, } impl Normalizable for CreateMessageRequest { @@ -1843,6 +1847,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, } } diff --git a/crates/reasoning_parser/src/parsers/base.rs b/crates/reasoning_parser/src/parsers/base.rs index 6214acd69..0adcc36ba 100644 --- a/crates/reasoning_parser/src/parsers/base.rs +++ b/crates/reasoning_parser/src/parsers/base.rs @@ -63,6 +63,12 @@ impl ReasoningParser for BaseReasoningParser { .to_string(); if !processed_text.contains(&self.config.think_end_token) { + // Don't consume tool call markers as reasoning content + if let Some(tool_pos) = processed_text.find("<|tool_calls_section_begin|>") { + let reasoning_text = processed_text[..tool_pos].trim().to_string(); + let normal_text = processed_text[tool_pos..].to_string(); + return Ok(ParserResult::new(normal_text, reasoning_text)); + } // Assume reasoning was truncated before end token return Ok(ParserResult::reasoning(processed_text)); } @@ -133,6 +139,14 @@ impl ReasoningParser for BaseReasoningParser { // Continue with reasoning content if self.in_reasoning && self.config.stream_reasoning { + // Some models skip and go straight to tool calls + if let Some(tool_pos) = current_text.find("<|tool_calls_section_begin|>") { + let reasoning_text = current_text[..tool_pos].trim().to_string(); + let normal_text = current_text[tool_pos..].to_string(); + self.buffer.clear(); + self.in_reasoning = false; + return Ok(ParserResult::new(normal_text, reasoning_text)); + } // Stream the content immediately let reasoning_text = current_text; self.buffer.clear(); diff --git a/crates/tokenizer/src/chat_template.rs b/crates/tokenizer/src/chat_template.rs index 2ac3b50c7..7a45a9273 100644 --- a/crates/tokenizer/src/chat_template.rs +++ b/crates/tokenizer/src/chat_template.rs @@ -597,6 +597,15 @@ fn build_environment(template: String) -> Result> { // like ensure_ascii, separators, and sort_keys that HuggingFace templates use env.add_filter("tojson", tojson_filter); + // HuggingFace's Jinja2 environment registers `raise_exception` as a global + // callable so templates can do `{{ raise_exception("msg") }}`. + env.add_function( + "raise_exception", + |msg: String| -> Result { + Err(MinijinjaError::new(ErrorKind::InvalidOperation, msg)) + }, + ); + Ok(env) } diff --git a/crates/tokenizer/src/huggingface.rs b/crates/tokenizer/src/huggingface.rs index e13ca1b00..fc3cbeebc 100644 --- a/crates/tokenizer/src/huggingface.rs +++ b/crates/tokenizer/src/huggingface.rs @@ -236,6 +236,7 @@ impl HuggingFaceTokenizer { cls_token: find_token(&["[CLS]", "", ""]), mask_token: find_token(&["[MASK]", "", ""]), additional_special_tokens, + ..Default::default() } } diff --git a/crates/tokenizer/src/mock.rs b/crates/tokenizer/src/mock.rs index f057aabe2..ad2a6b636 100644 --- a/crates/tokenizer/src/mock.rs +++ b/crates/tokenizer/src/mock.rs @@ -51,11 +51,7 @@ impl MockTokenizer { bos_token: Some("".to_string()), eos_token: Some("".to_string()), unk_token: Some("".to_string()), - sep_token: None, - pad_token: None, - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }; Self { diff --git a/crates/tokenizer/src/tiktoken.rs b/crates/tokenizer/src/tiktoken.rs index eb5178832..e56a34ddf 100644 --- a/crates/tokenizer/src/tiktoken.rs +++ b/crates/tokenizer/src/tiktoken.rs @@ -119,6 +119,7 @@ fn parse_special_tokens(config: &serde_json::Value) -> SpecialTokens { cls_token: get_str("cls_token"), mask_token: get_str("mask_token"), additional_special_tokens: additional, + ..Default::default() } } @@ -261,7 +262,6 @@ impl TiktokenTokenizer { }) }; - // Load merged EOS token IDs from config.json + generation_config.json let eos_token_ids = crate::eos::load_eos_token_ids(dir); Ok(TiktokenTokenizer { @@ -314,27 +314,20 @@ impl TiktokenTokenizer { TiktokenModel::Cl100kBase => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), - unk_token: None, - sep_token: None, pad_token: Some("<|endoftext|>".to_string()), - cls_token: None, - mask_token: None, additional_special_tokens: vec![ "<|fim_prefix|>".to_string(), "<|fim_middle|>".to_string(), "<|fim_suffix|>".to_string(), "<|endofprompt|>".to_string(), ], + ..Default::default() }, _ => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), - unk_token: None, - sep_token: None, pad_token: Some("<|endoftext|>".to_string()), - cls_token: None, - mask_token: None, - additional_special_tokens: vec![], + ..Default::default() }, } } @@ -443,20 +436,6 @@ pub fn is_tiktoken_file(path: &Path) -> bool { impl Encoder for TiktokenTokenizer { fn encode(&self, input: &str, _add_special_tokens: bool) -> Result { - // Always use encode_with_special_tokens so that special token strings - // in the input (e.g., <|media_pad|> from chat templates) are recognized - // as single tokens rather than split into BPE sub-tokens. - // - // NOTE: We intentionally ignore `add_special_tokens` here because the - // flag has different semantics across backends. For HuggingFace it - // controls BOS/EOS prepend/append (tiktoken has no such concept). - // For tiktoken, encode_ordinary vs encode_with_special_tokens controls - // whether special-token *patterns* in the input are recognized. - // All callers that encode chat-template-rendered text pass `false` - // (meaning "don't add BOS/EOS"), but tiktoken must still recognize - // the special tokens the template inserted. A proper fix requires - // redesigning the Encoder trait to separate "add wrapper tokens" from - // "recognize special-token patterns". let tokens = self.tokenizer.encode_with_special_tokens(input); Ok(Encoding::Tiktoken(tokens)) } @@ -480,7 +459,7 @@ impl Decoder for TiktokenTokenizer { ._decode_native_and_split(token_ids.to_vec()) .flatten() .collect(); - tracing::warn!( + tracing::debug!( error = %err, token_count = token_ids.len(), "tiktoken decode failed; returning lossy UTF-8 fallback" diff --git a/crates/tokenizer/src/traits.rs b/crates/tokenizer/src/traits.rs index c645d5910..12ae7ab4d 100644 --- a/crates/tokenizer/src/traits.rs +++ b/crates/tokenizer/src/traits.rs @@ -114,6 +114,12 @@ pub trait Tokenizer: Encoder + Decoder { false } + /// Merged EOS token IDs from config.json and generation_config.json. + /// Backends should stop generation when any of these tokens is produced. + fn eos_token_ids(&self) -> &[TokenIdType] { + &[] + } + /// Set or override the chat template. /// /// Returns an error if the template fails to parse or the tokenizer @@ -123,14 +129,6 @@ pub trait Tokenizer: Encoder + Decoder { "set_chat_template is not supported by this tokenizer" )) } - - /// EOS token IDs for stop detection. - /// - /// Merged from `config.json` and `generation_config.json` (eos_token_id, int or list). - /// Models can have multiple EOS tokens (e.g., Llama 3: end_of_text + eom_id + eot_id). - fn eos_token_ids(&self) -> &[TokenIdType] { - &[] - } } /// Contains the results of tokenizing text: token IDs, string tokens, and their spans @@ -184,4 +182,9 @@ pub struct SpecialTokens { pub cls_token: Option, pub mask_token: Option, pub additional_special_tokens: Vec, + /// Merged EOS token IDs from config.json + generation_config.json. + /// Models like Kimi-K2.5 define different EOS tokens in each file + /// (e.g. `[EOS]` in config.json, `<|im_end|>` in generation_config.json). + /// The engine must stop at any of these tokens. + pub eos_token_ids: Vec, } diff --git a/crates/tool_parser/src/parsers/kimik2.rs b/crates/tool_parser/src/parsers/kimik2.rs index 52f47de27..d5af43791 100644 --- a/crates/tool_parser/src/parsers/kimik2.rs +++ b/crates/tool_parser/src/parsers/kimik2.rs @@ -105,20 +105,20 @@ impl KimiK2Parser { reason = "regex patterns are compile-time string literals" )] pub fn new() -> Self { - // Pattern for complete tool calls - let tool_call_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*?\})\s*<\|tool_call_end\|>"; + // Supports alternative delimiters: <|func_start|>/<|func_end|>; (?s) for multi-line JSON. + // Tool call IDs may contain hyphens (e.g. "functions.execute-tool:0"). + let tool_call_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.\-]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*?\})\s*(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern"); - // Pattern for streaming (partial) tool calls - let stream_pattern = r"<\|tool_call_begin\|>\s*(?P[\w\.]+:\d+)\s*<\|tool_call_argument_begin\|>\s*(?P\{.*)"; + let stream_pattern = r"(?s)<\|tool_call_begin\|>\s*(?P[\w\.\-]+:\d+)\s*(?:<\|tool_call_argument_begin\|>\s*|<\|func_start\|>\s*)?(?P\{.*)"; let stream_tool_call_extractor = Regex::new(stream_pattern).expect("Valid regex pattern"); // Pattern for removing completed tool calls - let end_pattern = r"<\|tool_call_begin\|>.*?<\|tool_call_end\|>"; + let end_pattern = r"<\|tool_call_begin\|>.*?(?:<\|tool_call_end\|>|<\|func_end\|>)"; let tool_call_end_pattern = Regex::new(end_pattern).expect("Valid regex pattern"); - // Robust parser for ids like "functions.search:0" or fallback "search:0" - let id_pattern = r"^(?:functions\.)?(?P[\w\.]+):(?P\d+)$"; + // Robust parser for ids like "functions.execute-tool:0" or fallback "search:0" + let id_pattern = r"^(?:functions\.)?(?P[\w\.\-]+):(?P\d+)$"; let tool_call_id_regex = Regex::new(id_pattern).expect("Valid regex pattern"); Self { @@ -229,7 +229,11 @@ impl ToolParser for KimiK2Parser { // No tool markers detected - return all buffered content as normal text let mut normal_text = std::mem::take(&mut self.buffer); // Remove end tokens if present - for e_token in ["<|tool_calls_section_end|>", "<|tool_call_end|>"] { + for e_token in [ + "<|tool_calls_section_end|>", + "<|tool_call_end|>", + "<|func_end|>", + ] { normal_text = normal_text.replace(e_token, ""); } return Ok(StreamingParseResult { @@ -290,13 +294,15 @@ impl ToolParser for KimiK2Parser { function_args }; - // Split by end token before sending (like Python does) - let parsed_args_diff = - if let Some(pos) = argument_diff.find("<|tool_call_end|>") { - &argument_diff[..pos] - } else { - argument_diff - }; + // Split by end token before sending + let end_pos = argument_diff + .find("<|tool_call_end|>") + .or_else(|| argument_diff.find("<|func_end|>")); + let parsed_args_diff = if let Some(pos) = end_pos { + &argument_diff[..pos] + } else { + argument_diff + }; if !parsed_args_diff.is_empty() { calls.push(ToolCallItem { @@ -313,8 +319,10 @@ impl ToolParser for KimiK2Parser { } // Check completeness - split by end token first - let parsed_args = if let Some(pos) = function_args.find("<|tool_call_end|>") - { + let end_pos2 = function_args + .find("<|tool_call_end|>") + .or_else(|| function_args.find("<|func_end|>")); + let parsed_args = if let Some(pos) = end_pos2 { &function_args[..pos] } else { function_args diff --git a/model_gateway/benches/request_processing.rs b/model_gateway/benches/request_processing.rs index 465143ec4..573deca16 100644 --- a/model_gateway/benches/request_processing.rs +++ b/model_gateway/benches/request_processing.rs @@ -126,6 +126,7 @@ fn default_completion_request() -> CompletionRequest { session_params: None, return_hidden_states: false, sampling_seed: None, + request_id: None, other: serde_json::Map::new(), } } diff --git a/model_gateway/benches/wasm_middleware_latency.rs b/model_gateway/benches/wasm_middleware_latency.rs index 33b1b3221..f1c8889e1 100644 --- a/model_gateway/benches/wasm_middleware_latency.rs +++ b/model_gateway/benches/wasm_middleware_latency.rs @@ -79,6 +79,7 @@ fn bench_wasm_middleware_buffering(c: &mut Criterion) { concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); c.bench_function("wasm_middleware_pre_fix_latency", |b| { diff --git a/model_gateway/src/app_context.rs b/model_gateway/src/app_context.rs index f53898e3c..4572b5bf0 100644 --- a/model_gateway/src/app_context.rs +++ b/model_gateway/src/app_context.rs @@ -1,5 +1,5 @@ use std::{ - sync::{Arc, OnceLock}, + sync::{atomic::AtomicU64, Arc, OnceLock}, time::Duration, }; @@ -77,6 +77,10 @@ pub struct AppContext { pub wasm_manager: Option>, pub worker_service: Arc, pub inflight_tracker: Arc, + /// Unix timestamp (seconds) of the last token chunk forwarded to a client. + /// Zero means no tokens have been forwarded yet. Used by health_generate to + /// skip the real inference probe when traffic is flowing recently. + pub last_token_time: Arc, pub kv_event_monitor: Option>, pub realtime_registry: Arc, /// Bind address for WebRTC UDP sockets (`None` = `0.0.0.0`, auto-detect). @@ -400,6 +404,7 @@ impl AppContextBuilder { wasm_manager: self.wasm_manager, worker_service, inflight_tracker: InFlightRequestTracker::new(), + last_token_time: Arc::new(AtomicU64::new(0)), kv_event_monitor: self.kv_event_monitor, realtime_registry: Arc::new(RealtimeRegistry::new()), webrtc_bind_addr: self.webrtc_bind_addr, diff --git a/model_gateway/src/config/builder.rs b/model_gateway/src/config/builder.rs index 209daefdc..b69b0d287 100644 --- a/model_gateway/src/config/builder.rs +++ b/model_gateway/src/config/builder.rs @@ -283,6 +283,11 @@ impl RouterConfigBuilder { self } + pub fn health_generate_timeout_secs(mut self, timeout: u64) -> Self { + self.config.health_generate_timeout_secs = timeout; + self + } + // ==================== Discovery ==================== pub fn discovery_config(mut self, discovery: DiscoveryConfig) -> Self { @@ -385,6 +390,13 @@ impl RouterConfigBuilder { self } + // ==================== LOGGING ==================== + + pub fn enable_message_hash(mut self, enable: bool) -> Self { + self.config.enable_message_hash = enable; + self + } + // ==================== WASM ==================== pub fn enable_wasm(mut self, enable: bool) -> Self { diff --git a/model_gateway/src/config/types.rs b/model_gateway/src/config/types.rs index 1a675bd4b..43a6f620b 100644 --- a/model_gateway/src/config/types.rs +++ b/model_gateway/src/config/types.rs @@ -126,6 +126,8 @@ pub struct RouterConfig { #[serde(default)] pub disable_circuit_breaker: bool, pub health_check: HealthCheckConfig, + #[serde(default = "default_health_generate_timeout_secs")] + pub health_generate_timeout_secs: u64, #[serde(default)] pub enable_igw: bool, /// Can be a HuggingFace model ID or local path @@ -175,6 +177,9 @@ pub struct RouterConfig { /// Loaded from skills_config_path during config creation. #[serde(skip)] pub skills: Option, + /// Compute per-message SHA-256 hashes for session reconstruction logging + #[serde(default)] + pub enable_message_hash: bool, /// Enable WASM support #[serde(default)] pub enable_wasm: bool, @@ -219,6 +224,10 @@ fn default_load_monitor_interval_secs() -> u64 { 10 } +fn default_health_generate_timeout_secs() -> u64 { + 3 +} + fn default_enable_l0() -> bool { false } @@ -656,6 +665,7 @@ impl Default for RouterConfig { disable_retries: false, disable_circuit_breaker: false, health_check: HealthCheckConfig::default(), + health_generate_timeout_secs: default_health_generate_timeout_secs(), enable_igw: false, connection_mode: ConnectionMode::Http, model_path: None, @@ -674,6 +684,7 @@ impl Default for RouterConfig { mcp_config: None, skills_enabled: false, skills: None, + enable_message_hash: false, enable_wasm: false, storage_hook_wasm_path: None, server_cert: None, @@ -769,6 +780,7 @@ mod tests { assert_eq!(config.worker_startup_timeout_secs, 1800); assert_eq!(config.worker_startup_check_interval_secs, 30); assert_eq!(config.load_monitor_interval_secs, 10); + assert_eq!(config.health_generate_timeout_secs, 3); assert!(config.discovery.is_none()); assert!(config.metrics.is_none()); assert!(config.trace_config.is_none()); @@ -956,6 +968,7 @@ stream_retention_secs: 3600 assert!(deserialized.skills_enabled); assert!(deserialized.skills.is_none()); assert!(!deserialized.tenant_resolution.trust_tenant_header); + assert_eq!(deserialized.health_generate_timeout_secs, 3); assert_eq!( deserialized.tenant_resolution.tenant_header_name, DEFAULT_TENANT_HEADER_NAME diff --git a/model_gateway/src/config/validation.rs b/model_gateway/src/config/validation.rs index 7dfe99d00..8914a297f 100644 --- a/model_gateway/src/config/validation.rs +++ b/model_gateway/src/config/validation.rs @@ -522,6 +522,14 @@ impl ConfigValidator { }); } + if config.health_generate_timeout_secs == 0 { + return Err(ConfigError::InvalidValue { + field: "health_generate_timeout_secs".to_string(), + value: config.health_generate_timeout_secs.to_string(), + reason: "Must be > 0".to_string(), + }); + } + if config.worker_startup_check_interval_secs == 0 { return Err(ConfigError::InvalidValue { field: "worker_startup_check_interval_secs".to_string(), diff --git a/model_gateway/src/main.rs b/model_gateway/src/main.rs index b2d99fca2..5946fba4e 100644 --- a/model_gateway/src/main.rs +++ b/model_gateway/src/main.rs @@ -426,6 +426,10 @@ struct CliArgs { #[arg(long, default_value_t = false, help_heading = "Health Checks")] disable_health_check: bool, + /// Timeout in seconds for /health_generate inference probes + #[arg(long, default_value_t = 3, help_heading = "Health Checks")] + health_generate_timeout_secs: u64, + /// Remove workers from the registry when they are marked unhealthy #[arg(long, default_value_t = false, help_heading = "Health Checks")] remove_unhealthy_workers: bool, @@ -500,6 +504,10 @@ struct CliArgs { #[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle", "postgres", "redis"], help_heading = "Backend")] history_backend: String, + /// Compute per-message SHA-256 hashes for session reconstruction logging + #[arg(long, default_value_t = false, help_heading = "Logging")] + enable_message_hash: bool, + /// Enable WebAssembly support #[arg(long, default_value_t = false, help_heading = "Backend")] enable_wasm: bool, @@ -1225,6 +1233,7 @@ impl CliArgs { disable_health_check: self.disable_health_check, remove_unhealthy_workers: self.remove_unhealthy_workers, }) + .health_generate_timeout_secs(self.health_generate_timeout_secs) .tokenizer_cache(TokenizerCacheConfig { enable_l0: self.tokenizer_cache_enable_l0, l0_max_entries: self.tokenizer_cache_l0_max_entries, @@ -1263,6 +1272,7 @@ impl CliArgs { .dp_aware(self.dp_aware) .retries(!self.disable_retries) .circuit_breaker(!self.disable_circuit_breaker) + .enable_message_hash(self.enable_message_hash) .enable_wasm(self.enable_wasm) .maybe_storage_hook_wasm_path(self.storage_hook_wasm_path.as_deref()) .igw(self.enable_igw) diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 81cfc8f11..f14d59ac2 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -316,6 +316,10 @@ impl GrpcClient { clippy::unreachable, reason = "assembly stage guarantees matching MultimodalData variant for each backend" )] + #[expect( + clippy::too_many_arguments, + reason = "request builder requires all fields for each backend's proto message" + )] pub fn build_chat_request( &self, request_id: String, @@ -324,6 +328,8 @@ impl GrpcClient { token_ids: Vec, multimodal_inputs: Option, tool_constraints: Option<(String, String)>, + eos_token_ids: &[u32], + message_hashes: Option>, ) -> Result { match self { Self::Sglang(client) => { @@ -368,6 +374,8 @@ impl GrpcClient { token_ids, trtllm_mm, tool_constraints, + eos_token_ids, + message_hashes, )?; Ok(ProtoGenerateRequest::Trtllm(Box::new(req))) } @@ -389,6 +397,10 @@ impl GrpcClient { clippy::unreachable, reason = "assembly stage guarantees matching MultimodalData variant for each backend" )] + #[expect( + clippy::too_many_arguments, + reason = "gRPC request builder needs all fields from the Messages API request" + )] pub fn build_messages_request( &self, request_id: String, @@ -397,6 +409,7 @@ impl GrpcClient { token_ids: Vec, multimodal_inputs: Option, tool_constraints: Option<(String, String)>, + message_hashes: Option>, ) -> Result { match self { Self::Sglang(client) => { @@ -441,6 +454,7 @@ impl GrpcClient { token_ids, trtllm_mm, tool_constraints, + message_hashes, )?; Ok(ProtoGenerateRequest::Trtllm(Box::new(req))) } diff --git a/model_gateway/src/routers/grpc/common/responses/mod.rs b/model_gateway/src/routers/grpc/common/responses/mod.rs index fd94e65ff..1bbdbdf3b 100644 --- a/model_gateway/src/routers/grpc/common/responses/mod.rs +++ b/model_gateway/src/routers/grpc/common/responses/mod.rs @@ -7,7 +7,7 @@ pub(crate) mod utils; // Re-export commonly used items pub(crate) use context::ResponsesContext; -pub(crate) use streaming::build_sse_response; +pub(crate) use streaming::{build_sse_response, build_tracked_sse_response}; pub(crate) use utils::{ensure_mcp_connection, persist_response_if_needed}; pub(crate) use crate::routers::common::mcp_utils::collect_user_function_names; diff --git a/model_gateway/src/routers/grpc/common/responses/streaming.rs b/model_gateway/src/routers/grpc/common/responses/streaming.rs index 5cc119df8..c6c3df18d 100644 --- a/model_gateway/src/routers/grpc/common/responses/streaming.rs +++ b/model_gateway/src/routers/grpc/common/responses/streaming.rs @@ -1,5 +1,7 @@ //! Streaming infrastructure for /v1/responses endpoint +use std::sync::{atomic::AtomicU64, Arc}; + use axum::{body::Body, http::StatusCode, response::Response}; use bytes::Bytes; use http::header::{HeaderValue, CONTENT_TYPE}; @@ -18,7 +20,7 @@ use openai_protocol::{ use serde_json::json; use smg_mcp::{self as mcp, ResponseFormat}; use tokio::sync::mpsc; -use tokio_stream::wrappers::UnboundedReceiverStream; +use tokio_stream::{wrappers::UnboundedReceiverStream, StreamExt}; use tracing::warn; use uuid::Uuid; @@ -1042,6 +1044,42 @@ pub(crate) fn build_sse_response( .expect("infallible: static headers and valid status code") } +/// Like `build_sse_response` but atomically records the current time whenever a +/// chunk is successfully written, allowing the health check to skip its inference +/// probe when tokens have flowed recently. +#[expect( + clippy::expect_used, + reason = "Response::builder with static headers and valid status code is infallible" +)] +pub(crate) fn build_tracked_sse_response( + rx: mpsc::UnboundedReceiver>, + last_token_time: &Arc, +) -> Response { + use std::sync::atomic::Ordering; + + let last_token_time = last_token_time.clone(); + let stream = UnboundedReceiverStream::new(rx).map(move |result| { + if result.is_ok() { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + last_token_time.store(now, Ordering::Relaxed); + } + result + }); + Response::builder() + .status(StatusCode::OK) + .header( + CONTENT_TYPE, + HeaderValue::from_static("text/event-stream; charset=utf-8"), + ) + .header("Cache-Control", HeaderValue::from_static("no-cache")) + .header("Connection", HeaderValue::from_static("keep-alive")) + .body(Body::from_stream(stream)) + .expect("infallible: static headers and valid status code") +} + /// Attach `server_label` to an MCP tool-call JSON item. /// /// Only sets the field when `response_format` indicates a passthrough (mcp_call) @@ -1056,3 +1094,67 @@ pub(crate) fn attach_mcp_server_label( item["server_label"] = json!(label); } } + +#[cfg(test)] +mod tests { + use std::sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }; + + use axum::{body::to_bytes, http::StatusCode}; + use bytes::Bytes; + use tokio::sync::mpsc; + + use super::build_tracked_sse_response; + + #[tokio::test] + async fn tracked_sse_response_updates_timestamp_on_ok_chunk() { + use std::time::{SystemTime, UNIX_EPOCH}; + + let last_token_time = Arc::new(AtomicU64::new(0)); + let (tx, rx) = mpsc::unbounded_channel::>(); + + let resp = build_tracked_sse_response(rx, &last_token_time); + assert_eq!(resp.status(), StatusCode::OK); + + let before = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + tx.send(Ok(Bytes::from("data: hello\n\n"))).unwrap(); + drop(tx); + + let _ = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + + let stored = last_token_time.load(Ordering::Relaxed); + assert!( + stored >= before, + "timestamp should be >= before: {stored} < {before}" + ); + assert!( + stored <= before + 5, + "timestamp should be within 5s: {stored}" + ); + } + + #[tokio::test] + async fn tracked_sse_response_does_not_update_timestamp_on_error_chunk() { + let last_token_time = Arc::new(AtomicU64::new(0)); + let (tx, rx) = mpsc::unbounded_channel::>(); + + let resp = build_tracked_sse_response(rx, &last_token_time); + + tx.send(Err(std::io::Error::other("stream error"))).unwrap(); + drop(tx); + + let _ = to_bytes(resp.into_body(), usize::MAX).await; + + assert_eq!( + last_token_time.load(Ordering::Relaxed), + 0, + "error chunk should not update the timestamp" + ); + } +} diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index c7cf84a5c..90becc88c 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -2,7 +2,9 @@ use std::sync::Arc; +use openai_protocol::chat::ChatMessage; use rand::Rng; +use sha2::{Digest, Sha256}; use smg_grpc_client::sglang_proto::DisaggregatedParams; use tracing::debug; @@ -56,3 +58,95 @@ fn inject_sglang_bootstrap_metadata( hostname, bootstrap_port, room_id ); } + +fn chat_message_role(msg: &ChatMessage) -> &'static str { + match msg { + ChatMessage::System { .. } => "system", + ChatMessage::User { .. } => "user", + ChatMessage::Assistant { .. } => "assistant", + ChatMessage::Tool { .. } => "tool", + ChatMessage::Function { .. } => "function", + ChatMessage::Developer { .. } => "developer", + } +} + +fn chat_message_text_content(msg: &ChatMessage) -> String { + match msg { + ChatMessage::System { content, .. } + | ChatMessage::User { content, .. } + | ChatMessage::Tool { content, .. } + | ChatMessage::Developer { content, .. } => content.to_simple_string(), + ChatMessage::Assistant { content, .. } => content + .as_ref() + .map_or_else(String::new, |c| c.to_simple_string()), + ChatMessage::Function { content, .. } => content.clone(), + } +} + +/// Compute per-message SHA-256 hashes matching TRT-LLM's `openai_server.py` format: +/// `sha256(role + "\x00" + content).hexdigest()[:12]` +pub(crate) fn compute_and_log_message_hashes( + request_id: &str, + messages: &[ChatMessage], +) -> Vec<(String, String)> { + let hashes: Vec<(String, String)> = messages + .iter() + .map(|msg| { + let role = chat_message_role(msg); + let content = chat_message_text_content(msg); + let mut hasher = Sha256::new(); + hasher.update(format!("{role}\x00{content}").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + (role.to_string(), hash[..12].to_string()) + }) + .collect(); + debug!( + target: "smg::request", + request_id = %request_id, + message_hashes = ?hashes, + "Request message hashes for session reconstruction" + ); + hashes +} + +/// Compute per-message SHA-256 hashes from InputMessage (Messages API) format. +pub(crate) fn compute_and_log_input_message_hashes( + request_id: &str, + messages: &[openai_protocol::messages::InputMessage], +) -> Vec<(String, String)> { + use openai_protocol::messages::Role; + let hashes: Vec<(String, String)> = messages + .iter() + .map(|msg| { + let role = match msg.role { + Role::User => "user", + Role::Assistant => "assistant", + }; + let content = match &msg.content { + openai_protocol::messages::InputContent::String(s) => s.clone(), + openai_protocol::messages::InputContent::Blocks(blocks) => blocks + .iter() + .filter_map(|b| { + if let openai_protocol::messages::InputContentBlock::Text(t) = b { + Some(t.text.as_str()) + } else { + None + } + }) + .collect::>() + .join(" "), + }; + let mut hasher = Sha256::new(); + hasher.update(format!("{role}\x00{content}").as_bytes()); + let hash = format!("{:x}", hasher.finalize()); + (role.to_string(), hash[..12].to_string()) + }) + .collect(); + debug!( + target: "smg::request", + request_id = %request_id, + message_hashes = ?hashes, + "Request message hashes for session reconstruction" + ); + hashes +} diff --git a/model_gateway/src/routers/grpc/harmony/builder.rs b/model_gateway/src/routers/grpc/harmony/builder.rs index 3efaacd30..64360ff7a 100644 --- a/model_gateway/src/routers/grpc/harmony/builder.rs +++ b/model_gateway/src/routers/grpc/harmony/builder.rs @@ -1073,6 +1073,7 @@ impl HarmonyBuilder { ChatMessage::Tool { content, tool_call_id, + .. } => { // Look up the function name from the tool_call_id let function_name = tool_call_map diff --git a/model_gateway/src/routers/grpc/harmony/stages/preparation.rs b/model_gateway/src/routers/grpc/harmony/stages/preparation.rs index 596860750..2ae306ab3 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/preparation.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/preparation.rs @@ -291,7 +291,7 @@ impl HarmonyPreparationStage { }; let schema = match format { - ResponseFormat::Text => return Ok(None), + ResponseFormat::Text | ResponseFormat::Regex { .. } => return Ok(None), ResponseFormat::JsonObject => Cow::Owned(serde_json::json!({"type": "object"})), ResponseFormat::JsonSchema { json_schema } => Cow::Borrowed(&json_schema.schema), }; diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index d084d66f3..317f25c39 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::{debug, error}; +use tracing::{debug, error, info}; use uuid::Uuid; use crate::routers::{ @@ -21,12 +21,16 @@ use crate::routers::{ /// Unlike regular request building, this uses token_ids directly (Harmony encoding handles messages). pub(crate) struct HarmonyRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl HarmonyRequestBuildingStage { /// Create a new Harmony request building stage - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -72,10 +76,24 @@ impl PipelineStage for HarmonyRequestBuildingStage { ClientSelection::Dual { prefill, .. } => prefill, }; - // Generate request_id based on request type + // Generate request_id based on request type — use user-supplied request_id if provided let request_id = match &ctx.input.request_type { - RequestType::Chat(_) => format!("chatcmpl-{}", Uuid::now_v7()), - RequestType::Responses(_) => format!("responses-{}", Uuid::now_v7()), + RequestType::Chat(req) => { + if let Some(id) = req.request_id.clone() { + info!(target: "smg::request", request_id = %id, "Using user-supplied request ID"); + id + } else { + format!("chatcmpl-{}", Uuid::now_v7()) + } + } + RequestType::Responses(req) => { + if let Some(id) = req.request_id.clone() { + info!(target: "smg::request", request_id = %id, "Using user-supplied request ID"); + id + } else { + format!("responses-{}", Uuid::now_v7()) + } + } request_type @ (RequestType::Generate(_) | RequestType::Completion(_) | RequestType::Embedding(_) @@ -93,6 +111,19 @@ impl PipelineStage for HarmonyRequestBuildingStage { } }; + let message_hashes = if self.enable_message_hash { + if let RequestType::Chat(req) = &ctx.input.request_type { + Some(helpers::compute_and_log_message_hashes( + &request_id, + &req.messages, + )) + } else { + None + } + } else { + None + }; + // Build gRPC request using token_ids directly (Harmony encoding already handled message rendering) let placeholder_processed_text = "[harmony]".to_string(); @@ -189,6 +220,10 @@ impl PipelineStage for HarmonyRequestBuildingStage { ProtoGenerateRequest::Vllm(Box::new(req)) } GrpcClient::Trtllm(trtllm_client) => { + let eos_ids = ctx + .tokenizer_arc() + .map(|t| t.eos_token_ids().to_vec()) + .unwrap_or_default(); let req = match &ctx.input.request_type { RequestType::Chat(request) => { let body = modified_request.as_deref().unwrap_or_else(|| request.as_ref()); @@ -200,6 +235,8 @@ impl PipelineStage for HarmonyRequestBuildingStage { token_ids, None, // No multimodal in Harmony pipeline tool_constraints, + &eos_ids, + message_hashes, ) .map_err(|e| { error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TensorRT-LLM generate request"); @@ -213,6 +250,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { placeholder_processed_text, token_ids, tool_constraints, + None, ) .map_err(|e| { error!(function = "HarmonyRequestBuildingStage::execute", error = %e, "Failed to build TensorRT-LLM generate request from responses"); diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index d28c7e3fc..8ae1a2c89 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -690,6 +690,30 @@ fn expand_tokens( // Assembly: convert MultimodalIntermediate → backend-specific MultimodalData // --------------------------------------------------------------------------- +/// Collapse consecutive runs of a media placeholder token ID to a single +/// occurrence. TRT-LLM's `KimiK25InputProcessor` re-expands each single +/// placeholder to N vision tokens based on `grid_thws`, so sending N +/// placeholders causes double-expansion and a count mismatch error. +/// +/// SGLang expects N placeholders (it uses them directly for embedding lookup), +/// so this function must only be called for TRT-LLM requests. +pub(crate) fn collapse_media_placeholders(token_ids: &[u32], im_token_id: u32) -> Vec { + let mut result = Vec::with_capacity(token_ids.len()); + let mut prev_was_placeholder = false; + for &tok in token_ids { + if tok == im_token_id { + if !prev_was_placeholder { + result.push(tok); + } + prev_was_placeholder = true; + } else { + result.push(tok); + prev_was_placeholder = false; + } + } + result +} + /// Assemble backend-specific multimodal data from the intermediate. /// /// Called in request_building after worker selection, when the backend is known. diff --git a/model_gateway/src/routers/grpc/pd_router.rs b/model_gateway/src/routers/grpc/pd_router.rs index ab96d4101..221c62492 100644 --- a/model_gateway/src/routers/grpc/pd_router.rs +++ b/model_gateway/src/routers/grpc/pd_router.rs @@ -73,6 +73,8 @@ impl GrpcPDRouter { multimodal, }); + let enable_message_hash = ctx.router_config.enable_message_hash; + // Create PD pipeline let pipeline = RequestPipeline::new_pd( worker_registry.clone(), @@ -81,6 +83,8 @@ impl GrpcPDRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, + ctx.last_token_time.clone(), ); // Create Messages PD pipeline @@ -91,11 +95,17 @@ impl GrpcPDRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, + ctx.last_token_time.clone(), ); // Create Completion PD pipeline - let completion_pipeline = - RequestPipeline::new_completion_pd(worker_registry.clone(), policy_registry.clone()); + let completion_pipeline = RequestPipeline::new_completion_pd( + worker_registry.clone(), + policy_registry.clone(), + enable_message_hash, + ctx.last_token_time.clone(), + ); Ok(GrpcPDRouter { worker_registry, diff --git a/model_gateway/src/routers/grpc/pipeline.rs b/model_gateway/src/routers/grpc/pipeline.rs index f5e4f20b4..6be279c86 100644 --- a/model_gateway/src/routers/grpc/pipeline.rs +++ b/model_gateway/src/routers/grpc/pipeline.rs @@ -3,7 +3,10 @@ //! This module defines the RequestPipeline orchestrator that coordinates //! the execution of pipeline stages from request preparation to response delivery. -use std::{sync::Arc, time::Instant}; +use std::{ + sync::{atomic::AtomicU64, Arc}, + time::Instant, +}; use axum::response::{IntoResponse, Response}; use openai_protocol::{ @@ -110,6 +113,10 @@ impl RequestPipeline { } /// Create a regular (single-worker) pipeline + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_regular( worker_registry: Arc, policy_registry: Arc, @@ -117,6 +124,8 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -131,6 +140,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -141,7 +151,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(ChatGenerateRequestBuildingStage::new(false)), // No PD metadata + Box::new(ChatGenerateRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(ChatGenerateResponseProcessingStage::new( @@ -164,6 +177,7 @@ impl RequestPipeline { _reasoning_parser_factory: ReasoningParserFactory, _configured_tool_parser: Option, _configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let stages: Vec> = vec![ Box::new(harmony::stages::HarmonyPreparationStage::new()), @@ -173,7 +187,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(harmony::stages::HarmonyRequestBuildingStage::new(false)), + Box::new(harmony::stages::HarmonyRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(harmony::stages::HarmonyResponseProcessingStage::new()), @@ -194,6 +211,7 @@ impl RequestPipeline { _reasoning_parser_factory: ReasoningParserFactory, _configured_tool_parser: Option, _configured_reasoning_parser: Option, + enable_message_hash: bool, ) -> Self { let stages: Vec> = vec![ Box::new(harmony::stages::HarmonyPreparationStage::new()), @@ -203,7 +221,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(harmony::stages::HarmonyRequestBuildingStage::new(true)), + Box::new(harmony::stages::HarmonyRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(harmony::stages::HarmonyResponseProcessingStage::new()), @@ -216,6 +237,10 @@ impl RequestPipeline { } /// Create a PD (prefill-decode) pipeline + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_pd( worker_registry: Arc, policy_registry: Arc, @@ -223,6 +248,8 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -237,6 +264,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ @@ -247,7 +275,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(ChatGenerateRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(ChatGenerateRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(ChatGenerateResponseProcessingStage::new( @@ -320,6 +351,10 @@ impl RequestPipeline { /// Uses Messages-specific stages for preparation, request building, and response /// processing. Shares worker selection, client acquisition, dispatch metadata, /// and request execution stages with other pipelines. + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_messages( worker_registry: Arc, policy_registry: Arc, @@ -327,6 +362,8 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -341,6 +378,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -351,7 +389,7 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(MessageRequestBuildingStage::new(false)), // No PD metadata + Box::new(MessageRequestBuildingStage::new(false, enable_message_hash)), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(MessageResponseProcessingStage::new( @@ -367,6 +405,10 @@ impl RequestPipeline { } /// Create a Messages API PD (prefill-decode) pipeline + #[expect( + clippy::too_many_arguments, + reason = "all params are distinct required dependencies" + )] pub fn new_messages_pd( worker_registry: Arc, policy_registry: Arc, @@ -374,6 +416,8 @@ impl RequestPipeline { reasoning_parser_factory: ReasoningParserFactory, configured_tool_parser: Option, configured_reasoning_parser: Option, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( tool_parser_factory.clone(), @@ -388,6 +432,7 @@ impl RequestPipeline { configured_tool_parser, configured_reasoning_parser, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ @@ -398,7 +443,7 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(MessageRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(MessageRequestBuildingStage::new(true, enable_message_hash)), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(MessageResponseProcessingStage::new( @@ -421,6 +466,8 @@ impl RequestPipeline { pub fn new_completion( worker_registry: Arc, policy_registry: Arc, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -435,6 +482,7 @@ impl RequestPipeline { None, None, metrics_labels::BACKEND_REGULAR, + last_token_time, )); let stages: Vec> = vec![ @@ -445,7 +493,10 @@ impl RequestPipeline { WorkerSelectionMode::Regular, )), Box::new(ClientAcquisitionStage), - Box::new(CompletionRequestBuildingStage::new(false)), // No PD metadata + Box::new(CompletionRequestBuildingStage::new( + false, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::Single)), Box::new(CompletionResponseProcessingStage::new( @@ -464,6 +515,8 @@ impl RequestPipeline { pub fn new_completion_pd( worker_registry: Arc, policy_registry: Arc, + enable_message_hash: bool, + last_token_time: Arc, ) -> Self { let processor = processor::ResponseProcessor::new( ToolParserFactory::default(), @@ -478,6 +531,7 @@ impl RequestPipeline { None, None, metrics_labels::BACKEND_PD, + last_token_time, )); let stages: Vec> = vec![ @@ -488,7 +542,10 @@ impl RequestPipeline { WorkerSelectionMode::PrefillDecode, )), Box::new(ClientAcquisitionStage), - Box::new(CompletionRequestBuildingStage::new(true)), // Inject PD metadata + Box::new(CompletionRequestBuildingStage::new( + true, + enable_message_hash, + )), Box::new(DispatchMetadataStage), Box::new(RequestExecutionStage::new(ExecutionMode::DualDispatch)), Box::new(CompletionResponseProcessingStage::new( diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 582b401ad..093418034 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -11,7 +11,7 @@ use llm_tokenizer::{ }; use openai_protocol::{ chat::{ChatChoice, ChatCompletionMessage, ChatCompletionRequest, ChatCompletionResponse}, - common::{FunctionCallResponse, ToolCall, ToolChoice, ToolChoiceValue, Usage}, + common::{FunctionCallResponse, ResponseFormat, ToolCall, ToolChoice, ToolChoiceValue, Usage}, completion::{CompletionChoice, CompletionRequest, CompletionResponse}, generate::{GenerateMetaInfo, GenerateRequest, GenerateResponse}, messages::{self, CreateMessageRequest, Message}, @@ -96,7 +96,15 @@ impl ResponseProcessor { let mut reasoning_text: Option = None; let mut processed_text = final_text; - if original_request.separate_reasoning && reasoning_parser_available { + let has_structured_output = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), + ); + + if original_request.separate_reasoning + && reasoning_parser_available + && !has_structured_output + { let pooled_parser = utils::get_reasoning_parser( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), @@ -156,19 +164,43 @@ impl ResponseProcessor { } }; - if used_json_schema { + let output_is_constrained = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), + ); + if self.configured_tool_parser.is_some() + && tool_parser_available + && !output_is_constrained + { + // Configured parser for native tool call tokens (auto/required modes) + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &original_request.model, + history_tool_calls_count, + original_request.tools.as_deref(), + ) + .await; + } + + if tool_calls.is_none() && used_json_schema { + // Constrained decoding output: pure JSON wrapped as a tool call (tool_calls, processed_text) = utils::parse_json_schema_response( &processed_text, original_request.tool_choice.as_ref(), &original_request.model, history_tool_calls_count, ); - } else if tool_parser_available { + } + + if tool_calls.is_none() && tool_parser_available { + // Fallback: auto-detected parser (tool_calls, processed_text) = self .parse_tool_calls( &processed_text, &original_request.model, history_tool_calls_count, + original_request.tools.as_deref(), ) .await; } @@ -191,7 +223,73 @@ impl ResponseProcessor { utils::convert_proto_to_openai_logprobs(proto_logprobs, tokenizer) }); - // Step 5: Build ChatCompletionMessage (proper response message type) + // Strip leaked chatml tokens only when a model-specific parser is configured + if self.configured_tool_parser.is_some() || self.configured_reasoning_parser.is_some() { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + ] { + processed_text = processed_text.replace(token, ""); + } + processed_text = processed_text.trim().to_string(); + } + + let is_json_response = matches!( + &original_request.response_format, + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) + ); + if is_json_response { + if processed_text.starts_with("```json") || processed_text.starts_with("```JSON") { + if let Some(start) = processed_text.find('\n') { + let inner = &processed_text[start + 1..]; + if let Some(end) = inner.rfind("```") { + processed_text = inner[..end].trim().to_string(); + } + } + } + + if processed_text.starts_with('{') { + let mut depth = 0i32; + let mut in_string = false; + let mut escape = false; + let mut json_end = None; + for (i, ch) in processed_text.char_indices() { + if escape { + escape = false; + continue; + } + if ch == '\\' && in_string { + escape = true; + continue; + } + if ch == '"' { + in_string = !in_string; + continue; + } + if in_string { + continue; + } + if ch == '{' { + depth += 1; + } else if ch == '}' { + depth -= 1; + if depth == 0 { + json_end = Some(i + 1); + break; + } + } + } + if let Some(end) = json_end { + processed_text = processed_text[..end].to_string(); + } + } + } + + // Build ChatCompletionMessage (proper response message type) let chat_message = ChatCompletionMessage { role: "assistant".to_string(), content: if processed_text.is_empty() { @@ -230,8 +328,14 @@ impl ResponseProcessor { let history_tool_calls_count = utils::get_history_tool_calls_count(&chat_request); + let has_structured_output = utils::has_constrained_output( + chat_request.tool_choice.as_ref(), + chat_request.response_format.as_ref(), + ); + // Check parser availability once upfront (not per choice) let reasoning_parser_available = chat_request.separate_reasoning + && !has_structured_output && utils::check_reasoning_parser_availability( &self.reasoning_parser_factory, self.configured_reasoning_parser.as_deref(), @@ -312,6 +416,7 @@ impl ResponseProcessor { processed_text: &str, model: &str, history_tool_calls_count: usize, + tools: Option<&[openai_protocol::common::Tool]>, ) -> (Option>, String) { // Get pooled parser for this model let pooled_parser = utils::get_tool_parser( @@ -337,19 +442,23 @@ impl ResponseProcessor { .into_iter() .enumerate() .map(|(index, tc)| { - // Generate ID for this tool call let id = utils::generate_tool_call_id( model, &tc.function.name, index, history_tool_calls_count, ); + let arguments = coerce_tool_args_to_schema( + &tc.function.arguments, + &tc.function.name, + tools.unwrap_or(&[]), + ); ToolCall { id, tool_type: "function".to_string(), function: FunctionCallResponse { name: tc.function.name, - arguments: Some(tc.function.arguments), + arguments: Some(arguments), }, } }) @@ -634,7 +743,18 @@ impl ResponseProcessor { Some(messages::ToolChoice::Tool { .. } | messages::ToolChoice::Any { .. }) ); - if used_json_schema { + if self.configured_tool_parser.is_some() && tool_parser_available { + (tool_calls, processed_text) = self + .parse_tool_calls( + &processed_text, + &messages_request.model, + utils::message_utils::get_history_tool_calls_count_messages( + &messages_request, + ), + None, + ) + .await; + } else if used_json_schema { // Bridge Messages ToolChoice to Chat ToolChoice for reuse let chat_tool_choice = messages_request .tool_choice @@ -655,12 +775,28 @@ impl ResponseProcessor { utils::message_utils::get_history_tool_calls_count_messages( &messages_request, ), + None, ) .await; } } - // Step 3: Build content blocks + // Strip leaked chatml tokens only when a model-specific parser is configured + if self.configured_tool_parser.is_some() || self.configured_reasoning_parser.is_some() { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + ] { + processed_text = processed_text.replace(token, ""); + } + processed_text = processed_text.trim().to_string(); + } + + // Build content blocks let mut content_blocks: Vec = Vec::new(); // Thinking block first (if present) @@ -859,3 +995,91 @@ impl ResponseProcessor { }) } } + +/// Coerce tool call argument values to match the types declared in the function +/// schema. Models sometimes return `546382` (integer) for a `"type": "string"` +/// parameter, or `null` for `"type": "string"` when the user said "null". This +/// mirrors the type coercion that production tproxy performs. +fn coerce_tool_args_to_schema( + arguments_json: &str, + function_name: &str, + tools: &[openai_protocol::common::Tool], +) -> String { + let schema = tools.iter().find_map(|t| { + let f = &t.function; + if f.name == function_name { + Some(&f.parameters) + } else { + None + } + }); + let schema = match schema { + Some(s) if s.is_object() => s, + _ => return arguments_json.to_string(), + }; + + let mut args: serde_json::Value = match serde_json::from_str(arguments_json) { + Ok(v) => v, + Err(_) => return arguments_json.to_string(), + }; + + let props = match schema.get("properties").and_then(|p| p.as_object()) { + Some(p) => p, + None => return arguments_json.to_string(), + }; + + if let Some(obj) = args.as_object_mut() { + for (key, prop_schema) in props { + let expected_type = prop_schema.get("type").and_then(|t| t.as_str()); + let val = match obj.get(key) { + Some(v) => v.clone(), + None => continue, + }; + + let coerced = match expected_type { + Some("string") => match &val { + serde_json::Value::Number(n) => Some(serde_json::Value::String(n.to_string())), + serde_json::Value::Bool(b) => Some(serde_json::Value::String(b.to_string())), + serde_json::Value::Null => Some(serde_json::Value::String("null".to_string())), + _ => None, + }, + Some("number") => match &val { + serde_json::Value::String(s) => s + .parse::() + .ok() + .and_then(serde_json::Number::from_f64) + .map(serde_json::Value::Number), + _ => None, + }, + Some("integer") => match &val { + serde_json::Value::String(s) => s + .parse::() + .ok() + .map(|n| serde_json::Value::Number(n.into())), + _ => None, + }, + Some("array") => match &val { + serde_json::Value::String(_) | serde_json::Value::Number(_) => { + Some(serde_json::Value::Array(vec![val.clone()])) + } + _ => None, + }, + Some("boolean") => match &val { + serde_json::Value::String(s) => match s.as_str() { + "true" => Some(serde_json::Value::Bool(true)), + "false" => Some(serde_json::Value::Bool(false)), + _ => None, + }, + _ => None, + }, + _ => None, + }; + + if let Some(new_val) = coerced { + obj.insert(key.clone(), new_val); + } + } + } + + serde_json::to_string(&args).unwrap_or_else(|_| arguments_json.to_string()) +} diff --git a/model_gateway/src/routers/grpc/regular/responses/conversions.rs b/model_gateway/src/routers/grpc/regular/responses/conversions.rs index ad7387039..53ed450b2 100644 --- a/model_gateway/src/routers/grpc/regular/responses/conversions.rs +++ b/model_gateway/src/routers/grpc/regular/responses/conversions.rs @@ -111,6 +111,7 @@ pub(crate) fn responses_to_chat(req: &ResponsesRequest) -> Result Result, <|im_end|>, ) embedded in the chat template + // output as single token IDs rather than splitting them into characters. + let encoding = match tokenizer.encode(&processed_messages.text, true) { Ok(encoding) => encoding, Err(e) => { error!(function = "ChatPreparationStage::execute", error = %e, "Tokenization failed"); diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index aadcc89b3..9c3284b6e 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -10,7 +10,7 @@ use crate::routers::{ grpc::{ common::stages::{helpers, PipelineStage}, context::{ClientSelection, PreparationOutput, RequestContext}, - multimodal::assemble_multimodal_data, + multimodal::{assemble_multimodal_data, collapse_media_placeholders}, proto_wrapper::ProtoRequest, }, }; @@ -20,11 +20,15 @@ use crate::routers::{ /// Extracts chat-specific request building logic from the old unified RequestBuildingStage. pub(crate) struct ChatRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl ChatRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -60,7 +64,7 @@ impl PipelineStage for ChatRequestBuildingStage { }; let PreparationOutput::Chat { - token_ids, + mut token_ids, processed_messages, tool_constraints, } = prep @@ -72,8 +76,24 @@ impl PipelineStage for ChatRequestBuildingStage { )); }; - // Build chat request - let request_id = format!("chatcmpl-{}", Uuid::now_v7()); + // Build chat request — use user-supplied request_id if provided + let user_supplied = chat_request.request_id.is_some(); + let request_id = chat_request + .request_id + .clone() + .unwrap_or_else(|| format!("chatcmpl-{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } + + let message_hashes = if self.enable_message_hash { + Some(helpers::compute_and_log_message_hashes( + &request_id, + &chat_request.messages, + )) + } else { + None + }; // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { @@ -83,10 +103,27 @@ impl PipelineStage for ChatRequestBuildingStage { )); } - // Assemble backend-specific multimodal data now that the backend is known + // Assemble backend-specific multimodal data now that the backend is known. + // For TRT-LLM, collapse N consecutive media placeholders to 1 because + // TRT-LLM's KimiK25InputProcessor re-expands each single placeholder + // to N vision tokens. SGLang uses them directly for embedding lookup. + let im_token_id = processed_messages + .multimodal_intermediate + .as_ref() + .and_then(|i| i.im_token_id); let multimodal_data = processed_messages .multimodal_intermediate .map(|intermediate| assemble_multimodal_data(intermediate, builder_client)); + if builder_client.is_trtllm() { + if let Some(id) = im_token_id { + token_ids = collapse_media_placeholders(&token_ids, id); + } + } + + let eos_token_ids = ctx + .tokenizer_arc() + .map(|t| t.eos_token_ids().to_vec()) + .unwrap_or_default(); let mut proto_request = builder_client .build_chat_request( @@ -96,6 +133,8 @@ impl PipelineStage for ChatRequestBuildingStage { token_ids, multimodal_data, tool_constraints, + &eos_token_ids, + message_hashes, ) .map_err(|e| { error!(function = "ChatRequestBuildingStage::execute", error = %e, "Failed to build generate request"); diff --git a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs index 20784b842..3098bdf82 100644 --- a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs @@ -10,7 +10,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -27,7 +27,7 @@ pub(crate) struct CompletionRequestBuildingStage { } impl CompletionRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, _enable_message_hash: bool) -> Self { Self { inject_pd_metadata } } } @@ -61,7 +61,14 @@ impl PipelineStage for CompletionRequestBuildingStage { ClientSelection::Dual { prefill, .. } => prefill, }; - let request_id = format!("cmpl_{}", Uuid::now_v7()); + let user_supplied = completion_request.request_id.is_some(); + let request_id = completion_request + .request_id + .clone() + .unwrap_or_else(|| format!("cmpl_{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } let mut proto_request = builder_client .build_completion_request( diff --git a/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs index d2d4273b6..ebf682e72 100644 --- a/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/generate/request_building.rs @@ -22,7 +22,7 @@ pub(crate) struct GenerateRequestBuildingStage { } impl GenerateRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, _enable_message_hash: bool) -> Self { Self { inject_pd_metadata } } } diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs b/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs index 69d66d1b2..d872ce873 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/preparation.rs @@ -153,8 +153,8 @@ impl MessagePreparationStage { } }; - // Step 3: Tokenize the processed text - let encoding = match tokenizer.encode(&processed_messages.text, false) { + // Step 3: Tokenize the processed text (with special token recognition) + let encoding = match tokenizer.encode(&processed_messages.text, true) { Ok(encoding) => encoding, Err(e) => { error!(function = "MessagePreparationStage::execute", error = %e, "Tokenization failed"); diff --git a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs index fc1285456..51fa9f79c 100644 --- a/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/messages/request_building.rs @@ -2,7 +2,7 @@ use async_trait::async_trait; use axum::response::Response; -use tracing::error; +use tracing::{error, info}; use uuid::Uuid; use crate::routers::{ @@ -10,7 +10,7 @@ use crate::routers::{ grpc::{ common::stages::{helpers, PipelineStage}, context::{ClientSelection, PreparationOutput, RequestContext}, - multimodal::assemble_multimodal_data, + multimodal::{assemble_multimodal_data, collapse_media_placeholders}, proto_wrapper::ProtoRequest, }, }; @@ -21,11 +21,15 @@ use crate::routers::{ /// and CreateMessageRequest sampling parameters. pub(crate) struct MessageRequestBuildingStage { inject_pd_metadata: bool, + enable_message_hash: bool, } impl MessageRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { - Self { inject_pd_metadata } + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { + Self { + inject_pd_metadata, + enable_message_hash, + } } } @@ -61,7 +65,7 @@ impl PipelineStage for MessageRequestBuildingStage { }; let PreparationOutput::Messages { - token_ids, + mut token_ids, processed_messages, tool_constraints, } = prep @@ -73,8 +77,24 @@ impl PipelineStage for MessageRequestBuildingStage { )); }; - // Build message request - let request_id = format!("msg_{}", Uuid::now_v7()); + // Build message request — use user-supplied request_id if provided + let user_supplied = messages_request.request_id.is_some(); + let request_id = messages_request + .request_id + .clone() + .unwrap_or_else(|| format!("msg_{}", Uuid::now_v7())); + if user_supplied { + info!(target: "smg::request", request_id = %request_id, "Using user-supplied request ID"); + } + + let message_hashes = if self.enable_message_hash { + Some(helpers::compute_and_log_input_message_hashes( + &request_id, + &messages_request.messages, + )) + } else { + None + }; // Reject multimodal for backends that don't support it, before assembling if processed_messages.multimodal_intermediate.is_some() && builder_client.is_mlx() { @@ -84,10 +104,19 @@ impl PipelineStage for MessageRequestBuildingStage { )); } - // Assemble backend-specific multimodal data now that the backend is known + // Assemble backend-specific multimodal data; collapse placeholders for TRT-LLM + let im_token_id = processed_messages + .multimodal_intermediate + .as_ref() + .and_then(|i| i.im_token_id); let multimodal_data = processed_messages .multimodal_intermediate .map(|intermediate| assemble_multimodal_data(intermediate, builder_client)); + if builder_client.is_trtllm() { + if let Some(id) = im_token_id { + token_ids = collapse_media_placeholders(&token_ids, id); + } + } let mut proto_request = builder_client .build_messages_request( @@ -97,6 +126,7 @@ impl PipelineStage for MessageRequestBuildingStage { token_ids, multimodal_data, tool_constraints, + message_hashes, ) .map_err(|e| { error!(function = "MessageRequestBuildingStage::execute", error = %e, "Failed to build generate request"); diff --git a/model_gateway/src/routers/grpc/regular/stages/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/request_building.rs index 48a1879cd..cc5477706 100644 --- a/model_gateway/src/routers/grpc/regular/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/request_building.rs @@ -24,10 +24,13 @@ pub(crate) struct ChatGenerateRequestBuildingStage { } impl ChatGenerateRequestBuildingStage { - pub fn new(inject_pd_metadata: bool) -> Self { + pub fn new(inject_pd_metadata: bool, enable_message_hash: bool) -> Self { Self { - chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata), - generate_stage: GenerateRequestBuildingStage::new(inject_pd_metadata), + chat_stage: ChatRequestBuildingStage::new(inject_pd_metadata, enable_message_hash), + generate_stage: GenerateRequestBuildingStage::new( + inject_pd_metadata, + enable_message_hash, + ), } } } diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 031c93ab0..dcb3601ba 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -5,7 +5,7 @@ use std::{ collections::{HashMap, HashSet}, io, - sync::Arc, + sync::{atomic::AtomicU64, Arc}, time::Instant, }; @@ -18,7 +18,8 @@ use llm_tokenizer::{ use openai_protocol::{ chat::{ChatCompletionRequest, ChatCompletionStreamResponse}, common::{ - FunctionCallDelta, StringOrArray, Tool, ToolCallDelta, ToolChoice, ToolChoiceValue, Usage, + FunctionCallDelta, ResponseFormat, StringOrArray, Tool, ToolCallDelta, ToolChoice, + ToolChoiceValue, Usage, }, completion::{CompletionRequest, CompletionStreamChoice, CompletionStreamResponse}, generate::GenerateRequest, @@ -36,7 +37,9 @@ use tracing::{debug, error, warn}; use crate::{ observability::metrics::{metrics_labels, Metrics, StreamingMetricsParams}, routers::grpc::{ - common::{response_formatting::CompletionTokenTracker, responses::build_sse_response}, + common::{ + response_formatting::CompletionTokenTracker, responses::build_tracked_sse_response, + }, context, proto_wrapper::{ProtoResponseVariant, ProtoStream}, utils, @@ -52,6 +55,7 @@ pub(crate) struct StreamingProcessor { configured_tool_parser: Option, configured_reasoning_parser: Option, backend_type: &'static str, + last_token_time: Arc, } /// Context for generate endpoint streaming - groups config params to reduce function arguments @@ -70,6 +74,7 @@ impl StreamingProcessor { configured_tool_parser: Option, configured_reasoning_parser: Option, backend_type: &'static str, + last_token_time: Arc, ) -> Self { Self { tool_parser_factory, @@ -77,6 +82,7 @@ impl StreamingProcessor { configured_tool_parser, configured_reasoning_parser, backend_type, + last_token_time, } } @@ -178,7 +184,7 @@ impl StreamingProcessor { } // Return SSE response - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process streaming chunks from a single stream (Regular mode) @@ -195,8 +201,12 @@ impl StreamingProcessor { let start_time = Instant::now(); let mut first_token_time: Option = None; - // Extract request parameters - let separate_reasoning = original_request.separate_reasoning; + // Extract request parameters — skip reasoning parsing for structured output + let has_structured_output = utils::has_constrained_output( + original_request.tool_choice.as_ref(), + original_request.response_format.as_ref(), + ); + let separate_reasoning = original_request.separate_reasoning && !has_structured_output; let tool_choice = &original_request.tool_choice; let tools = &original_request.tools; let history_tool_calls_count = utils::get_history_tool_calls_count(&original_request); @@ -225,6 +235,10 @@ impl StreamingProcessor { // Reusable SSE formatting buffer to avoid allocations per chunk let mut sse_buffer = Vec::with_capacity(512); + // Tracks whether the previous chunk's backticks were stripped so that + // a language tag arriving at the start of the next chunk can be removed. + let mut fence_backticks_stripped = false; + // Use dispatch metadata for consistent response fields let request_id = &dispatch.request_id; let model = &dispatch.model; @@ -273,6 +287,17 @@ impl StreamingProcessor { let is_specific_function = used_json_schema && matches!(tool_choice, Some(ToolChoice::Function { .. })); + // Skip reasoning parsing when constrained decoding is active AND + // reasoning is not expected. When separate_reasoning is true the + // chat template injects , so the model emits thinking + // content even with constraints. The reasoning parser must run. + let output_is_constrained = !separate_reasoning + && (is_specific_function + || utils::has_constrained_output( + tool_choice.as_ref(), + original_request.response_format.as_ref(), + )); + let tool_parser_available = tools.is_some() && utils::check_tool_parser_availability( &self.tool_parser_factory, @@ -362,7 +387,10 @@ impl StreamingProcessor { stream_buffer.push_str(&delta); // Reasoning content handling - let in_reasoning = if separate_reasoning && reasoning_parser_available { + let in_reasoning = if separate_reasoning + && reasoning_parser_available + && !output_is_constrained + { let (normal_text, reasoning_chunk, in_reasoning) = self .process_reasoning_stream( &delta, @@ -396,8 +424,11 @@ impl StreamingProcessor { && tool_choice_enabled && (tool_parser_available || used_json_schema) { - let tool_chunks = if is_specific_function { - // Handle specific function case - emit tool call deltas with arguments + let tool_chunks = if is_specific_function + && (output_is_constrained + || !(self.configured_tool_parser.is_some() + && tool_parser_available)) + { Self::process_specific_function_stream( &delta, index, @@ -410,7 +441,14 @@ impl StreamingProcessor { history_tool_calls_count, ) } else { - // Use incremental parser for regular/required modes + // When reasoning is active the backend may + // not enforce constrained decoding, so the + // model can emit native tool-call tags. + // Use the native parser instead of JSON. + let effective_json_parser = used_json_schema + && !(separate_reasoning + && self.configured_tool_parser.is_some() + && tool_parser_available); self.process_tool_calls_stream( &delta, index, @@ -422,7 +460,7 @@ impl StreamingProcessor { created, system_fingerprint, history_tool_calls_count, - used_json_schema, + effective_json_parser, ) .await }; @@ -439,7 +477,34 @@ impl StreamingProcessor { } } - // Regular content emission + if self.configured_tool_parser.is_some() + || self.configured_reasoning_parser.is_some() + { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + "[EOS]", + "[BOS]", + ] { + delta = delta.replace(token, ""); + } + } + + // Strip markdown code fences for JSON responses so + // streamed content is directly parseable. + let is_json_response = matches!( + &original_request.response_format, + Some(ResponseFormat::JsonObject) | Some(ResponseFormat::JsonSchema { .. }) + ); + if is_json_response { + delta = strip_json_fence(delta, &mut fence_backticks_stripped); + } + if !delta.is_empty() { let content_chunk = ChatCompletionStreamResponse::builder(request_id, model) @@ -727,7 +792,7 @@ impl StreamingProcessor { } // Return SSE response - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process streaming chunks for generate endpoint (no tool/reasoning parsing) @@ -1203,12 +1268,26 @@ impl StreamingProcessor { ); } - // Emit arguments delta - if !delta.is_empty() { + // Emit arguments delta, stripping any chatml tokens + let mut clean_delta = delta.to_string(); + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + "[EOS]", + "[BOS]", + ] { + clean_delta = clean_delta.replace(token, ""); + } + if !clean_delta.is_empty() { chunks.push( ChatCompletionStreamResponse::builder(request_id, model) .created(created) - .add_choice_tool_args(index, delta.to_string()) + .add_choice_tool_args(index, clean_delta) .maybe_system_fingerprint(system_fingerprint) .build(), ); @@ -1261,7 +1340,21 @@ impl StreamingProcessor { match parser.parse_incremental(delta, tools).await { Ok(StreamingParseResult { normal_text, calls }) => { - // Emit normal text if present + let mut normal_text = normal_text; + if self.configured_tool_parser.is_some() + || self.configured_reasoning_parser.is_some() + { + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + ] { + normal_text = normal_text.replace(token, ""); + } + } if !normal_text.is_empty() { chunks.push( ChatCompletionStreamResponse::builder(request_id, model) @@ -1545,7 +1638,7 @@ impl StreamingProcessor { } } - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process Messages API streaming chunks from a single stream. @@ -1783,7 +1876,9 @@ impl StreamingProcessor { // Tool call handling: incremental streaming parser if !in_reasoning && streaming_tool_parser.is_some() { - if is_specific_function { + if is_specific_function + && !(self.configured_tool_parser.is_some() && tool_parser_available) + { // Specific function: entire output is arguments for one tool if !has_tool_calls { has_tool_calls = true; @@ -2289,7 +2384,7 @@ impl StreamingProcessor { } } - build_sse_response(rx) + build_tracked_sse_response(rx, &self.last_token_time) } /// Process completion streaming chunks from a single stream. @@ -2689,3 +2784,84 @@ impl StreamingProcessor { buffer.extend_from_slice(b"\n\n"); } } + +fn strip_json_fence(mut delta: String, fence_backticks_stripped: &mut bool) -> String { + delta = delta + .replace("```json", "") + .replace("```JSON", "") + .replace("```", ""); + let trimmed = delta.trim(); + if trimmed == "json" || trimmed == "JSON" { + delta = String::new(); + } + // Handle cross-chunk fence split: backticks were + // stripped from the previous chunk, language tag + // arrives at the start of this one. + if *fence_backticks_stripped { + for tag in ["json\r\n", "JSON\r\n", "json\n", "JSON\n"] { + if delta.starts_with(tag) { + delta = delta[tag.len()..].to_string(); + break; + } + } + } + *fence_backticks_stripped = delta.is_empty(); + delta +} + +#[cfg(test)] +mod tests { + use super::strip_json_fence; + + fn apply_chunks(chunks: &[&str]) -> String { + let mut fence_state = false; + let mut result = String::new(); + for chunk in chunks { + let out = strip_json_fence(chunk.to_string(), &mut fence_state); + result.push_str(&out); + } + result + } + + #[test] + fn full_fence_single_chunk() { + let out = apply_chunks(&["```json\n{\"a\":1}\n```"]); + assert_eq!(out, "\n{\"a\":1}\n"); + } + + #[test] + fn fence_split_backticks_then_language_tag() { + let out = apply_chunks(&["```", "json\n{\"a\":1}\n```"]); + assert_eq!(out, "{\"a\":1}\n"); + } + + #[test] + fn fence_split_backticks_then_language_tag_crlf() { + let out = apply_chunks(&["```", "json\r\n{\"a\":1}\r\n```"]); + assert_eq!(out, "{\"a\":1}\r\n"); + } + + #[test] + fn standalone_json_tag_dropped() { + let out = apply_chunks(&["json"]); + assert_eq!(out, ""); + } + + #[test] + fn no_fence_passthrough() { + let out = apply_chunks(&["{\"a\":", "1}"]); + assert_eq!(out, "{\"a\":1}"); + } + + #[test] + fn uppercase_fence() { + let out = apply_chunks(&["```JSON\n{\"a\":1}\n```"]); + assert_eq!(out, "\n{\"a\":1}\n"); + } + + #[test] + fn cross_chunk_uppercase() { + let out = apply_chunks(&["```", "JSON\n{\"a\":1}"]); + assert_eq!(out, "{\"a\":1}"); + } +} diff --git a/model_gateway/src/routers/grpc/router.rs b/model_gateway/src/routers/grpc/router.rs index 524da0773..1776ae1f0 100644 --- a/model_gateway/src/routers/grpc/router.rs +++ b/model_gateway/src/routers/grpc/router.rs @@ -88,6 +88,8 @@ impl GrpcRouter { multimodal, }); + let enable_message_hash = ctx.router_config.enable_message_hash; + // Create regular pipeline let pipeline = RequestPipeline::new_regular( worker_registry.clone(), @@ -96,6 +98,8 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, + ctx.last_token_time.clone(), ); // Create Harmony pipelines @@ -106,6 +110,7 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, ); // Create Embedding pipeline @@ -124,11 +129,17 @@ impl GrpcRouter { reasoning_parser_factory.clone(), ctx.configured_tool_parser.clone(), ctx.configured_reasoning_parser.clone(), + enable_message_hash, + ctx.last_token_time.clone(), ); // Create Completion pipeline - let completion_pipeline = - RequestPipeline::new_completion(worker_registry.clone(), _policy_registry.clone()); + let completion_pipeline = RequestPipeline::new_completion( + worker_registry.clone(), + _policy_registry.clone(), + enable_message_hash, + ctx.last_token_time.clone(), + ); // Extract shared dependencies for responses contexts let mcp_orchestrator = ctx diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index b75b8abf6..85a8090c5 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -83,7 +83,22 @@ pub(crate) fn resolve_tokenizer( /// Process tool call arguments in messages /// Per Transformers docs, tool call arguments in assistant messages should be dicts pub(crate) fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), String> { - for msg in messages { + // Validate that arguments are valid JSON but keep them as strings. + // The chat template handles both string and object arguments: + // {% if ... is string %}{{ args }}{% else %}{{ args | tojson }}{% endif %} + // Keeping strings preserves original formatting (e.g. spacing), + // matching the behavior of Python/TGL which passes arguments through as-is. + // + // Also rewrites tool_call IDs to `functions.NAME:INDEX` format. Some chat + // templates (Kimi K2) pass the raw ID through to the model prompt. When + // users send IDs like `call_1`, the model continues the pattern (`call_2`) + // instead of using the expected `functions.NAME:INDEX` format, causing the + // tool parser to fail. + + // First pass: rewrite assistant tool_call IDs and collect old→new mapping. + let mut id_rewrites: HashMap = HashMap::new(); + + for msg in messages.iter_mut() { let role = msg.get("role").and_then(|v| v.as_str()); if role != Some("assistant") { continue; @@ -93,28 +108,54 @@ pub(crate) fn process_tool_call_arguments(messages: &mut [Value]) -> Result<(), continue; }; - for call in tool_calls { - let Some(function) = call.get_mut("function") else { - continue; - }; - let Some(args) = function.get_mut("arguments") else { - continue; - }; - let Some(args_str) = args.as_str() else { - continue; - }; + for (index, call) in tool_calls.iter_mut().enumerate() { + // Validate arguments JSON + if let Some(args_str) = call + .get("function") + .and_then(|f| f.get("arguments")) + .and_then(|a| a.as_str()) + { + if serde_json::from_str::(args_str).is_err() { + return Err(format!("Invalid JSON in tool call arguments: '{args_str}'")); + } + } - // Parse JSON string to object (like Python json.loads) - match serde_json::from_str::(args_str) { - Ok(parsed) => *args = parsed, - Err(e) => { - return Err(format!( - "Failed to parse tool call arguments as JSON: '{args_str}'. Error: {e}" - )) + // Rewrite ID to functions.NAME:INDEX if not already in that format + let func_name = call + .get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()); + let old_id = call.get("id").and_then(|v| v.as_str()); + + if let (Some(name), Some(old)) = (func_name, old_id) { + let canonical = format!("functions.{name}:{index}"); + if old != canonical { + id_rewrites.insert(old.to_string(), canonical.clone()); + if let Some(obj) = call.as_object_mut() { + obj.insert("id".to_string(), Value::String(canonical)); + } } } } } + + // Second pass: rewrite tool message tool_call_ids to match + if !id_rewrites.is_empty() { + for msg in messages.iter_mut() { + let role = msg.get("role").and_then(|v| v.as_str()); + if role != Some("tool") { + continue; + } + if let Some(old_id) = msg.get("tool_call_id").and_then(|v| v.as_str()) { + if let Some(new_id) = id_rewrites.get(old_id) { + if let Some(obj) = msg.as_object_mut() { + obj.insert("tool_call_id".to_string(), Value::String(new_id.clone())); + } + } + } + } + } + Ok(()) } @@ -291,9 +332,21 @@ pub fn process_chat_messages( .transpose() .map_err(|e| format!("Failed to serialize tools: {e}"))?; - let kwargs_capacity = 1 + request.chat_template_kwargs.as_ref().map_or(0, |k| k.len()); + let kwargs_capacity = 2 + request.chat_template_kwargs.as_ref().map_or(0, |k| k.len()); let mut combined_template_kwargs = HashMap::with_capacity(kwargs_capacity); + // Generate TypeScript-style tool declarations for models that use it + // (e.g., Kimi K2.5 chat template checks for tools_ts_str) + if let Some(ref tools) = request.tools { + if !tools.is_empty() { + let ts_str = tools_to_typescript(tools); + if !ts_str.is_empty() { + combined_template_kwargs + .insert("tools_ts_str".to_string(), Value::String(ts_str)); + } + } + } + // Add reasoning_effort if present (like Python does) if let Some(reasoning_effort) = &request.reasoning_effort { combined_template_kwargs.insert( @@ -426,10 +479,27 @@ pub(crate) fn parse_json_schema_response( model: &str, history_tool_calls_count: usize, ) -> (Option>, String) { + // Strip chatml / special tokens that may trail the constrained JSON output + let mut clean = processed_text.to_string(); + for token in [ + "<|im_end|>", + "<|im_start|>", + "<|im_user|>", + "<|im_assistant|>", + "<|im_system|>", + "<|im_middle|>", + "", + "[EOS]", + "[BOS]", + ] { + clean = clean.replace(token, ""); + } + let clean = clean.trim(); + match tool_choice { Some(ToolChoice::Function { function, .. }) => { // Specific function: Parse parameters directly - match serde_json::from_str::(processed_text) { + match serde_json::from_str::(clean) { Ok(params) => { let tool_call = ToolCall { id: generate_tool_call_id( @@ -450,14 +520,14 @@ pub(crate) fn parse_json_schema_response( } Err(e) => { error!("Failed to parse specific function parameters: {}", e); - (None, processed_text.to_string()) + (None, clean.to_string()) } } } Some(ToolChoice::Value(ToolChoiceValue::Required)) | Some(ToolChoice::AllowedTools { .. }) => { // Required mode: Parse array of tool calls - match serde_json::from_str::>(processed_text) { + match serde_json::from_str::>(clean) { Ok(parsed_array) => { let spec_tool_calls: Vec = parsed_array .into_iter() @@ -489,11 +559,11 @@ pub(crate) fn parse_json_schema_response( } Err(e) => { error!("Failed to parse required tool call array: {}", e); - (None, processed_text.to_string()) + (None, clean.to_string()) } } } - _ => (None, processed_text.to_string()), + _ => (None, clean.to_string()), } } @@ -580,6 +650,206 @@ pub(crate) fn parse_finish_reason( } } +// ============================================================================ +// TypeScript-style tool declaration generator +// ============================================================================ + +/// Convert OpenAI tools to TypeScript-style declaration string. +/// +/// Produces the format expected by models like Kimi K2.5, whose chat templates +/// check for a `tools_ts_str` variable and prefer it over raw JSON. +/// +/// Example output: +/// ```text +/// # Tools +/// +/// ## functions +/// namespace functions { +/// // Get the current weather +/// type getCurrentWeather = (_: { +/// // The city and state +/// location: string, +/// unit?: "celsius" | "fahrenheit" +/// }) => any; +/// } +/// ``` +pub fn tools_to_typescript(tools: &[Tool]) -> String { + let mut functions = Vec::new(); + + for tool in tools { + if tool.tool_type != "function" { + continue; + } + functions.push(function_to_typescript(&tool.function)); + } + + if functions.is_empty() { + return String::new(); + } + + let mut result = String::from("# Tools\n\n## functions\nnamespace functions {\n"); + result.push_str(&functions.join("\n")); + result.push_str("\n}\n"); + result +} + +fn function_to_typescript(func: &openai_protocol::common::Function) -> String { + let mut out = String::new(); + + // Description comment + if let Some(ref desc) = func.description { + for line in desc.lines() { + if line.is_empty() { + out.push('\n'); + } else { + out.push_str(&format!("// {line}\n")); + } + } + } + + // Parameters + let params_str = + if func.parameters.is_null() || func.parameters == Value::Object(Default::default()) { + "{}".to_string() + } else { + schema_to_typescript(&func.parameters, "", &[]) + }; + + out.push_str(&format!("type {} = (_: {params_str}) => any;", func.name)); + out +} + +fn schema_to_typescript(schema: &Value, indent: &str, _required: &[&str]) -> String { + match schema.get("type").and_then(|t| t.as_str()) { + Some("object") => object_to_typescript(schema, indent), + Some("array") => array_to_typescript(schema, indent), + Some("string") => { + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + enum_to_typescript(enum_vals) + } else { + "string".to_string() + } + } + Some("integer" | "number") => "number".to_string(), + Some("boolean") => "boolean".to_string(), + Some("null") => "null".to_string(), + _ => { + // Handle enum without type + if let Some(enum_vals) = schema.get("enum").and_then(|e| e.as_array()) { + return enum_to_typescript(enum_vals); + } + // Handle anyOf + if let Some(any_of) = schema.get("anyOf").and_then(|a| a.as_array()) { + let types: Vec = any_of + .iter() + .map(|t| schema_to_typescript(t, indent, &[])) + .collect(); + return types.join(" | "); + } + "any".to_string() + } + } +} + +fn object_to_typescript(schema: &Value, indent: &str) -> String { + let properties = match schema.get("properties").and_then(|p| p.as_object()) { + Some(p) => p, + None => return "{}".to_string(), + }; + + let required_fields: Vec<&str> = schema + .get("required") + .and_then(|r| r.as_array()) + .map(|arr| arr.iter().filter_map(|v| v.as_str()).collect()) + .unwrap_or_default(); + + let child_indent = format!("{indent} "); + + // Sort: required first, then optional, both alphabetically + let mut required_params: Vec<(&String, &Value)> = Vec::new(); + let mut optional_params: Vec<(&String, &Value)> = Vec::new(); + + for (name, prop) in properties { + if required_fields.contains(&name.as_str()) { + required_params.push((name, prop)); + } else { + optional_params.push((name, prop)); + } + } + required_params.sort_by_key(|(n, _)| *n); + optional_params.sort_by_key(|(n, _)| *n); + + let mut params: Vec<(&String, &Value, bool)> = Vec::new(); + for (n, v) in required_params { + params.push((n, v, false)); + } + for (n, v) in optional_params { + params.push((n, v, true)); + } + + if params.is_empty() { + return "{}".to_string(); + } + + let mut parts = Vec::new(); + for (name, prop, optional) in ¶ms { + let mut part = String::new(); + + // Description comment + if let Some(desc) = prop.get("description").and_then(|d| d.as_str()) { + for line in desc.lines() { + if line.is_empty() { + part.push('\n'); + } else { + part.push_str(&format!("{child_indent}// {line}\n")); + } + } + } + + let type_str = schema_to_typescript(prop, &child_indent, &[]); + let opt_marker = if *optional { "?" } else { "" }; + part.push_str(&format!("{child_indent}{name}{opt_marker}: {type_str}")); + parts.push(part); + } + + format!("{{\n{}\n{indent}}}", parts.join(",\n")) +} + +fn array_to_typescript(schema: &Value, indent: &str) -> String { + let items = schema.get("items"); + let item_type = match items { + Some(item_schema) => { + let child_indent = format!("{indent} "); + // Check if item has description + let item_desc = item_schema.get("description").and_then(|d| d.as_str()); + let type_str = schema_to_typescript(item_schema, &child_indent, &[]); + + if let Some(desc) = item_desc { + return format!( + "Array<\n{child_indent}// {desc}\n{child_indent}{type_str}\n{indent}>" + ); + } + type_str + } + None => "any".to_string(), + }; + format!("Array<{item_type}>") +} + +fn enum_to_typescript(values: &[Value]) -> String { + let parts: Vec = values + .iter() + .map(|v| match v { + Value::String(s) => format!("\"{s}\""), + Value::Number(n) => n.to_string(), + Value::Bool(b) => b.to_string(), + Value::Null => "null".to_string(), + _ => "any".to_string(), + }) + .collect(); + parts.join(" | ") +} + #[cfg(test)] mod tests { use llm_tokenizer::chat_template::ChatTemplateContentFormat; diff --git a/model_gateway/src/routers/grpc/utils/message_utils.rs b/model_gateway/src/routers/grpc/utils/message_utils.rs index eda24a7ad..89b2c98eb 100644 --- a/model_gateway/src/routers/grpc/utils/message_utils.rs +++ b/model_gateway/src/routers/grpc/utils/message_utils.rs @@ -648,6 +648,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, }; assert_eq!(get_history_tool_calls_count_messages(&request), 0); @@ -696,6 +697,7 @@ mod tests { top_p: None, container: None, mcp_servers: None, + request_id: None, }; assert_eq!(get_history_tool_calls_count_messages(&request), 2); } diff --git a/model_gateway/src/routers/grpc/utils/mod.rs b/model_gateway/src/routers/grpc/utils/mod.rs index 2415494ae..f7c3174ea 100644 --- a/model_gateway/src/routers/grpc/utils/mod.rs +++ b/model_gateway/src/routers/grpc/utils/mod.rs @@ -22,5 +22,5 @@ pub(crate) use metrics::{error_type_from_status, route_to_endpoint}; pub(crate) use parsers::{ check_reasoning_parser_availability, check_tool_parser_availability, create_reasoning_parser, create_tool_parser, extract_thinking_from_kwargs, get_reasoning_parser, get_tool_parser, - should_mark_reasoning_started, + has_constrained_output, should_mark_reasoning_started, }; diff --git a/model_gateway/src/routers/grpc/utils/parsers.rs b/model_gateway/src/routers/grpc/utils/parsers.rs index 106f0a8ba..1c438b70d 100644 --- a/model_gateway/src/routers/grpc/utils/parsers.rs +++ b/model_gateway/src/routers/grpc/utils/parsers.rs @@ -4,6 +4,7 @@ use llm_tokenizer::{ chat_template::{ThinkingKeyName, ThinkingToggle}, traits::Tokenizer, }; +use openai_protocol::common::{ResponseFormat, ToolChoice, ToolChoiceValue}; use reasoning_parser::{ ParserFactory as ReasoningParserFactory, PooledParser as ReasoningPooledParser, ReasoningParser, }; @@ -178,3 +179,178 @@ pub(crate) fn create_tool_parser( tool_parser_factory.registry().create_for_model(model) } } + +/// Returns `true` when constrained decoding is active, meaning the model output +/// is structured JSON rather than free-form text. In that case the reasoning +/// parser must be skipped — otherwise it captures the constrained JSON as +/// `reasoning_content` and leaves `content` empty. +/// +/// Constrained decoding is triggered by: +/// - `tool_choice` = a specific function, `required`, or `allowed_tools` with +/// `mode == "required"` +/// - `response_format` = `json_object` or `json_schema` +pub(crate) fn has_constrained_output( + tool_choice: Option<&ToolChoice>, + response_format: Option<&ResponseFormat>, +) -> bool { + let constrained_tool_choice = matches!( + tool_choice, + Some(ToolChoice::Function { .. }) | Some(ToolChoice::Value(ToolChoiceValue::Required)) + ) || matches!( + tool_choice, + Some(ToolChoice::AllowedTools { mode, .. }) if mode == "required" + ); + + let constrained_response_format = matches!( + response_format, + Some(ResponseFormat::JsonObject) + | Some(ResponseFormat::JsonSchema { .. }) + | Some(ResponseFormat::Regex { .. }) + ); + + constrained_tool_choice || constrained_response_format +} + +#[cfg(test)] +mod tests { + use openai_protocol::common::{FunctionChoice, JsonSchemaFormat, ToolReference}; + + use super::*; + + // ── has_constrained_output: tool_choice variants ──────────────────── + + #[test] + fn no_tool_choice_no_response_format_is_unconstrained() { + assert!(!has_constrained_output(None, None)); + } + + #[test] + fn tool_choice_auto_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(!has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_none_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::None); + assert!(!has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_required_is_constrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Required); + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn tool_choice_specific_function_is_constrained() { + let tc = ToolChoice::Function { + tool_type: "function".to_string(), + function: FunctionChoice { + name: "get_weather".to_string(), + }, + }; + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn allowed_tools_required_is_constrained() { + let tc = ToolChoice::AllowedTools { + tool_type: "allowed_tools".to_string(), + mode: "required".to_string(), + tools: vec![ToolReference::Function { + name: "search".to_string(), + }], + }; + assert!(has_constrained_output(Some(&tc), None)); + } + + #[test] + fn allowed_tools_auto_is_unconstrained() { + let tc = ToolChoice::AllowedTools { + tool_type: "allowed_tools".to_string(), + mode: "auto".to_string(), + tools: vec![ToolReference::Function { + name: "search".to_string(), + }], + }; + assert!(!has_constrained_output(Some(&tc), None)); + } + + // ── has_constrained_output: response_format variants ──────────────── + + #[test] + fn response_format_text_is_unconstrained() { + assert!(!has_constrained_output(None, Some(&ResponseFormat::Text))); + } + + #[test] + fn response_format_json_object_is_constrained() { + assert!(has_constrained_output( + None, + Some(&ResponseFormat::JsonObject) + )); + } + + #[test] + fn response_format_json_schema_is_constrained() { + let rf = ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "feedback".to_string(), + schema: serde_json::json!({"type": "object"}), + strict: Some(true), + }, + }; + assert!(has_constrained_output(None, Some(&rf))); + } + + #[test] + fn response_format_regex_is_constrained() { + let rf = ResponseFormat::Regex { + pattern: "(positive|neutral|negative)".to_string(), + }; + assert!(has_constrained_output(None, Some(&rf))); + } + + // ── has_constrained_output: combinations ──────────────────────────── + + #[test] + fn tool_choice_auto_with_json_object_is_constrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(has_constrained_output( + Some(&tc), + Some(&ResponseFormat::JsonObject) + )); + } + + #[test] + fn tool_choice_auto_with_text_format_is_unconstrained() { + let tc = ToolChoice::Value(ToolChoiceValue::Auto); + assert!(!has_constrained_output( + Some(&tc), + Some(&ResponseFormat::Text) + )); + } + + #[test] + fn tool_choice_required_with_json_schema_both_constrain() { + let tc = ToolChoice::Value(ToolChoiceValue::Required); + let rf = ResponseFormat::JsonSchema { + json_schema: JsonSchemaFormat { + name: "output".to_string(), + schema: serde_json::json!({"type": "object"}), + strict: None, + }, + }; + assert!(has_constrained_output(Some(&tc), Some(&rf))); + } + + #[test] + fn tool_choice_none_with_json_object_is_constrained_via_format() { + let tc = ToolChoice::Value(ToolChoiceValue::None); + assert!(has_constrained_output( + Some(&tc), + Some(&ResponseFormat::JsonObject) + )); + } +} diff --git a/model_gateway/src/routers/openai/chat.rs b/model_gateway/src/routers/openai/chat.rs index 724dda03a..69acf9435 100644 --- a/model_gateway/src/routers/openai/chat.rs +++ b/model_gateway/src/routers/openai/chat.rs @@ -52,6 +52,7 @@ pub(super) async fn route_chat( let start = Instant::now(); let model = model_id; let streaming = body.stream; + let last_token_time = deps.shared_components.last_token_time.clone(); Metrics::record_router_request( metrics_labels::ROUTER_OPENAI, @@ -160,6 +161,7 @@ pub(super) async fn route_chat( let headers = Arc::clone(&headers_cloned); let worker_api_key = Arc::clone(&worker_api_key); let worker = Arc::clone(&worker); + let last_token_time = last_token_time.clone(); async move { let mut req = client.post(&url).json(&*payload); @@ -196,6 +198,7 @@ pub(super) async fn route_chat( let (tx, rx) = mpsc::unbounded_channel(); #[expect(clippy::disallowed_methods, reason = "fire-and-forget stream relay; gateway shutdown need not wait for individual stream forwarding")] tokio::spawn(async move { + use std::sync::atomic::Ordering; let mut s = stream; while let Some(chunk) = s.next().await { match chunk { @@ -203,6 +206,11 @@ pub(super) async fn route_chat( if tx.send(Ok(bytes)).is_err() { break; } + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + last_token_time.store(now, Ordering::Relaxed); } Err(e) => { let _ = tx.send(Err(format!("Stream error: {e}"))); diff --git a/model_gateway/src/routers/openai/context.rs b/model_gateway/src/routers/openai/context.rs index 9e94ccd24..32b709dbc 100644 --- a/model_gateway/src/routers/openai/context.rs +++ b/model_gateway/src/routers/openai/context.rs @@ -1,6 +1,6 @@ //! Request context types for OpenAI router pipeline. -use std::sync::Arc; +use std::sync::{atomic::AtomicU64, Arc}; use axum::http::HeaderMap; use openai_protocol::{chat::ChatCompletionRequest, responses::ResponsesRequest}; @@ -42,6 +42,7 @@ pub enum RequestType { pub struct SharedComponents { pub client: reqwest::Client, pub router_config: Arc, + pub last_token_time: Arc, } pub struct ResponsesComponents { diff --git a/model_gateway/src/routers/openai/router.rs b/model_gateway/src/routers/openai/router.rs index 0bda10e99..ee6da1de3 100644 --- a/model_gateway/src/routers/openai/router.rs +++ b/model_gateway/src/routers/openai/router.rs @@ -91,6 +91,7 @@ impl OpenAIRouter { let shared_components = Arc::new(SharedComponents { client: ctx.client.clone(), router_config: Arc::new(ctx.router_config.clone()), + last_token_time: ctx.last_token_time.clone(), }); let responses_components = Arc::new(ResponsesComponents { diff --git a/model_gateway/src/server.rs b/model_gateway/src/server.rs index 8ec3aa0b4..08d699fed 100644 --- a/model_gateway/src/server.rs +++ b/model_gateway/src/server.rs @@ -87,6 +87,7 @@ pub struct AppState { pub concurrency_queue_tx: Option>, pub router_manager: Option>, pub mesh_handler: Option>, + pub api_port: u16, } async fn parse_function_call( @@ -161,8 +162,143 @@ async fn health(_state: State>) -> Response { liveness().await } -async fn health_generate(State(state): State>, req: Request) -> Response { - state.router.health_generate(req).await +async fn health_generate(State(state): State>, _req: Request) -> Response { + use std::{ + sync::atomic::Ordering, + time::{Instant, SystemTime, UNIX_EPOCH}, + }; + + let registry = &state.context.worker_registry; + let workers = registry.get_all(); + if workers.is_empty() { + return (StatusCode::SERVICE_UNAVAILABLE, "No workers registered").into_response(); + } + + let healthy: Vec<_> = workers.iter().filter(|w| w.is_healthy()).collect(); + if healthy.is_empty() { + let info: Vec<_> = workers + .iter() + .map(|w| format!("{} ({})", w.model_id(), w.url())) + .collect(); + return ( + StatusCode::SERVICE_UNAVAILABLE, + format!("0/{} workers healthy: {}", workers.len(), info.join(", ")), + ) + .into_response(); + } + + // Fast path: if a token was forwarded within the last 10 seconds, skip the + // probe entirely. Under high load this avoids adding a synthetic request that + // could tip the backend over its concurrency limit. + let last_token_unix = state.context.last_token_time.load(Ordering::Relaxed); + if last_token_unix > 0 { + let now_unix = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_secs(); + let age_secs = now_unix.saturating_sub(last_token_unix); + if age_secs < 10 { + info!( + target: "smg::health", + workers = healthy.len(), + last_token_age_secs = age_secs, + "health_generate: skipping probe — token forwarded recently" + ); + return ( + StatusCode::OK, + format!( + "OK - {} workers healthy, last token {}s ago (probe skipped)", + healthy.len(), + age_secs + ), + ) + .into_response(); + } + } + + let model_id = healthy[0].model_id().to_string(); + let probe_url = format!("http://127.0.0.1:{}/v1/chat/completions", state.api_port); + let probe_body = serde_json::json!({ + "model": model_id, + "messages": [{"role": "user", "content": "ping"}], + "max_tokens": 1, + "stream": false, + "temperature": 0 + }); + + info!( + target: "smg::health", + model = %model_id, + probe_url = %probe_url, + timeout_secs = state.context.router_config.health_generate_timeout_secs, + max_tokens = 1, + "health_generate: sending real inference probe" + ); + + let start = Instant::now(); + let probe_result = state + .context + .client + .post(&probe_url) + .json(&probe_body) + .timeout(Duration::from_secs( + state.context.router_config.health_generate_timeout_secs, + )) + .send() + .await; + let duration_ms = start.elapsed().as_millis(); + + match probe_result { + Ok(resp) if resp.status().is_success() => { + info!( + target: "smg::health", + model = %model_id, + duration_ms = %duration_ms, + workers = healthy.len(), + "health_generate: probe succeeded" + ); + ( + StatusCode::OK, + format!( + "OK - {} workers healthy, probe succeeded in {}ms (model: {})", + healthy.len(), + duration_ms, + model_id + ), + ) + .into_response() + } + Ok(resp) => { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + warn!( + target: "smg::health", + model = %model_id, + status = %status, + duration_ms = %duration_ms, + "health_generate: probe failed (backend error)" + ); + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Probe failed: backend returned {status} in {duration_ms}ms — {body}"), + ) + .into_response() + } + Err(e) => { + warn!( + target: "smg::health", + model = %model_id, + error = %e, + duration_ms = %duration_ms, + "health_generate: probe failed (transport error)" + ); + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Probe failed: {e} in {duration_ms}ms"), + ) + .into_response() + } + } } async fn engine_metrics(State(state): State>) -> Response { @@ -1372,6 +1508,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box Arc { + use std::sync::atomic::AtomicU64; + use crate::{ config::RouterConfig, middleware::TokenBucket, observability::inflight_tracker::InFlightRequestTracker, @@ -1288,6 +1290,7 @@ mod tests { router_config, )), inflight_tracker: InFlightRequestTracker::new(), + last_token_time: Arc::new(AtomicU64::new(0)), kv_event_monitor: None, realtime_registry: Arc::new(RealtimeRegistry::new()), webrtc_bind_addr: None, diff --git a/model_gateway/src/worker/monitor.rs b/model_gateway/src/worker/monitor.rs index 9be5b225f..583b2c667 100644 --- a/model_gateway/src/worker/monitor.rs +++ b/model_gateway/src/worker/monitor.rs @@ -537,10 +537,7 @@ async fn run_event_loop( } } Ok(WorkerEvent::StatusChanged { - worker, - new_status, - old_status: _, - .. + worker, new_status, .. }) => { if new_status != WorkerStatus::Ready { monitor.evict_worker_loads(worker.url()); diff --git a/model_gateway/tests/api/api_endpoints_test.rs b/model_gateway/tests/api/api_endpoints_test.rs index d5357c27d..66cf103f5 100644 --- a/model_gateway/tests/api/api_endpoints_test.rs +++ b/model_gateway/tests/api/api_endpoints_test.rs @@ -112,6 +112,9 @@ mod health_tests { #[tokio::test] async fn test_health_generate_endpoint() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + let ctx = AppTestContext::new(vec![MockWorkerConfig { port: 18005, worker_type: WorkerType::Regular, @@ -121,6 +124,16 @@ mod health_tests { }]) .await; + // Simulate recent traffic so the fast path fires and avoids + // a real probe (which would fail since api_port=0 in tests). + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + ctx.app_context + .last_token_time + .store(now, Ordering::Relaxed); + let app = ctx.create_app(); let req = Request::builder() @@ -135,8 +148,124 @@ mod health_tests { let body = axum::body::to_bytes(resp.into_body(), usize::MAX) .await .unwrap(); - let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); - assert!(body_json.is_object()); + let body_str = String::from_utf8_lossy(&body); + assert!( + body_str.contains("workers healthy"), + "Expected workers healthy in body, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// Fast path: if last_token_time is recent (< 10s), the probe is skipped. + #[tokio::test] + async fn test_health_generate_fast_path_skips_probe() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + + let ctx = AppTestContext::new(vec![MockWorkerConfig { + port: 18006, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let now = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + ctx.app_context + .last_token_time + .store(now, Ordering::Relaxed); + + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_str = String::from_utf8_lossy(&body); + assert!( + body_str.contains("probe skipped"), + "Expected 'probe skipped' in body, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// Stale fast path: last_token_time older than 10s falls through to the probe. + /// The probe will fail (no real server at api_port=0), returning 503. + #[tokio::test] + async fn test_health_generate_stale_token_falls_through_to_probe() { + use std::sync::atomic::Ordering; + use std::time::{SystemTime, UNIX_EPOCH}; + + let ctx = AppTestContext::new(vec![MockWorkerConfig { + port: 18007, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }]) + .await; + + let stale = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() + .saturating_sub(30); + ctx.app_context + .last_token_time + .store(stale, Ordering::Relaxed); + + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + // Probe to api_port=0 will fail → 503 + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); + + let body = axum::body::to_bytes(resp.into_body(), usize::MAX) + .await + .unwrap(); + let body_str = String::from_utf8_lossy(&body); + assert!( + !body_str.contains("probe skipped"), + "Stale token should not skip probe, got: {body_str}" + ); + + ctx.shutdown().await; + } + + /// No workers → 503 immediately, no probe attempted. + #[tokio::test] + async fn test_health_generate_no_workers() { + let ctx = AppTestContext::new(vec![]).await; + let app = ctx.create_app(); + + let req = Request::builder() + .method("GET") + .uri("/health_generate") + .body(Body::empty()) + .unwrap(); + + let resp = app.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); ctx.shutdown().await; } diff --git a/model_gateway/tests/common/test_app.rs b/model_gateway/tests/common/test_app.rs index 40954e12b..8d640b5a5 100644 --- a/model_gateway/tests/common/test_app.rs +++ b/model_gateway/tests/common/test_app.rs @@ -106,6 +106,7 @@ pub fn create_test_app( concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); // Configure request ID headers (use defaults if not specified) @@ -148,6 +149,7 @@ pub fn create_test_app_with_context( concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); // Get config from the context diff --git a/model_gateway/tests/routing/test_openai_routing.rs b/model_gateway/tests/routing/test_openai_routing.rs index 0f66d3c5a..dd94bf78b 100644 --- a/model_gateway/tests/routing/test_openai_routing.rs +++ b/model_gateway/tests/routing/test_openai_routing.rs @@ -94,6 +94,7 @@ fn create_minimal_completion_request() -> CompletionRequest { session_params: None, return_hidden_states: false, sampling_seed: None, + request_id: None, other: serde_json::Map::new(), } } diff --git a/model_gateway/tests/spec/chat_message.rs b/model_gateway/tests/spec/chat_message.rs index 232154a62..92ef8e5c0 100644 --- a/model_gateway/tests/spec/chat_message.rs +++ b/model_gateway/tests/spec/chat_message.rs @@ -66,6 +66,7 @@ fn test_chat_message_tagged_by_role_tool() { ChatMessage::Tool { content, tool_call_id, + .. } => { match content { MessageContent::Text(text) => { diff --git a/model_gateway/tests/wasm_test.rs b/model_gateway/tests/wasm_test.rs index c877499b2..fc7527a24 100644 --- a/model_gateway/tests/wasm_test.rs +++ b/model_gateway/tests/wasm_test.rs @@ -196,6 +196,7 @@ async fn create_test_app_with_wasm() -> (axum::Router, Arc, TempDir) concurrency_queue_tx: None, router_manager: None, mesh_handler: None, + api_port: 0, }); let request_id_headers = vec!["x-request-id".to_string(), "x-correlation-id".to_string()];