From 663e9dccbb6474ab5e9a66b900b3df70e1fc8064 Mon Sep 17 00:00:00 2001 From: Peppi Littera Date: Sat, 6 Jun 2026 14:00:11 +0200 Subject: [PATCH 1/4] =?UTF-8?q?feat(qwen35):=20MoE-structured=20MTP=20head?= =?UTF-8?q?=20=E2=80=94=20Qwen3.6-A3B=20speculative=20decode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Qwen3.6-A3B ships its MTP layer as a full MoE decoder layer (router gate + 256 stacked experts + shared expert/gate) with a quantized fc, both bundled and as standalone drafter sidecars (e.g. mlx-community Qwen3.6-35B-A3B-MTP- 4bit). The existing MtpHead/DenseMtpHead are dense-MLP only, so these checkpoints could not speculate. - MoeMtpHead / MoeMtpTransformerLayer: full attention + SparseMoeBlock, quantized fc (these sidecars ship fc.{weight,scales,biases} triples). Constructed at the checkpoint's uniform quantization; the main model's gate_quantization override is deliberately NOT applied (sidecar router gates are quantized at the default width). - Layout detection: MoE-structured MTP keys classify as MoeQuantized and enable the head (use_moe_mtp). Truly unprefixed sidecar keys (fc.weight, layers.0....) are mtp.-prefixed at detection and load time (normalize_sidecar_mtp_key), so mlx-community drafters work as mtp.safetensors drop-ins. - Loading: mtp.* -> moe_mtp.* param remap through both the fused and direct loaders; no dense rmsnorm adjustment for MoE targets. - Forward: MoE branches in mtp_step_hidden and mtp_advance_many; has_mtp / make_mtp_cache cover the new head. Measured on Qwen3.6-35B-A3B-4bit + MTP drafter (M-series, kv-bits 4): short/structured output 60-62 tok/s at 100% accept (vs 40 tok/s baseline, +50%); long prose 37-41 tok/s at 60-71% accept (breakeven). Outputs verified exact (drafts go through the verify path). Tests: MoE layout classification, mtp->moe_mtp key remap, sidecar key normalization (aux-only, idempotent). Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/higgs-models/src/lib.rs | 2 +- crates/higgs-models/src/qwen3_next.rs | 258 +++++++++++++++++++++++++- 2 files changed, 254 insertions(+), 6 deletions(-) diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 8ae2716a..95903b07 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -1135,7 +1135,7 @@ pub struct WeightMapIndex { pub weight_map: HashMap, } -const AUXILIARY_SAFETENSORS_FILES: &[&str] = &["mtp.safetensors", "model-mtp.safetensors"]; +pub(crate) const AUXILIARY_SAFETENSORS_FILES: &[&str] = &["mtp.safetensors", "model-mtp.safetensors"]; /// Load a tokenizer from a model directory. pub fn load_tokenizer>(model_dir: P) -> Result { diff --git a/crates/higgs-models/src/qwen3_next.rs b/crates/higgs-models/src/qwen3_next.rs index ba3e5da4..6a66ce0d 100644 --- a/crates/higgs-models/src/qwen3_next.rs +++ b/crates/higgs-models/src/qwen3_next.rs @@ -215,6 +215,15 @@ pub struct Qwen3NextModelArgs { /// loader after inspecting checkpoint keys; it is not expected in configs. #[serde(default)] pub use_dense_mtp: bool, + + /// Use an MoE-structured MTP head (Qwen3.6-A3B style). + /// + /// These sidecars ship the MTP layer as a full MoE decoder layer + /// (`mlp.gate`, `mlp.switch_mlp.*`, `mlp.shared_expert*`) with a quantized + /// `fc`. Set by the loader after inspecting checkpoint keys; not expected + /// in configs. + #[serde(default)] + pub use_moe_mtp: bool, } // --------------------------------------------------------------------------- @@ -1951,6 +1960,86 @@ impl DenseMtpHead { } } +/// Single MoE MTP transformer layer (Qwen3.6-A3B style). +/// +/// Qwen3.6-A3B sidecars ship the MTP layer as a full MoE decoder layer: +/// full attention (with q/k norms) + `SparseMoeBlock` +/// (router gate + stacked experts + shared expert + shared-expert gate). +#[derive(Debug, Clone, ModuleParameters)] +struct MoeMtpTransformerLayer { + #[param] + self_attn: Qwen3NextAttention, + #[param] + input_layernorm: nn::RmsNorm, + #[param] + post_attention_layernorm: nn::RmsNorm, + #[param] + mlp: SparseMoeBlock, +} + +/// MTP head with an MoE transformer layer (Qwen3.6-A3B style sidecars). +/// +/// Unlike [`MtpHead`], the fusion projection `fc` is a quantized [`QLinear`] +/// (these sidecars ship `fc.{weight,scales,biases}` triples), and the MLP is +/// a [`SparseMoeBlock`]. All projections use the checkpoint's uniform +/// quantization — the main model's `gate_quantization` override must NOT be +/// applied here (the sidecar's router gate is quantized at the default width). +#[derive(Debug, Clone, ModuleParameters)] +pub struct MoeMtpHead { + #[param] + pre_fc_norm_hidden: nn::RmsNorm, + #[param] + pre_fc_norm_embedding: nn::RmsNorm, + #[param] + fc: QLinear, + #[param] + layers: Vec, + #[param] + norm: nn::RmsNorm, +} + +impl MoeMtpHead { + fn new(args: &Qwen3NextModelArgs, ql: i32, qb: i32) -> Result { + let n = usize::try_from(args.mtp_num_hidden_layers) + .map_err(|_| Exception::custom("mtp_num_hidden_layers must be non-negative"))?; + + // The sidecar's MoE block is uniformly quantized at the default width; + // strip the main model's per-layer gate override so the router gate's + // QLinear dequantizes with the right parameters. + let mut mtp_args = args.clone(); + mtp_args.gate_quantization = None; + + let layers = (0..n) + .map(|_| { + Ok(MoeMtpTransformerLayer { + self_attn: Qwen3NextAttention::new(args, ql, qb)?, + input_layernorm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + post_attention_layernorm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + mlp: SparseMoeBlock::new(&mtp_args, ql, qb)?, + }) + }) + .collect::, Exception>>()?; + + Ok(Self { + pre_fc_norm_hidden: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + pre_fc_norm_embedding: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + fc: QLinear::new(ql, qb)?, + layers, + norm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + }) + } +} + // --------------------------------------------------------------------------- // SwitchMLP weights (stacked expert weights for MoE) // --------------------------------------------------------------------------- @@ -3335,6 +3424,8 @@ pub struct Qwen3NextCausalLM { mtp: Option, #[param] dense_mtp: Option, + #[param] + moe_mtp: Option, } // Manual RoPE implementation for arbitrary positions @@ -3448,7 +3539,7 @@ impl Qwen3NextCausalLM { } else { Some(QLinear::new(ql, qb)?) }; - let mtp = if args.mtp_num_hidden_layers > 0 && !args.use_dense_mtp { + let mtp = if args.mtp_num_hidden_layers > 0 && !args.use_dense_mtp && !args.use_moe_mtp { Some(MtpHead::new(&args, ql, qb)?) } else { None @@ -3458,6 +3549,11 @@ impl Qwen3NextCausalLM { } else { None }; + let moe_mtp = if args.mtp_num_hidden_layers > 0 && args.use_moe_mtp { + Some(MoeMtpHead::new(&args, ql, qb)?) + } else { + None + }; Ok(Self { args, @@ -3465,6 +3561,7 @@ impl Qwen3NextCausalLM { lm_head, mtp, dense_mtp, + moe_mtp, }) } @@ -3814,7 +3911,7 @@ impl Qwen3NextCausalLM { /// Whether this model has an MTP head loaded. pub const fn has_mtp(&self) -> bool { - self.mtp.is_some() || self.dense_mtp.is_some() + self.mtp.is_some() || self.dense_mtp.is_some() || self.moe_mtp.is_some() } /// Create a fresh KV cache for the MTP head (one entry per MTP layer). @@ -3824,7 +3921,8 @@ impl Qwen3NextCausalLM { .mtp .as_ref() .map(|mtp| mtp.layers.len()) - .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len()))?; + .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len())) + .or_else(|| self.moe_mtp.as_ref().map(|mtp| mtp.layers.len()))?; Some( (0..layer_count) .map(|_| SteppingKeyValueCache::new()) @@ -3910,6 +4008,25 @@ impl Qwen3NextCausalLM { return mtp.norm.forward(&x); } + // MoE MTP head (Qwen3.6-A3B style): same loop, MoE MLP. + if let Some(mtp) = self.moe_mtp.as_mut() { + let h_norm = mtp.pre_fc_norm_hidden.forward(hidden)?; + let e_norm = mtp.pre_fc_norm_embedding.forward(&next_embed)?; + let concat = ops::concatenate_axis(&[&e_norm, &h_norm], -1)?; + let mut x = mtp.fc.forward(&concat)?; + + for (layer, kv) in mtp.layers.iter_mut().zip(mtp_cache.iter_mut()) { + let normed = layer.input_layernorm.forward(&x)?; + let attn_out = layer.self_attn.forward(&normed, None, kv)?; + let h2 = x.add(attn_out)?; + let normed_post = layer.post_attention_layernorm.forward(&h2)?; + let mlp_out = layer.mlp.forward(&normed_post)?; + x = h2.add(mlp_out)?; + } + + return mtp.norm.forward(&x); + } + let mtp = self .dense_mtp .as_mut() @@ -4002,6 +4119,7 @@ impl Qwen3NextCausalLM { .as_ref() .map(|mtp| mtp.layers.len()) .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len())) + .or_else(|| self.moe_mtp.as_ref().map(|mtp| mtp.layers.len())) .ok_or_else(|| Exception::custom("MTP head not loaded"))?; Self::validate_mtp_advance_many_inputs(hidden, mtp_cache, expected_layers, seq_len)?; @@ -4028,6 +4146,26 @@ impl Qwen3NextCausalLM { return Ok(()); } + // MoE MTP head (Qwen3.6-A3B style): same loop, MoE MLP. + if let Some(mtp) = self.moe_mtp.as_mut() { + let h_norm = mtp.pre_fc_norm_hidden.forward(hidden)?; + let e_norm = mtp.pre_fc_norm_embedding.forward(&next_embed)?; + let concat = ops::concatenate_axis(&[&e_norm, &h_norm], -1)?; + let mut x = mtp.fc.forward(&concat)?; + + for (layer, kv) in mtp.layers.iter_mut().zip(mtp_cache.iter_mut()) { + let normed = layer.input_layernorm.forward(&x)?; + let attn_out = layer.self_attn.forward(&normed, mask_ref, kv)?; + let h2 = x.add(attn_out)?; + let normed_post = layer.post_attention_layernorm.forward(&h2)?; + let mlp_out = layer.mlp.forward(&normed_post)?; + x = h2.add(mlp_out)?; + } + + let _ = mtp.norm.forward(&x)?; + return Ok(()); + } + let mtp = self .dense_mtp .as_mut() @@ -4217,6 +4355,9 @@ enum MtpWeightLayout { None, Quantized, Dense, + /// MTP layer is MoE-structured (`mlp.gate` / `shared_expert` / experts), + /// e.g. Qwen3.6-A3B sidecars — loaded via [`MoeMtpHead`]. + MoeQuantized, } fn is_mtp_key(key: &str) -> bool { @@ -4227,6 +4368,7 @@ fn mtp_weight_layout_from_keys<'a>(keys: impl IntoIterator) -> M let mut has_mtp = false; let mut has_unprefixed_mtp = false; let mut has_quantized_aux = false; + let mut has_moe_mlp = false; for key in keys { if !is_mtp_key(key) { @@ -4235,9 +4377,15 @@ fn mtp_weight_layout_from_keys<'a>(keys: impl IntoIterator) -> M has_mtp = true; has_unprefixed_mtp |= key.starts_with("mtp."); has_quantized_aux |= key.ends_with(".scales") || key.ends_with(".biases"); + has_moe_mlp |= key.contains(".mlp.gate.") + || key.contains(".mlp.shared_expert") + || key.contains(".mlp.switch_mlp") + || key.contains(".mlp.experts"); } - if has_quantized_aux { + if has_mtp && has_moe_mlp { + MtpWeightLayout::MoeQuantized + } else if has_quantized_aux { MtpWeightLayout::Quantized } else if has_unprefixed_mtp { MtpWeightLayout::Dense @@ -4253,7 +4401,16 @@ fn checkpoint_mtp_weight_layout(model_path: &Path) -> Result = metadata + .names() + .into_iter() + .map(|name| normalize_sidecar_mtp_key(file_path, name.to_owned())) + .collect(); + Ok(mtp_weight_layout_from_keys( + normalized.iter().map(String::as_str), + )) } let index_path = model_path.join("model.safetensors.index.json"); @@ -4326,6 +4483,15 @@ fn maybe_disable_mtp_without_checkpoint_weights( args.use_dense_mtp = true; return Ok(()); } + MtpWeightLayout::MoeQuantized => { + tracing::info!( + "Checkpoint ships an MoE-structured MTP head (Qwen3.6-A3B style); \ + loading via MoeMtpHead" + ); + args.use_dense_mtp = false; + args.use_moe_mtp = true; + return Ok(()); + } MtpWeightLayout::None => {} } @@ -4770,12 +4936,41 @@ fn dense_mtp_param_key(stripped: &str) -> Option { .map(|rest| format!("dense_mtp.{rest}")) } +fn moe_mtp_param_key(stripped: &str) -> Option { + stripped + .strip_prefix("mtp.") + .map(|rest| format!("moe_mtp.{rest}")) +} + +/// Normalize a tensor key loaded from an auxiliary MTP sidecar file. +/// +/// Some sidecars (e.g. mlx-community MTP drafters) ship truly unprefixed keys +/// (`fc.weight`, `layers.0....`); prefix them with `mtp.` so they map onto the +/// model's MTP head params the same way prefixed sidecars do. Keys from +/// non-auxiliary files are returned unchanged. +fn normalize_sidecar_mtp_key(file_path: &Path, key: String) -> String { + let is_aux = file_path + .file_name() + .and_then(|n| n.to_str()) + .is_some_and(|n| crate::AUXILIARY_SAFETENSORS_FILES.contains(&n)); + if is_aux && !key.starts_with("mtp.") { + format!("mtp.{key}") + } else { + key + } +} + fn qwen35_target_param_key( params: &HashMap, &mut Array>, stripped: &str, ) -> Option<(String, bool)> { if params.contains_key(stripped) { Some((stripped.to_owned(), false)) + } else if let Some(moe_key) = + moe_mtp_param_key(stripped).filter(|key| params.contains_key(key.as_str())) + { + // MoE MTP head: plain remap, no dense rmsnorm adjustment. + Some((moe_key, false)) } else { dense_mtp_param_key(stripped) .filter(|dense_key| params.contains_key(dense_key.as_str())) @@ -4828,6 +5023,7 @@ fn load_qwen3_5_moe_weights_direct( .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); let Some(stripped) = qwen35_checkpoint_param_key(&key) else { unmatched.push(key); continue; @@ -4906,6 +5102,7 @@ fn load_qwen3_5_moe_weights_fused( .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); let Some(stripped) = qwen35_checkpoint_param_key(&key) else { continue; }; @@ -13908,6 +14105,57 @@ mod tests { ); } + #[test] + fn test_mtp_layout_detects_moe_structured_head() { + // Qwen3.6-A3B style: the MTP layer is a full MoE layer. + let layout = mtp_weight_layout_from_keys([ + "mtp.layers.0.self_attn.q_proj.weight", + "mtp.layers.0.mlp.gate.weight", + "mtp.layers.0.mlp.shared_expert.up_proj.scales", + "mtp.layers.0.mlp.switch_mlp.down_proj.weight", + "mtp.fc.weight", + ]); + assert_eq!(layout, MtpWeightLayout::MoeQuantized); + } + + #[test] + fn test_moe_mtp_param_key_remaps_mtp_prefix() { + assert_eq!( + moe_mtp_param_key("mtp.layers.0.mlp.gate.weight").as_deref(), + Some("moe_mtp.layers.0.mlp.gate.weight") + ); + assert_eq!( + moe_mtp_param_key("mtp.fc.scales").as_deref(), + Some("moe_mtp.fc.scales") + ); + assert!(moe_mtp_param_key("model.layers.0.mlp.gate.weight").is_none()); + } + + #[test] + fn test_normalize_sidecar_mtp_key_prefixes_aux_files_only() { + let aux = Path::new("/models/x/mtp.safetensors"); + let main = Path::new("/models/x/model-00001-of-00004.safetensors"); + // Unprefixed keys from the sidecar get the mtp. prefix. + assert_eq!( + normalize_sidecar_mtp_key(aux, "fc.weight".to_owned()), + "mtp.fc.weight" + ); + assert_eq!( + normalize_sidecar_mtp_key(aux, "layers.0.mlp.gate.weight".to_owned()), + "mtp.layers.0.mlp.gate.weight" + ); + // Already-prefixed sidecar keys are unchanged. + assert_eq!( + normalize_sidecar_mtp_key(aux, "mtp.fc.weight".to_owned()), + "mtp.fc.weight" + ); + // Keys from main shards are never touched. + assert_eq!( + normalize_sidecar_mtp_key(main, "fc.weight".to_owned()), + "fc.weight" + ); + } + #[test] fn test_checkpoint_mtp_weight_layout_detects_dense_auxiliary_mtp_file() { let dir = tempfile::tempdir().unwrap(); From df751599cfbcbadc815b71cab237399fd48baf1a Mon Sep 17 00:00:00 2001 From: Peppi Littera Date: Sat, 6 Jun 2026 14:01:42 +0200 Subject: [PATCH 2/4] fix(qwen35): survive mixed-bit GDN quants (OptiQ-style per-projection bits) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Mixed-precision checkpoints (e.g. mlx-community OptiQ quants) assign different bit-widths per GDN projection on sensitive layers. Two loader gaps turned those into hard load failures: - The mixed-bit detector only compared the in_proj_a/in_proj_b pair; a mismatched in_proj_qkv/in_proj_z pair slipped through to the fused loader, which then failed concatenating packed shapes like (8192,256) vs (4096,512). Check both fusion pairs. - In separate-GDN mode the fused in_proj_qkvz/in_proj_ba QLinears are still constructed (as unused placeholders — the forward dispatches on use_separate_projections), and the direct loader's completeness check flagged them as missing weights, rejecting every checkpoint that *requires* separate projections. Exempt the unused fused placeholders. Note: fully running OptiQ-style quants also needs per-projection quantization plumbed into every QLinear (their overrides span attention, shared experts, etc.) — this fix makes the GDN layer detection/loading correct and turns the failure mode from a crash into a clean report. Test: mixed in_proj_qkv/in_proj_z bits force separate GDN projections. Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/higgs-models/src/qwen3_next.rs | 55 ++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 1 deletion(-) diff --git a/crates/higgs-models/src/qwen3_next.rs b/crates/higgs-models/src/qwen3_next.rs index 6a66ce0d..ad71a4da 100644 --- a/crates/higgs-models/src/qwen3_next.rs +++ b/crates/higgs-models/src/qwen3_next.rs @@ -4679,7 +4679,17 @@ fn qwen3_5_mixed_ba_quantization_layers( }; let a_quant = projection_quantization("in_proj_a"); let b_quant = projection_quantization("in_proj_b"); - a_quant.bits != b_quant.bits || a_quant.group_size != b_quant.group_size + // The qkvz fusion pair has the same constraint: `in_proj_qkv` and + // `in_proj_z` are concatenated into `in_proj_qkvz`, which is only + // possible when their packed (quantized) shapes agree. Mixed- + // precision quants (e.g. OptiQ) assign different bit-widths per + // projection on sensitive layers, so check both fusion pairs. + let qkv_quant = projection_quantization("in_proj_qkv"); + let z_quant = projection_quantization("in_proj_z"); + a_quant.bits != b_quant.bits + || a_quant.group_size != b_quant.group_size + || qkv_quant.bits != z_quant.bits + || qkv_quant.group_size != z_quant.group_size }) .collect() } @@ -5057,9 +5067,19 @@ fn load_qwen3_5_moe_weights_direct( } } let param_count = params.len(); + // This loader is only used with separate GDN projections (see + // `load_qwen3_5_model_with_gdn_fallback`). In that mode the fused + // `in_proj_qkvz` / `in_proj_ba` QLinears are still constructed — as unused + // placeholders, since the forward path dispatches on + // `use_separate_projections` — so they must be exempt from the + // completeness check. Flagging them would reject every mixed-bit + // checkpoint that *requires* separate projections (e.g. OptiQ quants). ensure_all_model_params_loaded( params .iter() + .filter(|(name, _)| { + !(name.contains(".in_proj_qkvz.") || name.contains(".in_proj_ba.")) + }) .map(|(name, value)| (std::rc::Rc::::clone(name), &**value)), )?; tracing::info!(param_count, matched, "Total model parameters loaded"); @@ -14301,6 +14321,39 @@ mod tests { ); } + #[test] + fn test_load_qwen35_mixed_qkvz_quantization_forces_separate_gdn() { + // Mixed-precision quants (e.g. OptiQ) can also put `in_proj_qkv` and + // `in_proj_z` at different bit-widths — that breaks the qkvz fusion + // concat exactly like a mixed BA pair does. + let dir = tempfile::tempdir().unwrap(); + let config = format!( + r#"{{ + "text_config": {}, + "tie_word_embeddings": false, + "quantization": {{ + "group_size": 64, + "bits": 4, + "mode": "affine", + "language_model.model.layers.2.linear_attn.in_proj_z": {{ + "group_size": 64, + "bits": 8, + "mode": "affine" + }} + }} + }}"#, + qwen35_dense_text_config() + ); + std::fs::write(dir.path().join("config.json"), config).unwrap(); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + args.use_separate_gdn_projections, + "mixed-bit in_proj_qkv/in_proj_z must force separate GDN projections" + ); + } + #[test] fn test_load_qwen35_matching_ba_quantization_keeps_fused_gdn() { let dir = tempfile::tempdir().unwrap(); From 40bafa8192bf8b122cf2d36233fdde996f881204 Mon Sep 17 00:00:00 2001 From: Peppi Littera Date: Wed, 10 Jun 2026 13:00:45 +0200 Subject: [PATCH 3/4] =?UTF-8?q?fix(mtp):=20eval-before-deep=5Fclone=20+=20?= =?UTF-8?q?trim-based=20rollback=20=E2=80=94=20kill=20speculative-decode?= =?UTF-8?q?=20double-free?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Speculative-decode checkpoints shared MLX buffers with the live cache. The backbone rollback did `*cache = base_cache` (shallow clone) and the MTP head cache was shallow-cloned too; the live cache's in-place `slice_update` during verify then let MLX donate/free a buffer the checkpoint still held, double- freeing on drop — the `malloc: pointer being freed was not allocated` abort that crashed MTP decode at ~44 tokens. Backbone: roll KV layers back by offset (`trim_by`), never clone — no buffer aliasing. Only hybrid SSM/recurrent state (can't be offset-trimmed) still clone-restores via `AnyCache::deep_clone`. MTP head + hybrid clones: deep-copy via `deep_clone_mtp_cache` / `SteppingKeyValueCache::deep_clone`, which allocate fresh buffers. deep_clone itself was unsafe: `Array::deep_clone` copies straight from the buffer pointer (valid only once evaluated), but the cache stores lazy `slice_update` results — cloning read an unmaterialized pointer and segfaulted. `eval_deep_clone` forces eval first. Tests: 2 deep_clone unit tests (lazy-pointer + live-update independence); higgs-models lib suite 371 passed. Soak: Qwen3.6-35B-A3B MTP server, 5x200-tok requests = 432 mtp_cycle iterations / ~1300 cache deep-clones, accept 62-69%, zero aborts (RED crashed ~44). Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/higgs-engine/src/mtp.rs | 51 ++++++++++--- crates/higgs-models/src/cache.rs | 122 +++++++++++++++++++++++++++++++ crates/higgs-models/src/lib.rs | 40 ++++++++++ 3 files changed, 202 insertions(+), 11 deletions(-) diff --git a/crates/higgs-engine/src/mtp.rs b/crates/higgs-engine/src/mtp.rs index d29ef9c4..ed483684 100644 --- a/crates/higgs-engine/src/mtp.rs +++ b/crates/higgs-engine/src/mtp.rs @@ -6,7 +6,7 @@ //! //! Expected speedup: ~1.5x on dense models at ~80% acceptance rate. -use higgs_models::{AnyCache, AnyModel, MtpCache}; +use higgs_models::{AnyCache, AnyModel, MtpCache, deep_clone_mtp_cache}; use mlx_rs::{ Array, argmax_axis, ops::{self, concatenate_axis, indexing::IndexOp}, @@ -19,6 +19,35 @@ const fn draft_matches_target(draft_token_id: u32, target_id: u32) -> bool { draft_token_id == target_id } +/// Capture a backbone-cache rollback point before a speculative verify. +/// +/// KV caches are rolled back by *trimming the offset* (see [`rollback_backbone`]), +/// so we deliberately return `None` and never clone them. Cloning a KV cache and +/// later restoring it (`*cache = base`) makes the checkpoint share the live +/// cache's underlying MLX buffers; the in-place `slice_update` writes during +/// verify then let MLX donate a buffer that the checkpoint still references, +/// corrupting it and double-freeing on drop (the `malloc: pointer being freed +/// was not allocated` abort). Hybrid SSM/recurrent state cannot be offset- +/// trimmed, so those still need a full clone-restore. +fn capture_backbone_checkpoint(cache: &AnyCache) -> Option { + match cache { + AnyCache::KV(_) => None, + AnyCache::Hybrid(_) => Some(cache.deep_clone()), + } +} + +/// Roll the backbone cache back after a rejected speculative verify. +/// +/// `verify_len` is the number of tokens the verify batch advanced the cache by. +/// KV caches rewind by `trim_by(verify_len)` (no clone, no buffer aliasing); +/// hybrid caches restore the clone captured by [`capture_backbone_checkpoint`]. +fn rollback_backbone(cache: &mut AnyCache, checkpoint: Option, verify_len: usize) { + match checkpoint { + Some(base) => *cache = base, + None => cache.trim_by(verify_len), + } +} + /// Aggregate MTP decode counters. /// /// Tracks per-cycle telemetry for MTP speculative decoding. @@ -191,8 +220,8 @@ pub fn mtp_prompt_lookup_cycle( return Ok(None); } - let base_cache = cache.clone(); - let base_mtp_cache = mtp_cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); + let base_mtp_cache = deep_clone_mtp_cache(mtp_cache); let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1)); verify_tokens.push(confirmed_token_id); verify_tokens.extend(drafts.iter().copied()); @@ -218,7 +247,7 @@ pub fn mtp_prompt_lookup_cycle( })?; (verify_hidden, next) } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?; let next = *replay_targets.get(accepted_drafts).ok_or_else(|| { EngineError::Generation(format!( @@ -351,7 +380,7 @@ pub fn prompt_lookup_cycle( config.max_window, ); - let base_cache = cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1)); verify_tokens.push(confirmed_token_id); verify_tokens.extend(drafts.iter().copied()); @@ -378,7 +407,7 @@ pub fn prompt_lookup_cycle( )) })? } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let replay_logits = model .forward_all_logits(&token_input(&tokens)?, None, cache) .map_err(EngineError::Mlx)?; @@ -617,9 +646,9 @@ pub fn mtp_cycle( draft_n_max: usize, ) -> Result { let draft_limit = draft_n_max.max(1); - let base_cache = cache.clone(); - let base_mtp_cache = mtp_cache.clone(); - let mut speculative_mtp_cache = mtp_cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); + let base_mtp_cache = deep_clone_mtp_cache(mtp_cache); + let mut speculative_mtp_cache = deep_clone_mtp_cache(mtp_cache); let mut confirmed_mtp_cache: Option = None; let mut speculative_hidden = hidden.clone(); let mut speculative_token = confirmed_token_id; @@ -638,7 +667,7 @@ pub fn mtp_cycle( speculative_hidden = next_hidden; speculative_token = draft_token_id; if draft_idx == 0 { - confirmed_mtp_cache = Some(speculative_mtp_cache.clone()); + confirmed_mtp_cache = Some(deep_clone_mtp_cache(&speculative_mtp_cache)); } } @@ -678,7 +707,7 @@ pub fn mtp_cycle( })?; (verify_hidden, next) } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?; let next = *replay_targets.get(accepted_drafts).ok_or_else(|| { EngineError::Generation(format!( diff --git a/crates/higgs-models/src/cache.rs b/crates/higgs-models/src/cache.rs index 14f5bed8..9ce3683b 100644 --- a/crates/higgs-models/src/cache.rs +++ b/crates/higgs-models/src/cache.rs @@ -428,6 +428,26 @@ impl SteppingKeyValueCache { self.offset = self.offset.saturating_sub(trim).max(0); } + /// An **independent** deep copy: every MLX buffer is materialized into a + /// fresh buffer (`Array::deep_clone`), so the result shares no storage with + /// `self`. The derived `Clone` only bumps MLX refcounts (shared buffers), + /// which is unsafe as a speculative-decode rollback checkpoint: the live + /// cache's in-place `slice_update` lets MLX donate (reuse/free) a buffer the + /// checkpoint still references, double-freeing it (the MTP `malloc: pointer + /// being freed was not allocated` abort). Use this for any checkpoint that + /// will outlive an in-place update of the live cache. + #[must_use] + pub fn deep_clone(&self) -> Self { + Self { + keys: self.keys.as_ref().map(eval_deep_clone), + values: self.values.as_ref().map(eval_deep_clone), + turbo: self.turbo.as_ref().map(TurboQuantStorage::deep_clone), + config: self.config, + offset: self.offset, + step: self.step, + } + } + /// References to internal arrays that must be eval'd between chunked-prefill steps. pub fn eval_targets(&self) -> Vec<&Array> { let mut targets = Vec::with_capacity(8); @@ -783,6 +803,22 @@ impl TurboQuantStorage { } } + /// Independent deep copy (see [`SteppingKeyValueCache::deep_clone`]). The + /// shared read-only `context` is refcounted (safe to share); every packed + /// array is materialized into its own buffer so an in-place update of the + /// live cache cannot donate/free a buffer this snapshot holds. + fn deep_clone(&self) -> Self { + Self { + context: Arc::clone(&self.context), + key_codes: self.key_codes.as_ref().map(eval_deep_clone), + key_norms: self.key_norms.as_ref().map(eval_deep_clone), + key_gammas: self.key_gammas.as_ref().map(eval_deep_clone), + value_codes: self.value_codes.as_ref().map(eval_deep_clone), + value_norms: self.value_norms.as_ref().map(eval_deep_clone), + capacity: self.capacity, + } + } + fn ensure_capacity(&mut self, required: i32, step: i32) -> Result<(), Exception> { if required <= self.capacity { return Ok(()); @@ -1020,6 +1056,25 @@ pub fn slice_axis1(arr: &Array, start: i32, end: i32) -> Result Array { + a.eval().expect("eval before deep_clone checkpoint"); + a.deep_clone() +} + /// Write `update` into `target` at `[..., start:start+n, ...]` on axis 2. #[allow(unsafe_code, clippy::indexing_slicing)] fn slice_update_axis2( @@ -1337,6 +1392,73 @@ mod tests { assert!((k_data[8] - 2.0).abs() < 1e-6); } + #[test] + fn deep_clone_preserves_contents_and_offset() { + // deep_clone must be a faithful, independent copy. + let mut cache = SteppingKeyValueCache::new(); + let ones_k = Array::ones::(&[1, 2, 2, 8]).unwrap(); + let ones_v = Array::ones::(&[1, 2, 2, 8]).unwrap(); + cache.update_and_fetch(ones_k, ones_v).unwrap(); + + let copy = cache.deep_clone(); + assert_eq!(copy.offset(), cache.offset()); + + let orig_k: Vec = { + let k = cache.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + let copy_k: Vec = { + let k = copy.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + assert_eq!(orig_k, copy_k, "deep_clone must copy contents faithfully"); + } + + #[test] + fn deep_clone_checkpoint_survives_live_in_place_update() { + // The speculative-decode invariant: a checkpoint captured before the + // live cache is advanced must NOT change when the live cache does an + // in-place `slice_update`. A shallow `clone()` shares the KV buffer, so + // MLX can donate/free it under the checkpoint (the double-free abort); + // `deep_clone()` is independent. + let mut cache = SteppingKeyValueCache::new(); + let ones_k = Array::ones::(&[1, 2, 2, 8]).unwrap(); + let ones_v = Array::ones::(&[1, 2, 2, 8]).unwrap(); + cache.update_and_fetch(ones_k, ones_v).unwrap(); + + let checkpoint = cache.deep_clone(); + let before: Vec = { + let k = checkpoint.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + + // Advance the LIVE cache in place with a token of value 2.0; force eval + // so any buffer donation would fire. + let two = Array::from_f32(2.0); + let twos_k = Array::full::(&[1, 2, 1, 8], &two).unwrap(); + let twos_v = Array::full::(&[1, 2, 1, 8], &two).unwrap(); + let (rk, _) = cache.update_and_fetch(twos_k, twos_v).unwrap(); + rk.eval().unwrap(); + + assert_eq!( + checkpoint.offset(), + 2, + "checkpoint offset must be unchanged" + ); + let after: Vec = { + let k = checkpoint.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + assert_eq!( + before, after, + "deep_clone checkpoint must survive the live cache's in-place update" + ); + } + #[test] fn test_turboquant_cache_round_trips_dense_fetch() { let config = KvCacheConfig { diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 95903b07..7f7adf81 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -117,6 +117,46 @@ impl AnyCache { } } } + + /// An **independent** deep copy for use as a speculative-decode checkpoint. + /// KV layers are deep-cloned (their in-place `slice_update` buffers must not + /// be shared — see [`cache::SteppingKeyValueCache::deep_clone`]); GDN/SSM + /// (`Arrays`) layers update by full reassignment, never in place, so a cheap + /// shallow `clone()` of those is safe. + #[must_use] + pub fn deep_clone(&self) -> Self { + match self { + Self::KV(layers) => Self::KV( + layers + .iter() + .map(|l| l.as_ref().map(cache::SteppingKeyValueCache::deep_clone)) + .collect(), + ), + Self::Hybrid(layers) => Self::Hybrid( + layers + .iter() + .map(|l| { + l.as_ref().map(|lc| match lc { + LayerCache::KV(kv) => LayerCache::KV(kv.deep_clone()), + recurrent => recurrent.clone(), + }) + }) + .collect(), + ), + } + } +} + +/// Independent deep copy of an MTP head cache (`Vec`). +/// +/// For use as a speculative-decode checkpoint. See +/// [`cache::SteppingKeyValueCache::deep_clone`] for why a shallow clone is +/// unsafe (buffer donation double-free). +#[must_use] +pub fn deep_clone_mtp_cache(c: &MtpCache) -> MtpCache { + c.iter() + .map(cache::SteppingKeyValueCache::deep_clone) + .collect() } /// Unified model wrapper dispatching to the correct architecture. From 4cd7f47a9dbe3d020b3e4b3d1a5d1c0f7d68ce16 Mon Sep 17 00:00:00 2001 From: Peppi Littera Date: Thu, 11 Jun 2026 14:36:37 +0200 Subject: [PATCH 4/4] fix(qwen35): address CodeRabbit review + unblock Lint CI CodeRabbit review: - normalize_sidecar_mtp_key: only prefix truly un-namespaced sidecar keys. Gate on !is_mtp_key() instead of !starts_with("mtp.") so already-namespaced keys (e.g. language_model.mtp.*) aren't mangled into unmatchable mtp.language_model.mtp.*. Extends the existing unit test to cover it. - load_qwen3_next_model: route through a new MTP-aware load_qwen3_next_weights instead of the plain loader. maybe_disable_mtp_without_checkpoint_weights can select the dense/MoE head (params dense_mtp.* / moe_mtp.*) while the checkpoint ships the head under mtp.*; the plain loader did no remap and silently left the draft head uninitialized. Backbone keys still match directly, so behaviour is unchanged for the common Quantized layout. Lint CI was failing at the fmt step, masking clippy -Dwarnings errors that the feature commits introduced. Fixed so the full Lint job passes: - cargo fmt (AUXILIARY_SAFETENSORS_FILES wrap, in_proj filter closure) - clippy::shadow_reuse on the three sidecar-key loaders (file convention) - backtick bare `MoE` in MoE-MTP doc comments (doc_markdown) - LayerCache wildcard -> explicit Arrays(_) arm (match_wildcard_for_single_variants) - allow large_enum_variant on AnyModel (singleton dispatch handle; boxing every variant would add forward-path indirection for no real benefit) Co-Authored-By: Claude Opus 4.8 (1M context) --- crates/higgs-models/src/lib.rs | 10 +++- crates/higgs-models/src/qwen3_next.rs | 76 ++++++++++++++++++++++----- 2 files changed, 72 insertions(+), 14 deletions(-) diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 7f7adf81..4eece9fb 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -138,7 +138,7 @@ impl AnyCache { .map(|l| { l.as_ref().map(|lc| match lc { LayerCache::KV(kv) => LayerCache::KV(kv.deep_clone()), - recurrent => recurrent.clone(), + recurrent @ LayerCache::Arrays(_) => recurrent.clone(), }) }) .collect(), @@ -160,6 +160,11 @@ pub fn deep_clone_mtp_cache(c: &MtpCache) -> MtpCache { } /// Unified model wrapper dispatching to the correct architecture. +// One `AnyModel` exists per loaded model (held by the engine for the process +// lifetime), never stored in bulk, so the size spread between variants costs a +// few hundred bytes once. Boxing a dispatch variant would add an indirection on +// the forward path for no practical benefit. +#[allow(clippy::large_enum_variant)] pub enum AnyModel { /// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3. Transformer(Model), @@ -1175,7 +1180,8 @@ pub struct WeightMapIndex { pub weight_map: HashMap, } -pub(crate) const AUXILIARY_SAFETENSORS_FILES: &[&str] = &["mtp.safetensors", "model-mtp.safetensors"]; +pub(crate) const AUXILIARY_SAFETENSORS_FILES: &[&str] = + &["mtp.safetensors", "model-mtp.safetensors"]; /// Load a tokenizer from a model directory. pub fn load_tokenizer>(model_dir: P) -> Result { diff --git a/crates/higgs-models/src/qwen3_next.rs b/crates/higgs-models/src/qwen3_next.rs index ad71a4da..dfeaf45b 100644 --- a/crates/higgs-models/src/qwen3_next.rs +++ b/crates/higgs-models/src/qwen3_next.rs @@ -218,7 +218,7 @@ pub struct Qwen3NextModelArgs { /// Use an MoE-structured MTP head (Qwen3.6-A3B style). /// - /// These sidecars ship the MTP layer as a full MoE decoder layer + /// These sidecars ship the MTP layer as a full `MoE` decoder layer /// (`mlp.gate`, `mlp.switch_mlp.*`, `mlp.shared_expert*`) with a quantized /// `fc`. Set by the loader after inspecting checkpoint keys; not expected /// in configs. @@ -1960,9 +1960,9 @@ impl DenseMtpHead { } } -/// Single MoE MTP transformer layer (Qwen3.6-A3B style). +/// Single `MoE` MTP transformer layer (Qwen3.6-A3B style). /// -/// Qwen3.6-A3B sidecars ship the MTP layer as a full MoE decoder layer: +/// Qwen3.6-A3B sidecars ship the MTP layer as a full `MoE` decoder layer: /// full attention (with q/k norms) + `SparseMoeBlock` /// (router gate + stacked experts + shared expert + shared-expert gate). #[derive(Debug, Clone, ModuleParameters)] @@ -1977,7 +1977,7 @@ struct MoeMtpTransformerLayer { mlp: SparseMoeBlock, } -/// MTP head with an MoE transformer layer (Qwen3.6-A3B style sidecars). +/// MTP head with an `MoE` transformer layer (Qwen3.6-A3B style sidecars). /// /// Unlike [`MtpHead`], the fusion projection `fc` is a quantized [`QLinear`] /// (these sidecars ship `fc.{weight,scales,biases}` triples), and the MLP is @@ -4524,9 +4524,12 @@ pub fn load_qwen3_next_model>( let mut model = Qwen3NextCausalLM::new(args)?; - // Load weights directly from safetensors (no key remapping needed - // since our param names match the safetensors keys exactly) - crate::load_safetensors_weights(&mut model, model_path)?; + // Backbone keys match model params directly, but the MTP sidecar may need + // remapping: `maybe_disable_mtp_without_checkpoint_weights` can select the + // dense or MoE head (params `dense_mtp.*` / `moe_mtp.*`) while the checkpoint + // still ships the head under the `mtp.*` namespace. The plain loader can't + // bridge that, so it would silently leave the draft head uninitialized. + load_qwen3_next_weights(&mut model, model_path)?; tracing::info!("Qwen3Next model loaded successfully"); Ok(model) @@ -4963,7 +4966,11 @@ fn normalize_sidecar_mtp_key(file_path: &Path, key: String) -> String { .file_name() .and_then(|n| n.to_str()) .is_some_and(|n| crate::AUXILIARY_SAFETENSORS_FILES.contains(&n)); - if is_aux && !key.starts_with("mtp.") { + // Only prefix truly un-namespaced sidecar keys (`fc.weight`, + // `layers.0....`). Already-namespaced keys (`mtp.*`, `language_model.mtp.*`) + // are left intact so `qwen35_checkpoint_param_key` can still strip/remap + // them — prefixing those would produce unmatchable `mtp.language_model.mtp.*`. + if is_aux && !is_mtp_key(&key) { format!("mtp.{key}") } else { key @@ -5015,10 +5022,51 @@ fn qwen35_loaded_value( } } +/// Load `Qwen3Next` weights, remapping the `mtp.*` sidecar onto whichever MTP +/// head is active (`mtp` / `dense_mtp` / `moe_mtp`). +/// +/// Backbone keys match params directly (via `qwen35_target_param_key`'s +/// direct-match branch), so this is behaviour-compatible with the plain loader +/// for the common `Quantized` layout. The only added behaviour is the +/// `mtp.*` → `dense_mtp.*` / `moe_mtp.*` remap, which the plain loader lacks — +/// without it a dense/MoE draft head selected by +/// `maybe_disable_mtp_without_checkpoint_weights` is silently left uninitialized. +#[allow(clippy::shadow_reuse)] +fn load_qwen3_next_weights( + model: &mut M, + model_path: &Path, +) -> Result<(), crate::error::ModelError> { + let safetensors_files = crate::collect_safetensors_files(model_path)?; + let mut params = model.parameters_mut().flatten(); + + for file_path in &safetensors_files { + let loaded = Array::load_safetensors(file_path) + .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; + + for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); + if let Some((target_key, dense_mtp_target)) = qwen35_target_param_key(¶ms, &key) { + if let Some(param) = params.get_mut(target_key.as_str()) { + **param = qwen35_loaded_value(&key, value, dense_mtp_target)?; + continue; + } + } + tracing::warn!(key = %key, "Weight key not found in model parameters"); + } + } + + model + .eval() + .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; + + Ok(()) +} + /// Load Qwen3.5-MoE weights with GDN projection fusion. /// /// Direct weight loader: strip `language_model.` prefix, no rearrangement. /// Used when `use_separate_gdn_projections = true`. +#[allow(clippy::shadow_reuse)] fn load_qwen3_5_moe_weights_direct( model: &mut M, model_path: &Path, @@ -5077,9 +5125,7 @@ fn load_qwen3_5_moe_weights_direct( ensure_all_model_params_loaded( params .iter() - .filter(|(name, _)| { - !(name.contains(".in_proj_qkvz.") || name.contains(".in_proj_ba.")) - }) + .filter(|(name, _)| !(name.contains(".in_proj_qkvz.") || name.contains(".in_proj_ba."))) .map(|(name, value)| (std::rc::Rc::::clone(name), &**value)), )?; tracing::info!(param_count, matched, "Total model parameters loaded"); @@ -5093,7 +5139,7 @@ fn load_qwen3_5_moe_weights_direct( /// Rearranges flat (qkv,z,b,a) projections to per-head-grouped (qkvz,ba) /// so the model uses the fused 2-dispatch forward path instead of 4 separate. -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::shadow_reuse)] fn load_qwen3_5_moe_weights_fused( model: &mut M, model_path: &Path, @@ -14169,6 +14215,12 @@ mod tests { normalize_sidecar_mtp_key(aux, "mtp.fc.weight".to_owned()), "mtp.fc.weight" ); + // Already-namespaced sidecar keys (e.g. `language_model.mtp.*`) must NOT + // be over-prefixed into unmatchable `mtp.language_model.mtp.*`. + assert_eq!( + normalize_sidecar_mtp_key(aux, "language_model.mtp.layers.0.fc.weight".to_owned()), + "language_model.mtp.layers.0.fc.weight" + ); // Keys from main shards are never touched. assert_eq!( normalize_sidecar_mtp_key(main, "fc.weight".to_owned()),