From e60ab1356a2a59481e3b79911820473b9b3f19de Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Mon, 4 May 2026 15:03:50 -0700 Subject: [PATCH 1/8] feat(mlx-grpc): support string stop sequences for chat and completion (#1099) Signed-off-by: Zhuo Li --- Cargo.toml | 1 + crates/grpc_client/src/mlx_engine.rs | 9 +- crates/protocols/src/completion.rs | 2 + crates/tokenizer/src/mock.rs | 13 +- model_gateway/Cargo.toml | 1 + .../src/routers/grpc/common/stages/helpers.rs | 44 ++++++ .../src/routers/grpc/proto_wrapper.rs | 8 + .../src/routers/grpc/regular/processor.rs | 24 ++- .../regular/stages/chat/request_building.rs | 8 + .../stages/completion/request_building.rs | 8 +- .../src/routers/grpc/regular/streaming.rs | 32 +++- .../src/routers/grpc/utils/chat_utils.rs | 138 +++++++++++++++++- model_gateway/src/routers/grpc/utils/mod.rs | 2 +- 13 files changed, 274 insertions(+), 16 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1cacf31fb..fca10cf91 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,6 +60,7 @@ str0m = { version = "0.19", default-features = false, features = ["openssl"] } scopeguard = "1.2" bitflags = "2.10.0" schemars = "0.8" +test-case = "3.3.1" [workspace.lints.rust] unsafe_code = "deny" diff --git a/crates/grpc_client/src/mlx_engine.rs b/crates/grpc_client/src/mlx_engine.rs index fab8f44fe..b581f73b9 100644 --- a/crates/grpc_client/src/mlx_engine.rs +++ b/crates/grpc_client/src/mlx_engine.rs @@ -242,11 +242,8 @@ impl MlxEngineClient { // - response_format — same as constrained decoding // // Servicer limitations (fixable without mlx-lm changes): - // - TODO(mlx): String stop sequences — mlx-lm supports this via - // tokenizer.encode() → SequenceStateMachine. Fix by converting stop - // strings to token IDs in the preparation stage (which already has the - // Rust tokenizer) and passing them as stop_token_ids in the proto. - // + // - String stop sequences: supported in chat and completion pipelines. + // Messages and Generate pipelines still reject string stops (see reject_stop_strings). // Track upstream: https://github.com/ml-explore/mlx-lm fn reject_constraint(constraint: Option<&(String, String)>) -> Result<(), String> { @@ -309,7 +306,6 @@ impl MlxEngineClient { ) -> Result { Self::reject_constraint(constraint.as_ref())?; Self::reject_n(body.n)?; - Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?; Self::reject_response_format(body.response_format.is_some())?; let sampling_params = Self::build_sampling_params_from_chat(body); @@ -335,7 +331,6 @@ impl MlxEngineClient { token_ids: Vec, ) -> Result { Self::reject_n(body.n)?; - Self::reject_stop_strings(body.stop.as_ref().is_some_and(|s| !s.is_empty()))?; Self::reject_if_any_constraint( body.json_schema.as_ref(), body.regex.as_ref(), diff --git a/crates/protocols/src/completion.rs b/crates/protocols/src/completion.rs index 393596fe9..d21a58a91 100644 --- a/crates/protocols/src/completion.rs +++ b/crates/protocols/src/completion.rs @@ -273,4 +273,6 @@ pub struct CompletionStreamChoice { #[serde(skip_serializing_if = "Option::is_none")] pub logprobs: Option, pub finish_reason: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub matched_stop: Option, } diff --git a/crates/tokenizer/src/mock.rs b/crates/tokenizer/src/mock.rs index f057aabe2..7b0e2a0a0 100644 --- a/crates/tokenizer/src/mock.rs +++ b/crates/tokenizer/src/mock.rs @@ -11,6 +11,7 @@ pub struct MockTokenizer { vocab: HashMap, reverse_vocab: HashMap, special_tokens: SpecialTokens, + fail_encode: bool, } impl Default for MockTokenizer { @@ -62,19 +63,29 @@ impl MockTokenizer { vocab, reverse_vocab, special_tokens, + fail_encode: false, + } + } + + pub fn failing() -> Self { + Self { + fail_encode: true, + ..Self::new() } } } impl Encoder for MockTokenizer { fn encode(&self, input: &str, _add_special_tokens: bool) -> Result { + if self.fail_encode { + return Err(anyhow::anyhow!("test encode error")); + } // Simple word-based tokenization using the vocab // Split by whitespace and look up each word (decoder adds spaces back) let tokens: Vec = input .split_whitespace() .filter_map(|word| self.vocab.get(word).copied()) .collect(); - Ok(Encoding::Plain(tokens)) } diff --git a/model_gateway/Cargo.toml b/model_gateway/Cargo.toml index 5de28d1bb..83415bf6b 100644 --- a/model_gateway/Cargo.toml +++ b/model_gateway/Cargo.toml @@ -144,6 +144,7 @@ rsa = { version = "0.9", features = ["sha2"] } jsonwebtoken = "9.3" validator = "0.20.0" kv-index.workspace = true +test-case = { workspace = true } wasmtime-wasi = { workspace = true } lru = { workspace = true } wat = "1.244" diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 65ef6dded..0a5cee5bd 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -2,6 +2,9 @@ use std::sync::Arc; +use axum::response::Response; +use llm_tokenizer::traits::Tokenizer; +use openai_protocol::common::StringOrArray; use rand::Rng; use smg_grpc_client::{ mlx_proto, @@ -19,6 +22,14 @@ use crate::{ sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT, DEFAULT_SAMPLING_PARAMS_LABEL, }, + routers::{ + error, + grpc::{ + context::WorkerSelection, proto_wrapper::ProtoGenerateRequest, + utils::resolve_mlx_stop_ids, + }, + }, + worker::{RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT}, }; #[derive(Clone, Copy, Debug, Default)] @@ -263,3 +274,36 @@ fn inject_sglang_bootstrap_metadata( hostname, bootstrap_port, room_id ); } + +/// Convert string stop sequences to token IDs and append them to the MLX proto request. +/// +/// The MLX proto only supports stop_token_ids; string stop sequences from the +/// CompletionRequest must be tokenized here before the request is dispatched. +/// No-op if the request has no string stop sequences. +#[expect( + clippy::result_large_err, + reason = "Response is the standard error type in the pipeline stage pattern" +)] +pub(crate) fn apply_mlx_stop_sequences( + proto_request: &mut ProtoGenerateRequest, + stop: Option<&StringOrArray>, + tokenizer: Option<&dyn Tokenizer>, +) -> Result<(), Response> { + let Some(stop) = stop else { + return Ok(()); + }; + + let token_ids = resolve_mlx_stop_ids(stop, tokenizer)?; + + if let ProtoGenerateRequest::Mlx(req) = proto_request { + let sampling = req.sampling_params.as_mut().ok_or_else(|| { + error::internal_error( + "mlx_sampling_params_missing", + "MLX GenerateRequest has no sampling_params; cannot inject stop IDs", + ) + })?; + sampling.stop_token_ids.extend(token_ids); + } + + Ok(()) +} diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 971ff388b..0bd0f3818 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -738,6 +738,14 @@ impl ProtoGenerateComplete { matches!(self, Self::Mlx(_)) } + /// Return the raw matched stop token ID for MLX responses; None for all other backends. + pub fn mlx_matched_stop_token_id(&self) -> Option { + match self { + Self::Mlx(c) => c.matched_stop_token_id, + _ => None, + } + } + /// Get token IDs from either backend (output_ids in proto) pub fn token_ids(&self) -> &[u32] { match self { diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 582b401ad..45d4b896d 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -184,7 +184,16 @@ impl ResponseProcessor { finish_reason_str }; - let matched_stop = complete.matched_stop_json(); + let matched_stop = if complete.is_mlx() { + utils::resolve_mlx_matched_stop_json( + complete.mlx_matched_stop_token_id(), + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ) + } else { + complete.matched_stop_json() + }; // Step 4: Convert output logprobs if present let logprobs = complete.output_logprobs().map(|ref proto_logprobs| { @@ -760,7 +769,7 @@ impl ResponseProcessor { execution_result: ExecutionResult, completion_req: Arc, dispatch: DispatchMetadata, - _tokenizer: Arc, + tokenizer: Arc, stop_decoder: &mut StopSequenceDecoder, prompt_text: &str, ) -> Result { @@ -822,7 +831,16 @@ impl ResponseProcessor { } }; - let matched_stop = complete.matched_stop_json(); + let matched_stop = if complete.is_mlx() { + utils::resolve_mlx_matched_stop_json( + complete.mlx_matched_stop_token_id(), + completion_req.stop.as_ref(), + completion_req.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ) + } else { + complete.matched_stop_json() + }; let suffix_len = completion_req.suffix.as_ref().map_or(0, |s| s.len()); let echo_len = if completion_req.echo { diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index d8be6b682..885ef26a9 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -114,6 +114,14 @@ impl PipelineStage for ChatRequestBuildingStage { } } + if builder_client.is_mlx() { + helpers::apply_mlx_stop_sequences( + &mut proto_request, + chat_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; + } + ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) } diff --git a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs index d5f7db249..af2ebcb18 100644 --- a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs @@ -93,7 +93,13 @@ impl PipelineStage for CompletionRequestBuildingStage { helpers::maybe_inject_pd_metadata(&mut proto_request, workers); } } - + if builder_client.is_mlx() { + helpers::apply_mlx_stop_sequences( + &mut proto_request, + completion_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; + } ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) } diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 031c93ab0..3043f1a6d 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -492,7 +492,19 @@ impl StreamingProcessor { cached_tokens.insert(index, complete.cached_tokens()); finish_reasons.insert(index, complete.finish_reason().to_string()); - matched_stops.insert(index, complete.matched_stop_json()); + matched_stops.insert( + index, + if complete.is_mlx() { + utils::resolve_mlx_matched_stop_json( + complete.mlx_matched_stop_token_id(), + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ) + } else { + complete.matched_stop_json() + }, + ); // Don't break - continue reading all Complete messages for n>1 } @@ -2407,6 +2419,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: None, + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2434,6 +2447,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: None, + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2453,6 +2467,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: Some("stop".to_string()), + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2487,6 +2502,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: None, + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2510,6 +2526,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: None, + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2532,6 +2549,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason: None, + matched_stop: None, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2568,6 +2586,17 @@ impl StreamingProcessor { } }; + let matched_stop = if complete.is_mlx() { + utils::resolve_mlx_matched_stop_json( + complete.mlx_matched_stop_token_id(), + completion_request.stop.as_ref(), + completion_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ) + } else { + complete.matched_stop_json() + }; + let final_chunk = CompletionStreamResponse { id: request_id.clone(), object: "text_completion".to_string(), @@ -2577,6 +2606,7 @@ impl StreamingProcessor { index, logprobs: None, finish_reason, + matched_stop, }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index b75b8abf6..53dcb04de 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -17,7 +17,7 @@ use openai_protocol::{ }; use serde_json::{json, Value}; use tokio::sync::mpsc; -use tracing::error; +use tracing::{error, warn}; use uuid::Uuid; use crate::routers::{ @@ -419,6 +419,94 @@ pub fn create_stop_decoder( builder.build() } +/// Tokenizes stop strings into token IDs for the MLX backend. +/// +/// Returns `Err` if any string encodes to more than one token — the caller +/// should surface this as an HTTP 400 so the client knows the stop condition +/// was not honored rather than silently ignoring it. +/// Strings that encode to zero tokens (unknown vocab) are skipped with a warning. +pub(crate) fn stop_strings_to_token_ids<'a>( + stop: impl IntoIterator, + tokenizer: &dyn Tokenizer, +) -> Result, String> { + let mut ids = Vec::new(); + for s in stop { + match tokenizer.encode(s, false) { + Ok(enc) => match enc.token_ids() { + [id] => ids.push(*id), + tokens if !tokens.is_empty() => { + return Err(format!( + "stop string {s:?} encodes to {} tokens; \ + MLX backend only supports single-token stop strings", + tokens.len() + )); + } + _ => warn!( + stop_string = s, + "stop string produced no tokens for MLX, skipping" + ), + }, + Err(e) => warn!(stop_string = s, error = %e, "failed to tokenize stop string for MLX"), + } + } + Ok(ids) +} + +/// Resolve the `matched_stop` JSON value for an MLX response. +/// +/// MLX only returns a token ID; this reverses the mapping back to the user-facing form: +/// - If the token ID was tokenized from a user stop string → return the string. +/// - If the token ID was an explicit user stop_token_id → return the integer. +/// - Otherwise (EOS or other internal stop) → return None. +pub(crate) fn resolve_mlx_matched_stop_json( + matched_token_id: Option, + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + tokenizer: &dyn Tokenizer, +) -> Option { + let id = matched_token_id?; + + // Check stop strings first: find the string that tokenizes to this single token. + if let Some(stop_strings) = stop { + for s in stop_strings.iter() { + if let Ok(enc) = tokenizer.encode(s, false) { + if enc.token_ids() == [id] { + return Some(Value::String(s.to_string())); + } + } + } + } + + // Check explicit stop_token_ids provided by the user. + if stop_token_ids.is_some_and(|ids| ids.contains(&id)) { + return Some(Value::Number(id.into())); + } + + // EOS or other internal stop condition — don't surface to the caller. + None +} + +/// For MLX: tokenize string stop sequences and merge with existing token IDs. +/// Returns an HTTP error response if the tokenizer is missing or a stop string encodes +/// to more than one token (propagate with `?` from a pipeline stage). +#[expect( + clippy::result_large_err, + reason = "Response is the standard error type in the pipeline stage pattern" +)] +pub(crate) fn resolve_mlx_stop_ids( + stop_strings: &StringOrArray, + tokenizer: Option<&dyn Tokenizer>, +) -> Result, Response> { + let tok = tokenizer.ok_or_else(|| { + error::bad_request( + "tokenizer_unavailable", + "MLX backend requires a tokenizer to convert string stop sequences", + ) + })?; + stop_strings_to_token_ids(stop_strings.iter(), tok) + .map_err(|e| error::bad_request("unsupported_stop_string", e)) +} + /// Parse tool calls from JSON schema constrained response pub(crate) fn parse_json_schema_response( processed_text: &str, @@ -582,12 +670,13 @@ pub(crate) fn parse_finish_reason( #[cfg(test)] mod tests { - use llm_tokenizer::chat_template::ChatTemplateContentFormat; + use llm_tokenizer::{chat_template::ChatTemplateContentFormat, MockTokenizer}; use openai_protocol::{ chat::{ChatMessage, MessageContent}, common::{ContentPart, ImageUrl}, }; use serde_json::json; + use test_case::test_case; use super::*; @@ -780,4 +869,49 @@ mod tests { assert_eq!(content_array[0]["type"], "text"); assert_eq!(content_array[1], json!({"type": "image"})); } + // MockTokenizer vocab used below: "Hello"→1, "world"→2, "test"→3, + // "<|im_end|>"→1002. expected = None means the call should return Err. + #[test_case(&["Hello"], Some(&[1u32]) ; "single token regular")] + #[test_case(&["world"], Some(&[2]) ; "single token another regular")] + #[test_case(&["<|im_end|>"], Some(&[1002]) ; "single token special")] + #[test_case(&["Hello world"], None ; "multi token returns err")] + #[test_case(&["zzzunknown"], Some(&[]) ; "unknown vocab skipped")] + #[test_case(&["Hello", "Hello world"], None ; "array with multi token err")] + #[test_case(&["Hello", "test"], Some(&[1, 3]) ; "array all single token")] + #[test_case(&[], Some(&[]) ; "empty array")] + fn test_stop_strings_to_token_ids(inputs: &[&str], expected: Option<&[u32]>) { + let tok = MockTokenizer::new(); + let result = stop_strings_to_token_ids(inputs.iter().copied(), &tok); + match expected { + Some(ids) => assert_eq!(result.unwrap(), ids), + None => assert!(result.is_err()), + } + } + + #[test] + fn test_stop_encode_error_skipped() { + // Tokenizer errors are silently skipped; the function returns Ok(empty). + let tok = MockTokenizer::failing(); + let result = stop_strings_to_token_ids(["Hello", "test"].iter().copied(), &tok); + assert!(result.unwrap().is_empty()); + } + + // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. + // stop_ids=&[] is treated as None (no user stop_token_ids supplied). + #[test_case(None, None, &[] => None ; "no id returns none")] + #[test_case(Some(1), Some("Hello"), &[] => Some(Value::String("Hello".to_string())) ; "string match")] + #[test_case(Some(42), None, &[42] => Some(Value::Number(42u32.into())) ; "token id match")] + #[test_case(Some(1), Some("Hello"), &[1] => Some(Value::String("Hello".to_string())) ; "string wins over token id")] + #[test_case(Some(999), None, &[] => None ; "eos returns none")] + fn test_resolve_mlx_matched_stop( + id: Option, + stop_str: Option<&str>, + stop_ids: &[u32], + ) -> Option { + let tok = MockTokenizer::new(); + let stop = stop_str.map(|s| StringOrArray::String(s.to_string())); + let ids: Vec = stop_ids.to_vec(); + let ids_opt = if ids.is_empty() { None } else { Some(&ids) }; + resolve_mlx_matched_stop_json(id, stop.as_ref(), ids_opt, &tok) + } } diff --git a/model_gateway/src/routers/grpc/utils/mod.rs b/model_gateway/src/routers/grpc/utils/mod.rs index 2415494ae..a4cc89b7f 100644 --- a/model_gateway/src/routers/grpc/utils/mod.rs +++ b/model_gateway/src/routers/grpc/utils/mod.rs @@ -12,7 +12,7 @@ pub use chat_utils::{create_stop_decoder, process_chat_messages}; pub(crate) use chat_utils::{ filter_chat_request_by_tool_choice, filter_tools_by_tool_choice, generate_tool_call_id, get_history_tool_calls_count, parse_finish_reason, parse_json_schema_response, - resolve_tokenizer, send_error_sse, + resolve_mlx_matched_stop_json, resolve_mlx_stop_ids, resolve_tokenizer, send_error_sse, }; pub(crate) use logprobs::{ convert_generate_input_logprobs, convert_generate_output_logprobs, convert_proto_logprobs, From 6f3908991ae33d2211ee4b98bdaa2cc046803742 Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Fri, 8 May 2026 22:06:01 -0700 Subject: [PATCH 2/8] move mlx match_stop processing logic into proto wrapper Signed-off-by: Zhuo Li --- .../src/routers/grpc/proto_wrapper.rs | 32 ++++++++++++++++--- .../src/routers/grpc/regular/processor.rs | 30 ++++++----------- .../src/routers/grpc/regular/streaming.rs | 30 ++++++----------- 3 files changed, 48 insertions(+), 44 deletions(-) diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 0bd0f3818..4a84080ee 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -6,6 +6,8 @@ use std::collections::HashMap; use futures_util::StreamExt; +use llm_tokenizer::traits::Tokenizer; +use openai_protocol::common::StringOrArray; use smg_grpc_client::{ mlx_engine::AbortOnDropStream as MlxStream, mlx_proto::{self as mlx}, @@ -739,7 +741,7 @@ impl ProtoGenerateComplete { } /// Return the raw matched stop token ID for MLX responses; None for all other backends. - pub fn mlx_matched_stop_token_id(&self) -> Option { + fn mlx_matched_stop_token_id(&self) -> Option { match self { Self::Mlx(c) => c.matched_stop_token_id, _ => None, @@ -828,9 +830,31 @@ impl ProtoGenerateComplete { TrtllmMatchedStop::MatchedTokenId, TrtllmMatchedStop::MatchedStopStr ), - Self::Mlx(c) => c - .matched_stop_token_id - .map(|id| serde_json::Value::Number(id.into())), + // MLX requires request context to resolve the token ID; use matched_stop_json_with_context. + Self::Mlx(_) => unreachable!("matched_stop_json called for MLX backend"), + } + } + + /// Resolve the matched stop for any backend, using request context for MLX. + /// + /// MLX only stores a token ID; this maps it back to the user-facing string or integer + /// (see `chat_utils::resolve_mlx_matched_stop_json`). All other backends return + /// `matched_stop_json()` directly. + pub fn matched_stop_json_with_context( + &self, + stop: Option<&StringOrArray>, + stop_token_ids: Option<&Vec>, + tokenizer: &dyn Tokenizer, + ) -> Option { + if self.is_mlx() { + crate::routers::grpc::utils::resolve_mlx_matched_stop_json( + self.mlx_matched_stop_token_id(), + stop, + stop_token_ids, + tokenizer, + ) + } else { + self.matched_stop_json() } } diff --git a/model_gateway/src/routers/grpc/regular/processor.rs b/model_gateway/src/routers/grpc/regular/processor.rs index 45d4b896d..d6e504f90 100644 --- a/model_gateway/src/routers/grpc/regular/processor.rs +++ b/model_gateway/src/routers/grpc/regular/processor.rs @@ -184,16 +184,11 @@ impl ResponseProcessor { finish_reason_str }; - let matched_stop = if complete.is_mlx() { - utils::resolve_mlx_matched_stop_json( - complete.mlx_matched_stop_token_id(), - original_request.stop.as_ref(), - original_request.stop_token_ids.as_ref(), - tokenizer.as_ref(), - ) - } else { - complete.matched_stop_json() - }; + let matched_stop = complete.matched_stop_json_with_context( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); // Step 4: Convert output logprobs if present let logprobs = complete.output_logprobs().map(|ref proto_logprobs| { @@ -831,16 +826,11 @@ impl ResponseProcessor { } }; - let matched_stop = if complete.is_mlx() { - utils::resolve_mlx_matched_stop_json( - complete.mlx_matched_stop_token_id(), - completion_req.stop.as_ref(), - completion_req.stop_token_ids.as_ref(), - tokenizer.as_ref(), - ) - } else { - complete.matched_stop_json() - }; + let matched_stop = complete.matched_stop_json_with_context( + completion_req.stop.as_ref(), + completion_req.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); let suffix_len = completion_req.suffix.as_ref().map_or(0, |s| s.len()); let echo_len = if completion_req.echo { diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 3043f1a6d..352b953ce 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -494,16 +494,11 @@ impl StreamingProcessor { matched_stops.insert( index, - if complete.is_mlx() { - utils::resolve_mlx_matched_stop_json( - complete.mlx_matched_stop_token_id(), - original_request.stop.as_ref(), - original_request.stop_token_ids.as_ref(), - tokenizer.as_ref(), - ) - } else { - complete.matched_stop_json() - }, + complete.matched_stop_json_with_context( + original_request.stop.as_ref(), + original_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ), ); // Don't break - continue reading all Complete messages for n>1 @@ -2586,16 +2581,11 @@ impl StreamingProcessor { } }; - let matched_stop = if complete.is_mlx() { - utils::resolve_mlx_matched_stop_json( - complete.mlx_matched_stop_token_id(), - completion_request.stop.as_ref(), - completion_request.stop_token_ids.as_ref(), - tokenizer.as_ref(), - ) - } else { - complete.matched_stop_json() - }; + let matched_stop = complete.matched_stop_json_with_context( + completion_request.stop.as_ref(), + completion_request.stop_token_ids.as_ref(), + tokenizer.as_ref(), + ); let final_chunk = CompletionStreamResponse { id: request_id.clone(), From 15c122715af1204e5e79c194c65cb2fcbe4a0883 Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Fri, 8 May 2026 23:09:03 -0700 Subject: [PATCH 3/8] fix double gated apply_mlx_stop_sequences: helper is unconditional and no-ops on non-MLX Signed-off-by: Zhuo Li --- .../src/routers/grpc/common/stages/helpers.rs | 4 ++-- .../grpc/regular/stages/chat/request_building.rs | 12 +++++------- .../regular/stages/completion/request_building.rs | 12 +++++------- 3 files changed, 12 insertions(+), 16 deletions(-) diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 0a5cee5bd..211b6890e 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -293,9 +293,8 @@ pub(crate) fn apply_mlx_stop_sequences( return Ok(()); }; - let token_ids = resolve_mlx_stop_ids(stop, tokenizer)?; - if let ProtoGenerateRequest::Mlx(req) = proto_request { + let token_ids = resolve_mlx_stop_ids(stop, tokenizer)?; let sampling = req.sampling_params.as_mut().ok_or_else(|| { error::internal_error( "mlx_sampling_params_missing", @@ -305,5 +304,6 @@ pub(crate) fn apply_mlx_stop_sequences( sampling.stop_token_ids.extend(token_ids); } + // Non-MLX backends handle string stop sequences natively; no-op for them. Ok(()) } diff --git a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs index 885ef26a9..cac790142 100644 --- a/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/chat/request_building.rs @@ -114,13 +114,11 @@ impl PipelineStage for ChatRequestBuildingStage { } } - if builder_client.is_mlx() { - helpers::apply_mlx_stop_sequences( - &mut proto_request, - chat_request.stop.as_ref(), - ctx.state.tokenizer.as_deref(), - )?; - } + helpers::apply_mlx_stop_sequences( + &mut proto_request, + chat_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) diff --git a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs index af2ebcb18..f829d7b2e 100644 --- a/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs +++ b/model_gateway/src/routers/grpc/regular/stages/completion/request_building.rs @@ -93,13 +93,11 @@ impl PipelineStage for CompletionRequestBuildingStage { helpers::maybe_inject_pd_metadata(&mut proto_request, workers); } } - if builder_client.is_mlx() { - helpers::apply_mlx_stop_sequences( - &mut proto_request, - completion_request.stop.as_ref(), - ctx.state.tokenizer.as_deref(), - )?; - } + helpers::apply_mlx_stop_sequences( + &mut proto_request, + completion_request.stop.as_ref(), + ctx.state.tokenizer.as_deref(), + )?; ctx.state.proto_request = Some(ProtoRequest::Generate(proto_request)); Ok(None) } From 271cec93c8568e7c38429a5e6b2ab88f84b5bdeb Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Fri, 8 May 2026 23:36:26 -0700 Subject: [PATCH 4/8] fix silent encode-error: zero token and failed tokenizer throw 400 error Signed-off-by: Zhuo Li --- .../src/routers/grpc/utils/chat_utils.rs | 57 ++++++++++++++----- 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index 53dcb04de..f798921c8 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -17,7 +17,7 @@ use openai_protocol::{ }; use serde_json::{json, Value}; use tokio::sync::mpsc; -use tracing::{error, warn}; +use tracing::error; use uuid::Uuid; use crate::routers::{ @@ -421,10 +421,12 @@ pub fn create_stop_decoder( /// Tokenizes stop strings into token IDs for the MLX backend. /// -/// Returns `Err` if any string encodes to more than one token — the caller -/// should surface this as an HTTP 400 so the client knows the stop condition -/// was not honored rather than silently ignoring it. -/// Strings that encode to zero tokens (unknown vocab) are skipped with a warning. +/// Returns `Err` for any string that cannot be honored as a stop condition: +/// - encodes to more than one token +/// - encodes to zero tokens (not in vocabulary) +/// - tokenizer returns an error +/// +/// The caller should surface all errors as HTTP 400. pub(crate) fn stop_strings_to_token_ids<'a>( stop: impl IntoIterator, tokenizer: &dyn Tokenizer, @@ -441,12 +443,16 @@ pub(crate) fn stop_strings_to_token_ids<'a>( tokens.len() )); } - _ => warn!( - stop_string = s, - "stop string produced no tokens for MLX, skipping" - ), + _ => { + return Err(format!( + "stop string {s:?} produced no tokens; \ + it may not be present in the model vocabulary" + )); + } }, - Err(e) => warn!(stop_string = s, error = %e, "failed to tokenize stop string for MLX"), + Err(e) => { + return Err(format!("failed to tokenize stop string {s:?}: {e}")); + } } } Ok(ids) @@ -670,6 +676,7 @@ pub(crate) fn parse_finish_reason( #[cfg(test)] mod tests { + use axum::http::StatusCode; use llm_tokenizer::{chat_template::ChatTemplateContentFormat, MockTokenizer}; use openai_protocol::{ chat::{ChatMessage, MessageContent}, @@ -875,7 +882,7 @@ mod tests { #[test_case(&["world"], Some(&[2]) ; "single token another regular")] #[test_case(&["<|im_end|>"], Some(&[1002]) ; "single token special")] #[test_case(&["Hello world"], None ; "multi token returns err")] - #[test_case(&["zzzunknown"], Some(&[]) ; "unknown vocab skipped")] + #[test_case(&["zzzunknown"], None ; "unknown vocab returns err")] #[test_case(&["Hello", "Hello world"], None ; "array with multi token err")] #[test_case(&["Hello", "test"], Some(&[1, 3]) ; "array all single token")] #[test_case(&[], Some(&[]) ; "empty array")] @@ -889,11 +896,33 @@ mod tests { } #[test] - fn test_stop_encode_error_skipped() { - // Tokenizer errors are silently skipped; the function returns Ok(empty). + fn test_stop_encode_error_returns_err() { let tok = MockTokenizer::failing(); let result = stop_strings_to_token_ids(["Hello", "test"].iter().copied(), &tok); - assert!(result.unwrap().is_empty()); + assert!(result.is_err()); + } + + #[test] + fn test_resolve_mlx_stop_ids_zero_token_is_400() { + let tok = MockTokenizer::new(); + let stop = StringOrArray::String("zzzunknown".to_string()); + let resp = resolve_mlx_stop_ids(&stop, Some(&tok)).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_resolve_mlx_stop_ids_tokenizer_error_is_400() { + let tok = MockTokenizer::failing(); + let stop = StringOrArray::String("Hello".to_string()); + let resp = resolve_mlx_stop_ids(&stop, Some(&tok)).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); + } + + #[test] + fn test_resolve_mlx_stop_ids_missing_tokenizer_is_400() { + let stop = StringOrArray::String("Hello".to_string()); + let resp = resolve_mlx_stop_ids(&stop, None).unwrap_err(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. From 486da46e2c8f672cb8691736b291c1d7de7800ab Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Fri, 8 May 2026 23:54:02 -0700 Subject: [PATCH 5/8] fix: remove test-case dep and refactor unit tests in chat utils Signed-off-by: Zhuo Li --- Cargo.toml | 1 - model_gateway/Cargo.toml | 1 - .../src/routers/grpc/utils/chat_utils.rs | 74 ++++++++++--------- 3 files changed, 41 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index fca10cf91..1cacf31fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,6 @@ str0m = { version = "0.19", default-features = false, features = ["openssl"] } scopeguard = "1.2" bitflags = "2.10.0" schemars = "0.8" -test-case = "3.3.1" [workspace.lints.rust] unsafe_code = "deny" diff --git a/model_gateway/Cargo.toml b/model_gateway/Cargo.toml index 83415bf6b..5de28d1bb 100644 --- a/model_gateway/Cargo.toml +++ b/model_gateway/Cargo.toml @@ -144,7 +144,6 @@ rsa = { version = "0.9", features = ["sha2"] } jsonwebtoken = "9.3" validator = "0.20.0" kv-index.workspace = true -test-case = { workspace = true } wasmtime-wasi = { workspace = true } lru = { workspace = true } wat = "1.244" diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index f798921c8..d3f050de7 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -683,8 +683,6 @@ mod tests { common::{ContentPart, ImageUrl}, }; use serde_json::json; - use test_case::test_case; - use super::*; #[test] @@ -876,22 +874,27 @@ mod tests { assert_eq!(content_array[0]["type"], "text"); assert_eq!(content_array[1], json!({"type": "image"})); } - // MockTokenizer vocab used below: "Hello"→1, "world"→2, "test"→3, - // "<|im_end|>"→1002. expected = None means the call should return Err. - #[test_case(&["Hello"], Some(&[1u32]) ; "single token regular")] - #[test_case(&["world"], Some(&[2]) ; "single token another regular")] - #[test_case(&["<|im_end|>"], Some(&[1002]) ; "single token special")] - #[test_case(&["Hello world"], None ; "multi token returns err")] - #[test_case(&["zzzunknown"], None ; "unknown vocab returns err")] - #[test_case(&["Hello", "Hello world"], None ; "array with multi token err")] - #[test_case(&["Hello", "test"], Some(&[1, 3]) ; "array all single token")] - #[test_case(&[], Some(&[]) ; "empty array")] - fn test_stop_strings_to_token_ids(inputs: &[&str], expected: Option<&[u32]>) { + #[test] + fn test_stop_strings_to_token_ids() { + // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. + // expected = None means the call should return Err. + let cases: &[(&[&str], Option<&[u32]>, &str)] = &[ + (&["Hello"], Some(&[1]), "single token regular"), + (&["world"], Some(&[2]), "single token another regular"), + (&["<|im_end|>"], Some(&[1002]), "single token special"), + (&["Hello world"], None, "multi token returns err"), + (&["zzzunknown"], None, "unknown vocab returns err"), + (&["Hello", "Hello world"], None, "array with multi token err"), + (&["Hello", "test"], Some(&[1, 3]), "array all single token"), + (&[], Some(&[]), "empty array"), + ]; let tok = MockTokenizer::new(); - let result = stop_strings_to_token_ids(inputs.iter().copied(), &tok); - match expected { - Some(ids) => assert_eq!(result.unwrap(), ids), - None => assert!(result.is_err()), + for &(inputs, expected, name) in cases { + let result = stop_strings_to_token_ids(inputs.iter().copied(), &tok); + match expected { + Some(ids) => assert_eq!(result.unwrap(), ids, "{name}"), + None => assert!(result.is_err(), "{name}"), + } } } @@ -925,22 +928,27 @@ mod tests { assert_eq!(resp.status(), StatusCode::BAD_REQUEST); } - // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. - // stop_ids=&[] is treated as None (no user stop_token_ids supplied). - #[test_case(None, None, &[] => None ; "no id returns none")] - #[test_case(Some(1), Some("Hello"), &[] => Some(Value::String("Hello".to_string())) ; "string match")] - #[test_case(Some(42), None, &[42] => Some(Value::Number(42u32.into())) ; "token id match")] - #[test_case(Some(1), Some("Hello"), &[1] => Some(Value::String("Hello".to_string())) ; "string wins over token id")] - #[test_case(Some(999), None, &[] => None ; "eos returns none")] - fn test_resolve_mlx_matched_stop( - id: Option, - stop_str: Option<&str>, - stop_ids: &[u32], - ) -> Option { + #[test] + fn test_resolve_mlx_matched_stop() { + // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. + // stop_ids=&[] is treated as None (no user stop_token_ids supplied). + let cases: &[(Option, Option<&str>, &[u32], Option, &str)] = &[ + (None, None, &[], None, "no id returns none"), + (Some(1), Some("Hello"), &[], Some(Value::String("Hello".to_string())), "string match"), + (Some(42), None, &[42], Some(Value::Number(42u32.into())), "token id match"), + (Some(1), Some("Hello"), &[1], Some(Value::String("Hello".to_string())), "string wins over token id"), + (Some(999), None, &[], None, "eos returns none"), + ]; let tok = MockTokenizer::new(); - let stop = stop_str.map(|s| StringOrArray::String(s.to_string())); - let ids: Vec = stop_ids.to_vec(); - let ids_opt = if ids.is_empty() { None } else { Some(&ids) }; - resolve_mlx_matched_stop_json(id, stop.as_ref(), ids_opt, &tok) + for (id, stop_str, stop_ids, expected, name) in cases { + let stop = stop_str.map(|s| StringOrArray::String(s.to_string())); + let ids: Vec = stop_ids.to_vec(); + let ids_opt = if ids.is_empty() { None } else { Some(&ids) }; + assert_eq!( + resolve_mlx_matched_stop_json(*id, stop.as_ref(), ids_opt, &tok), + *expected, + "{name}", + ); + } } } From 94eef5a2dfc93a56fc2522a31b15a8f3ac30a80a Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Sat, 9 May 2026 11:56:55 -0700 Subject: [PATCH 6/8] use default values for CompletionStreamChoice fields Signed-off-by: Zhuo Li --- crates/protocols/src/completion.rs | 2 +- .../src/routers/grpc/regular/streaming.rs | 24 +++++-------------- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/crates/protocols/src/completion.rs b/crates/protocols/src/completion.rs index d21a58a91..82f91020d 100644 --- a/crates/protocols/src/completion.rs +++ b/crates/protocols/src/completion.rs @@ -266,7 +266,7 @@ pub struct CompletionStreamResponse { pub usage: Option, } -#[derive(Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] +#[derive(Default, Debug, Clone, Deserialize, Serialize, schemars::JsonSchema)] pub struct CompletionStreamChoice { pub text: String, pub index: u32, diff --git a/model_gateway/src/routers/grpc/regular/streaming.rs b/model_gateway/src/routers/grpc/regular/streaming.rs index 352b953ce..b9dbbab73 100644 --- a/model_gateway/src/routers/grpc/regular/streaming.rs +++ b/model_gateway/src/routers/grpc/regular/streaming.rs @@ -2412,9 +2412,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: std::mem::take(&mut chunk_text), index, - logprobs: None, - finish_reason: None, - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2440,9 +2438,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: sfx.to_string(), index, - logprobs: None, - finish_reason: None, - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2458,11 +2454,9 @@ impl StreamingProcessor { object: "text_completion".to_string(), created, choices: vec![CompletionStreamChoice { - text: String::new(), index, - logprobs: None, finish_reason: Some("stop".to_string()), - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2495,9 +2489,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: prompt_text.to_string(), index, - logprobs: None, - finish_reason: None, - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2519,9 +2511,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text, index, - logprobs: None, - finish_reason: None, - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), @@ -2542,9 +2532,7 @@ impl StreamingProcessor { choices: vec![CompletionStreamChoice { text: sfx.to_string(), index, - logprobs: None, - finish_reason: None, - matched_stop: None, + ..Default::default() }], model: model.clone(), system_fingerprint: system_fingerprint.map(String::from), From b938bc92b3e9b9c613036b1a88b7589adaa180ed Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Sun, 10 May 2026 11:02:45 -0700 Subject: [PATCH 7/8] rebase before pushing Signed-off-by: Zhuo Li --- model_gateway/src/routers/grpc/common/stages/helpers.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 211b6890e..69502d4e6 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -14,10 +14,6 @@ use smg_grpc_client::{ use tracing::{debug, warn}; use crate::{ - routers::grpc::{ - context::{RequestType, WorkerSelection}, - proto_wrapper::ProtoGenerateRequest, - }, worker::{ sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT, DEFAULT_SAMPLING_PARAMS_LABEL, @@ -25,11 +21,11 @@ use crate::{ routers::{ error, grpc::{ - context::WorkerSelection, proto_wrapper::ProtoGenerateRequest, + context::{RequestType, WorkerSelection}, + proto_wrapper::ProtoGenerateRequest, utils::resolve_mlx_stop_ids, }, }, - worker::{RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT}, }; #[derive(Clone, Copy, Debug, Default)] From fac846fbd67b26fd0912d6b93536a5295d75ade8 Mon Sep 17 00:00:00 2001 From: Zhuo Li Date: Sun, 10 May 2026 11:36:21 -0700 Subject: [PATCH 8/8] fix fmt and clippy issues Signed-off-by: Zhuo Li --- .../src/routers/grpc/common/stages/helpers.rs | 8 +-- .../src/routers/grpc/proto_wrapper.rs | 8 ++- .../src/routers/grpc/utils/chat_utils.rs | 62 ++++++++++++++----- 3 files changed, 58 insertions(+), 20 deletions(-) diff --git a/model_gateway/src/routers/grpc/common/stages/helpers.rs b/model_gateway/src/routers/grpc/common/stages/helpers.rs index 69502d4e6..3cacf0ff6 100644 --- a/model_gateway/src/routers/grpc/common/stages/helpers.rs +++ b/model_gateway/src/routers/grpc/common/stages/helpers.rs @@ -14,10 +14,6 @@ use smg_grpc_client::{ use tracing::{debug, warn}; use crate::{ - worker::{ - sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT, - DEFAULT_SAMPLING_PARAMS_LABEL, - }, routers::{ error, grpc::{ @@ -26,6 +22,10 @@ use crate::{ utils::resolve_mlx_stop_ids, }, }, + worker::{ + sampling_defaults::SamplingDefaults, RuntimeType, Worker, DEFAULT_BOOTSTRAP_PORT, + DEFAULT_SAMPLING_PARAMS_LABEL, + }, }; #[derive(Clone, Copy, Debug, Default)] diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 4a84080ee..279a280a9 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -19,6 +19,8 @@ use smg_grpc_client::{ vllm_proto::{self as vllm, generate_complete::MatchedStop as VllmMatchedStop}, }; +use crate::routers::grpc::utils::resolve_mlx_matched_stop_json; + // ===================== // Multimodal Data // ===================== @@ -805,6 +807,10 @@ impl ProtoGenerateComplete { /// - MatchedTokenId → Number /// - MatchedStopStr → String /// - None → None + #[expect( + clippy::unreachable, + reason = "MLX must use matched_stop_json_with_context" + )] pub fn matched_stop_json(&self) -> Option { macro_rules! convert { ($oneof:expr, $token_id:path, $stop_str:path) => { @@ -847,7 +853,7 @@ impl ProtoGenerateComplete { tokenizer: &dyn Tokenizer, ) -> Option { if self.is_mlx() { - crate::routers::grpc::utils::resolve_mlx_matched_stop_json( + resolve_mlx_matched_stop_json( self.mlx_matched_stop_token_id(), stop, stop_token_ids, diff --git a/model_gateway/src/routers/grpc/utils/chat_utils.rs b/model_gateway/src/routers/grpc/utils/chat_utils.rs index d3f050de7..504260ccb 100644 --- a/model_gateway/src/routers/grpc/utils/chat_utils.rs +++ b/model_gateway/src/routers/grpc/utils/chat_utils.rs @@ -683,8 +683,18 @@ mod tests { common::{ContentPart, ImageUrl}, }; use serde_json::json; + use super::*; + type StopTokenCase<'a> = (&'a [&'a str], Option<&'a [u32]>, &'a str); + type MatchedStopCase<'a> = ( + Option, + Option<&'a str>, + &'a [u32], + Option, + &'a str, + ); + #[test] fn test_transform_messages_string_format() { let messages = vec![ChatMessage::User { @@ -878,15 +888,19 @@ mod tests { fn test_stop_strings_to_token_ids() { // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. // expected = None means the call should return Err. - let cases: &[(&[&str], Option<&[u32]>, &str)] = &[ - (&["Hello"], Some(&[1]), "single token regular"), - (&["world"], Some(&[2]), "single token another regular"), - (&["<|im_end|>"], Some(&[1002]), "single token special"), - (&["Hello world"], None, "multi token returns err"), - (&["zzzunknown"], None, "unknown vocab returns err"), - (&["Hello", "Hello world"], None, "array with multi token err"), - (&["Hello", "test"], Some(&[1, 3]), "array all single token"), - (&[], Some(&[]), "empty array"), + let cases: &[StopTokenCase<'_>] = &[ + (&["Hello"], Some(&[1]), "single token regular"), + (&["world"], Some(&[2]), "single token another regular"), + (&["<|im_end|>"], Some(&[1002]), "single token special"), + (&["Hello world"], None, "multi token returns err"), + (&["zzzunknown"], None, "unknown vocab returns err"), + ( + &["Hello", "Hello world"], + None, + "array with multi token err", + ), + (&["Hello", "test"], Some(&[1, 3]), "array all single token"), + (&[], Some(&[]), "empty array"), ]; let tok = MockTokenizer::new(); for &(inputs, expected, name) in cases { @@ -932,12 +946,30 @@ mod tests { fn test_resolve_mlx_matched_stop() { // MockTokenizer vocab: "Hello"→1, "world"→2, "test"→3, "<|im_end|>"→1002. // stop_ids=&[] is treated as None (no user stop_token_ids supplied). - let cases: &[(Option, Option<&str>, &[u32], Option, &str)] = &[ - (None, None, &[], None, "no id returns none"), - (Some(1), Some("Hello"), &[], Some(Value::String("Hello".to_string())), "string match"), - (Some(42), None, &[42], Some(Value::Number(42u32.into())), "token id match"), - (Some(1), Some("Hello"), &[1], Some(Value::String("Hello".to_string())), "string wins over token id"), - (Some(999), None, &[], None, "eos returns none"), + let cases: &[MatchedStopCase<'_>] = &[ + (None, None, &[], None, "no id returns none"), + ( + Some(1), + Some("Hello"), + &[], + Some(Value::String("Hello".to_string())), + "string match", + ), + ( + Some(42), + None, + &[42], + Some(Value::Number(42u32.into())), + "token id match", + ), + ( + Some(1), + Some("Hello"), + &[1], + Some(Value::String("Hello".to_string())), + "string wins over token id", + ), + (Some(999), None, &[], None, "eos returns none"), ]; let tok = MockTokenizer::new(); for (id, stop_str, stop_ids, expected, name) in cases {