diff --git a/crates/higgs-engine/src/mtp.rs b/crates/higgs-engine/src/mtp.rs index d29ef9c4..ed483684 100644 --- a/crates/higgs-engine/src/mtp.rs +++ b/crates/higgs-engine/src/mtp.rs @@ -6,7 +6,7 @@ //! //! Expected speedup: ~1.5x on dense models at ~80% acceptance rate. -use higgs_models::{AnyCache, AnyModel, MtpCache}; +use higgs_models::{AnyCache, AnyModel, MtpCache, deep_clone_mtp_cache}; use mlx_rs::{ Array, argmax_axis, ops::{self, concatenate_axis, indexing::IndexOp}, @@ -19,6 +19,35 @@ const fn draft_matches_target(draft_token_id: u32, target_id: u32) -> bool { draft_token_id == target_id } +/// Capture a backbone-cache rollback point before a speculative verify. +/// +/// KV caches are rolled back by *trimming the offset* (see [`rollback_backbone`]), +/// so we deliberately return `None` and never clone them. Cloning a KV cache and +/// later restoring it (`*cache = base`) makes the checkpoint share the live +/// cache's underlying MLX buffers; the in-place `slice_update` writes during +/// verify then let MLX donate a buffer that the checkpoint still references, +/// corrupting it and double-freeing on drop (the `malloc: pointer being freed +/// was not allocated` abort). Hybrid SSM/recurrent state cannot be offset- +/// trimmed, so those still need a full clone-restore. +fn capture_backbone_checkpoint(cache: &AnyCache) -> Option { + match cache { + AnyCache::KV(_) => None, + AnyCache::Hybrid(_) => Some(cache.deep_clone()), + } +} + +/// Roll the backbone cache back after a rejected speculative verify. +/// +/// `verify_len` is the number of tokens the verify batch advanced the cache by. +/// KV caches rewind by `trim_by(verify_len)` (no clone, no buffer aliasing); +/// hybrid caches restore the clone captured by [`capture_backbone_checkpoint`]. +fn rollback_backbone(cache: &mut AnyCache, checkpoint: Option, verify_len: usize) { + match checkpoint { + Some(base) => *cache = base, + None => cache.trim_by(verify_len), + } +} + /// Aggregate MTP decode counters. /// /// Tracks per-cycle telemetry for MTP speculative decoding. @@ -191,8 +220,8 @@ pub fn mtp_prompt_lookup_cycle( return Ok(None); } - let base_cache = cache.clone(); - let base_mtp_cache = mtp_cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); + let base_mtp_cache = deep_clone_mtp_cache(mtp_cache); let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1)); verify_tokens.push(confirmed_token_id); verify_tokens.extend(drafts.iter().copied()); @@ -218,7 +247,7 @@ pub fn mtp_prompt_lookup_cycle( })?; (verify_hidden, next) } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?; let next = *replay_targets.get(accepted_drafts).ok_or_else(|| { EngineError::Generation(format!( @@ -351,7 +380,7 @@ pub fn prompt_lookup_cycle( config.max_window, ); - let base_cache = cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1)); verify_tokens.push(confirmed_token_id); verify_tokens.extend(drafts.iter().copied()); @@ -378,7 +407,7 @@ pub fn prompt_lookup_cycle( )) })? } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let replay_logits = model .forward_all_logits(&token_input(&tokens)?, None, cache) .map_err(EngineError::Mlx)?; @@ -617,9 +646,9 @@ pub fn mtp_cycle( draft_n_max: usize, ) -> Result { let draft_limit = draft_n_max.max(1); - let base_cache = cache.clone(); - let base_mtp_cache = mtp_cache.clone(); - let mut speculative_mtp_cache = mtp_cache.clone(); + let base_cache = capture_backbone_checkpoint(cache); + let base_mtp_cache = deep_clone_mtp_cache(mtp_cache); + let mut speculative_mtp_cache = deep_clone_mtp_cache(mtp_cache); let mut confirmed_mtp_cache: Option = None; let mut speculative_hidden = hidden.clone(); let mut speculative_token = confirmed_token_id; @@ -638,7 +667,7 @@ pub fn mtp_cycle( speculative_hidden = next_hidden; speculative_token = draft_token_id; if draft_idx == 0 { - confirmed_mtp_cache = Some(speculative_mtp_cache.clone()); + confirmed_mtp_cache = Some(deep_clone_mtp_cache(&speculative_mtp_cache)); } } @@ -678,7 +707,7 @@ pub fn mtp_cycle( })?; (verify_hidden, next) } else { - *cache = base_cache; + rollback_backbone(cache, base_cache, verify_tokens.len()); let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?; let next = *replay_targets.get(accepted_drafts).ok_or_else(|| { EngineError::Generation(format!( diff --git a/crates/higgs-models/src/cache.rs b/crates/higgs-models/src/cache.rs index 14f5bed8..9ce3683b 100644 --- a/crates/higgs-models/src/cache.rs +++ b/crates/higgs-models/src/cache.rs @@ -428,6 +428,26 @@ impl SteppingKeyValueCache { self.offset = self.offset.saturating_sub(trim).max(0); } + /// An **independent** deep copy: every MLX buffer is materialized into a + /// fresh buffer (`Array::deep_clone`), so the result shares no storage with + /// `self`. The derived `Clone` only bumps MLX refcounts (shared buffers), + /// which is unsafe as a speculative-decode rollback checkpoint: the live + /// cache's in-place `slice_update` lets MLX donate (reuse/free) a buffer the + /// checkpoint still references, double-freeing it (the MTP `malloc: pointer + /// being freed was not allocated` abort). Use this for any checkpoint that + /// will outlive an in-place update of the live cache. + #[must_use] + pub fn deep_clone(&self) -> Self { + Self { + keys: self.keys.as_ref().map(eval_deep_clone), + values: self.values.as_ref().map(eval_deep_clone), + turbo: self.turbo.as_ref().map(TurboQuantStorage::deep_clone), + config: self.config, + offset: self.offset, + step: self.step, + } + } + /// References to internal arrays that must be eval'd between chunked-prefill steps. pub fn eval_targets(&self) -> Vec<&Array> { let mut targets = Vec::with_capacity(8); @@ -783,6 +803,22 @@ impl TurboQuantStorage { } } + /// Independent deep copy (see [`SteppingKeyValueCache::deep_clone`]). The + /// shared read-only `context` is refcounted (safe to share); every packed + /// array is materialized into its own buffer so an in-place update of the + /// live cache cannot donate/free a buffer this snapshot holds. + fn deep_clone(&self) -> Self { + Self { + context: Arc::clone(&self.context), + key_codes: self.key_codes.as_ref().map(eval_deep_clone), + key_norms: self.key_norms.as_ref().map(eval_deep_clone), + key_gammas: self.key_gammas.as_ref().map(eval_deep_clone), + value_codes: self.value_codes.as_ref().map(eval_deep_clone), + value_norms: self.value_norms.as_ref().map(eval_deep_clone), + capacity: self.capacity, + } + } + fn ensure_capacity(&mut self, required: i32, step: i32) -> Result<(), Exception> { if required <= self.capacity { return Ok(()); @@ -1020,6 +1056,25 @@ pub fn slice_axis1(arr: &Array, start: i32, end: i32) -> Result Array { + a.eval().expect("eval before deep_clone checkpoint"); + a.deep_clone() +} + /// Write `update` into `target` at `[..., start:start+n, ...]` on axis 2. #[allow(unsafe_code, clippy::indexing_slicing)] fn slice_update_axis2( @@ -1337,6 +1392,73 @@ mod tests { assert!((k_data[8] - 2.0).abs() < 1e-6); } + #[test] + fn deep_clone_preserves_contents_and_offset() { + // deep_clone must be a faithful, independent copy. + let mut cache = SteppingKeyValueCache::new(); + let ones_k = Array::ones::(&[1, 2, 2, 8]).unwrap(); + let ones_v = Array::ones::(&[1, 2, 2, 8]).unwrap(); + cache.update_and_fetch(ones_k, ones_v).unwrap(); + + let copy = cache.deep_clone(); + assert_eq!(copy.offset(), cache.offset()); + + let orig_k: Vec = { + let k = cache.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + let copy_k: Vec = { + let k = copy.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + assert_eq!(orig_k, copy_k, "deep_clone must copy contents faithfully"); + } + + #[test] + fn deep_clone_checkpoint_survives_live_in_place_update() { + // The speculative-decode invariant: a checkpoint captured before the + // live cache is advanced must NOT change when the live cache does an + // in-place `slice_update`. A shallow `clone()` shares the KV buffer, so + // MLX can donate/free it under the checkpoint (the double-free abort); + // `deep_clone()` is independent. + let mut cache = SteppingKeyValueCache::new(); + let ones_k = Array::ones::(&[1, 2, 2, 8]).unwrap(); + let ones_v = Array::ones::(&[1, 2, 2, 8]).unwrap(); + cache.update_and_fetch(ones_k, ones_v).unwrap(); + + let checkpoint = cache.deep_clone(); + let before: Vec = { + let k = checkpoint.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + + // Advance the LIVE cache in place with a token of value 2.0; force eval + // so any buffer donation would fire. + let two = Array::from_f32(2.0); + let twos_k = Array::full::(&[1, 2, 1, 8], &two).unwrap(); + let twos_v = Array::full::(&[1, 2, 1, 8], &two).unwrap(); + let (rk, _) = cache.update_and_fetch(twos_k, twos_v).unwrap(); + rk.eval().unwrap(); + + assert_eq!( + checkpoint.offset(), + 2, + "checkpoint offset must be unchanged" + ); + let after: Vec = { + let k = checkpoint.keys.as_ref().unwrap(); + k.eval().unwrap(); + k.as_slice().to_vec() + }; + assert_eq!( + before, after, + "deep_clone checkpoint must survive the live cache's in-place update" + ); + } + #[test] fn test_turboquant_cache_round_trips_dense_fetch() { let config = KvCacheConfig { diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 8ae2716a..4eece9fb 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -117,9 +117,54 @@ impl AnyCache { } } } + + /// An **independent** deep copy for use as a speculative-decode checkpoint. + /// KV layers are deep-cloned (their in-place `slice_update` buffers must not + /// be shared — see [`cache::SteppingKeyValueCache::deep_clone`]); GDN/SSM + /// (`Arrays`) layers update by full reassignment, never in place, so a cheap + /// shallow `clone()` of those is safe. + #[must_use] + pub fn deep_clone(&self) -> Self { + match self { + Self::KV(layers) => Self::KV( + layers + .iter() + .map(|l| l.as_ref().map(cache::SteppingKeyValueCache::deep_clone)) + .collect(), + ), + Self::Hybrid(layers) => Self::Hybrid( + layers + .iter() + .map(|l| { + l.as_ref().map(|lc| match lc { + LayerCache::KV(kv) => LayerCache::KV(kv.deep_clone()), + recurrent @ LayerCache::Arrays(_) => recurrent.clone(), + }) + }) + .collect(), + ), + } + } +} + +/// Independent deep copy of an MTP head cache (`Vec`). +/// +/// For use as a speculative-decode checkpoint. See +/// [`cache::SteppingKeyValueCache::deep_clone`] for why a shallow clone is +/// unsafe (buffer donation double-free). +#[must_use] +pub fn deep_clone_mtp_cache(c: &MtpCache) -> MtpCache { + c.iter() + .map(cache::SteppingKeyValueCache::deep_clone) + .collect() } /// Unified model wrapper dispatching to the correct architecture. +// One `AnyModel` exists per loaded model (held by the engine for the process +// lifetime), never stored in bulk, so the size spread between variants costs a +// few hundred bytes once. Boxing a dispatch variant would add an indirection on +// the forward path for no practical benefit. +#[allow(clippy::large_enum_variant)] pub enum AnyModel { /// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3. Transformer(Model), @@ -1135,7 +1180,8 @@ pub struct WeightMapIndex { pub weight_map: HashMap, } -const AUXILIARY_SAFETENSORS_FILES: &[&str] = &["mtp.safetensors", "model-mtp.safetensors"]; +pub(crate) const AUXILIARY_SAFETENSORS_FILES: &[&str] = + &["mtp.safetensors", "model-mtp.safetensors"]; /// Load a tokenizer from a model directory. pub fn load_tokenizer>(model_dir: P) -> Result { diff --git a/crates/higgs-models/src/qwen3_next.rs b/crates/higgs-models/src/qwen3_next.rs index ba3e5da4..dfeaf45b 100644 --- a/crates/higgs-models/src/qwen3_next.rs +++ b/crates/higgs-models/src/qwen3_next.rs @@ -215,6 +215,15 @@ pub struct Qwen3NextModelArgs { /// loader after inspecting checkpoint keys; it is not expected in configs. #[serde(default)] pub use_dense_mtp: bool, + + /// Use an MoE-structured MTP head (Qwen3.6-A3B style). + /// + /// These sidecars ship the MTP layer as a full `MoE` decoder layer + /// (`mlp.gate`, `mlp.switch_mlp.*`, `mlp.shared_expert*`) with a quantized + /// `fc`. Set by the loader after inspecting checkpoint keys; not expected + /// in configs. + #[serde(default)] + pub use_moe_mtp: bool, } // --------------------------------------------------------------------------- @@ -1951,6 +1960,86 @@ impl DenseMtpHead { } } +/// Single `MoE` MTP transformer layer (Qwen3.6-A3B style). +/// +/// Qwen3.6-A3B sidecars ship the MTP layer as a full `MoE` decoder layer: +/// full attention (with q/k norms) + `SparseMoeBlock` +/// (router gate + stacked experts + shared expert + shared-expert gate). +#[derive(Debug, Clone, ModuleParameters)] +struct MoeMtpTransformerLayer { + #[param] + self_attn: Qwen3NextAttention, + #[param] + input_layernorm: nn::RmsNorm, + #[param] + post_attention_layernorm: nn::RmsNorm, + #[param] + mlp: SparseMoeBlock, +} + +/// MTP head with an `MoE` transformer layer (Qwen3.6-A3B style sidecars). +/// +/// Unlike [`MtpHead`], the fusion projection `fc` is a quantized [`QLinear`] +/// (these sidecars ship `fc.{weight,scales,biases}` triples), and the MLP is +/// a [`SparseMoeBlock`]. All projections use the checkpoint's uniform +/// quantization — the main model's `gate_quantization` override must NOT be +/// applied here (the sidecar's router gate is quantized at the default width). +#[derive(Debug, Clone, ModuleParameters)] +pub struct MoeMtpHead { + #[param] + pre_fc_norm_hidden: nn::RmsNorm, + #[param] + pre_fc_norm_embedding: nn::RmsNorm, + #[param] + fc: QLinear, + #[param] + layers: Vec, + #[param] + norm: nn::RmsNorm, +} + +impl MoeMtpHead { + fn new(args: &Qwen3NextModelArgs, ql: i32, qb: i32) -> Result { + let n = usize::try_from(args.mtp_num_hidden_layers) + .map_err(|_| Exception::custom("mtp_num_hidden_layers must be non-negative"))?; + + // The sidecar's MoE block is uniformly quantized at the default width; + // strip the main model's per-layer gate override so the router gate's + // QLinear dequantizes with the right parameters. + let mut mtp_args = args.clone(); + mtp_args.gate_quantization = None; + + let layers = (0..n) + .map(|_| { + Ok(MoeMtpTransformerLayer { + self_attn: Qwen3NextAttention::new(args, ql, qb)?, + input_layernorm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + post_attention_layernorm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + mlp: SparseMoeBlock::new(&mtp_args, ql, qb)?, + }) + }) + .collect::, Exception>>()?; + + Ok(Self { + pre_fc_norm_hidden: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + pre_fc_norm_embedding: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + fc: QLinear::new(ql, qb)?, + layers, + norm: nn::RmsNormBuilder::new(args.hidden_size) + .eps(args.rms_norm_eps) + .build()?, + }) + } +} + // --------------------------------------------------------------------------- // SwitchMLP weights (stacked expert weights for MoE) // --------------------------------------------------------------------------- @@ -3335,6 +3424,8 @@ pub struct Qwen3NextCausalLM { mtp: Option, #[param] dense_mtp: Option, + #[param] + moe_mtp: Option, } // Manual RoPE implementation for arbitrary positions @@ -3448,7 +3539,7 @@ impl Qwen3NextCausalLM { } else { Some(QLinear::new(ql, qb)?) }; - let mtp = if args.mtp_num_hidden_layers > 0 && !args.use_dense_mtp { + let mtp = if args.mtp_num_hidden_layers > 0 && !args.use_dense_mtp && !args.use_moe_mtp { Some(MtpHead::new(&args, ql, qb)?) } else { None @@ -3458,6 +3549,11 @@ impl Qwen3NextCausalLM { } else { None }; + let moe_mtp = if args.mtp_num_hidden_layers > 0 && args.use_moe_mtp { + Some(MoeMtpHead::new(&args, ql, qb)?) + } else { + None + }; Ok(Self { args, @@ -3465,6 +3561,7 @@ impl Qwen3NextCausalLM { lm_head, mtp, dense_mtp, + moe_mtp, }) } @@ -3814,7 +3911,7 @@ impl Qwen3NextCausalLM { /// Whether this model has an MTP head loaded. pub const fn has_mtp(&self) -> bool { - self.mtp.is_some() || self.dense_mtp.is_some() + self.mtp.is_some() || self.dense_mtp.is_some() || self.moe_mtp.is_some() } /// Create a fresh KV cache for the MTP head (one entry per MTP layer). @@ -3824,7 +3921,8 @@ impl Qwen3NextCausalLM { .mtp .as_ref() .map(|mtp| mtp.layers.len()) - .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len()))?; + .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len())) + .or_else(|| self.moe_mtp.as_ref().map(|mtp| mtp.layers.len()))?; Some( (0..layer_count) .map(|_| SteppingKeyValueCache::new()) @@ -3910,6 +4008,25 @@ impl Qwen3NextCausalLM { return mtp.norm.forward(&x); } + // MoE MTP head (Qwen3.6-A3B style): same loop, MoE MLP. + if let Some(mtp) = self.moe_mtp.as_mut() { + let h_norm = mtp.pre_fc_norm_hidden.forward(hidden)?; + let e_norm = mtp.pre_fc_norm_embedding.forward(&next_embed)?; + let concat = ops::concatenate_axis(&[&e_norm, &h_norm], -1)?; + let mut x = mtp.fc.forward(&concat)?; + + for (layer, kv) in mtp.layers.iter_mut().zip(mtp_cache.iter_mut()) { + let normed = layer.input_layernorm.forward(&x)?; + let attn_out = layer.self_attn.forward(&normed, None, kv)?; + let h2 = x.add(attn_out)?; + let normed_post = layer.post_attention_layernorm.forward(&h2)?; + let mlp_out = layer.mlp.forward(&normed_post)?; + x = h2.add(mlp_out)?; + } + + return mtp.norm.forward(&x); + } + let mtp = self .dense_mtp .as_mut() @@ -4002,6 +4119,7 @@ impl Qwen3NextCausalLM { .as_ref() .map(|mtp| mtp.layers.len()) .or_else(|| self.dense_mtp.as_ref().map(|mtp| mtp.layers.len())) + .or_else(|| self.moe_mtp.as_ref().map(|mtp| mtp.layers.len())) .ok_or_else(|| Exception::custom("MTP head not loaded"))?; Self::validate_mtp_advance_many_inputs(hidden, mtp_cache, expected_layers, seq_len)?; @@ -4028,6 +4146,26 @@ impl Qwen3NextCausalLM { return Ok(()); } + // MoE MTP head (Qwen3.6-A3B style): same loop, MoE MLP. + if let Some(mtp) = self.moe_mtp.as_mut() { + let h_norm = mtp.pre_fc_norm_hidden.forward(hidden)?; + let e_norm = mtp.pre_fc_norm_embedding.forward(&next_embed)?; + let concat = ops::concatenate_axis(&[&e_norm, &h_norm], -1)?; + let mut x = mtp.fc.forward(&concat)?; + + for (layer, kv) in mtp.layers.iter_mut().zip(mtp_cache.iter_mut()) { + let normed = layer.input_layernorm.forward(&x)?; + let attn_out = layer.self_attn.forward(&normed, mask_ref, kv)?; + let h2 = x.add(attn_out)?; + let normed_post = layer.post_attention_layernorm.forward(&h2)?; + let mlp_out = layer.mlp.forward(&normed_post)?; + x = h2.add(mlp_out)?; + } + + let _ = mtp.norm.forward(&x)?; + return Ok(()); + } + let mtp = self .dense_mtp .as_mut() @@ -4217,6 +4355,9 @@ enum MtpWeightLayout { None, Quantized, Dense, + /// MTP layer is MoE-structured (`mlp.gate` / `shared_expert` / experts), + /// e.g. Qwen3.6-A3B sidecars — loaded via [`MoeMtpHead`]. + MoeQuantized, } fn is_mtp_key(key: &str) -> bool { @@ -4227,6 +4368,7 @@ fn mtp_weight_layout_from_keys<'a>(keys: impl IntoIterator) -> M let mut has_mtp = false; let mut has_unprefixed_mtp = false; let mut has_quantized_aux = false; + let mut has_moe_mlp = false; for key in keys { if !is_mtp_key(key) { @@ -4235,9 +4377,15 @@ fn mtp_weight_layout_from_keys<'a>(keys: impl IntoIterator) -> M has_mtp = true; has_unprefixed_mtp |= key.starts_with("mtp."); has_quantized_aux |= key.ends_with(".scales") || key.ends_with(".biases"); + has_moe_mlp |= key.contains(".mlp.gate.") + || key.contains(".mlp.shared_expert") + || key.contains(".mlp.switch_mlp") + || key.contains(".mlp.experts"); } - if has_quantized_aux { + if has_mtp && has_moe_mlp { + MtpWeightLayout::MoeQuantized + } else if has_quantized_aux { MtpWeightLayout::Quantized } else if has_unprefixed_mtp { MtpWeightLayout::Dense @@ -4253,7 +4401,16 @@ fn checkpoint_mtp_weight_layout(model_path: &Path) -> Result = metadata + .names() + .into_iter() + .map(|name| normalize_sidecar_mtp_key(file_path, name.to_owned())) + .collect(); + Ok(mtp_weight_layout_from_keys( + normalized.iter().map(String::as_str), + )) } let index_path = model_path.join("model.safetensors.index.json"); @@ -4326,6 +4483,15 @@ fn maybe_disable_mtp_without_checkpoint_weights( args.use_dense_mtp = true; return Ok(()); } + MtpWeightLayout::MoeQuantized => { + tracing::info!( + "Checkpoint ships an MoE-structured MTP head (Qwen3.6-A3B style); \ + loading via MoeMtpHead" + ); + args.use_dense_mtp = false; + args.use_moe_mtp = true; + return Ok(()); + } MtpWeightLayout::None => {} } @@ -4358,9 +4524,12 @@ pub fn load_qwen3_next_model>( let mut model = Qwen3NextCausalLM::new(args)?; - // Load weights directly from safetensors (no key remapping needed - // since our param names match the safetensors keys exactly) - crate::load_safetensors_weights(&mut model, model_path)?; + // Backbone keys match model params directly, but the MTP sidecar may need + // remapping: `maybe_disable_mtp_without_checkpoint_weights` can select the + // dense or MoE head (params `dense_mtp.*` / `moe_mtp.*`) while the checkpoint + // still ships the head under the `mtp.*` namespace. The plain loader can't + // bridge that, so it would silently leave the draft head uninitialized. + load_qwen3_next_weights(&mut model, model_path)?; tracing::info!("Qwen3Next model loaded successfully"); Ok(model) @@ -4513,7 +4682,17 @@ fn qwen3_5_mixed_ba_quantization_layers( }; let a_quant = projection_quantization("in_proj_a"); let b_quant = projection_quantization("in_proj_b"); - a_quant.bits != b_quant.bits || a_quant.group_size != b_quant.group_size + // The qkvz fusion pair has the same constraint: `in_proj_qkv` and + // `in_proj_z` are concatenated into `in_proj_qkvz`, which is only + // possible when their packed (quantized) shapes agree. Mixed- + // precision quants (e.g. OptiQ) assign different bit-widths per + // projection on sensitive layers, so check both fusion pairs. + let qkv_quant = projection_quantization("in_proj_qkv"); + let z_quant = projection_quantization("in_proj_z"); + a_quant.bits != b_quant.bits + || a_quant.group_size != b_quant.group_size + || qkv_quant.bits != z_quant.bits + || qkv_quant.group_size != z_quant.group_size }) .collect() } @@ -4770,12 +4949,45 @@ fn dense_mtp_param_key(stripped: &str) -> Option { .map(|rest| format!("dense_mtp.{rest}")) } +fn moe_mtp_param_key(stripped: &str) -> Option { + stripped + .strip_prefix("mtp.") + .map(|rest| format!("moe_mtp.{rest}")) +} + +/// Normalize a tensor key loaded from an auxiliary MTP sidecar file. +/// +/// Some sidecars (e.g. mlx-community MTP drafters) ship truly unprefixed keys +/// (`fc.weight`, `layers.0....`); prefix them with `mtp.` so they map onto the +/// model's MTP head params the same way prefixed sidecars do. Keys from +/// non-auxiliary files are returned unchanged. +fn normalize_sidecar_mtp_key(file_path: &Path, key: String) -> String { + let is_aux = file_path + .file_name() + .and_then(|n| n.to_str()) + .is_some_and(|n| crate::AUXILIARY_SAFETENSORS_FILES.contains(&n)); + // Only prefix truly un-namespaced sidecar keys (`fc.weight`, + // `layers.0....`). Already-namespaced keys (`mtp.*`, `language_model.mtp.*`) + // are left intact so `qwen35_checkpoint_param_key` can still strip/remap + // them — prefixing those would produce unmatchable `mtp.language_model.mtp.*`. + if is_aux && !is_mtp_key(&key) { + format!("mtp.{key}") + } else { + key + } +} + fn qwen35_target_param_key( params: &HashMap, &mut Array>, stripped: &str, ) -> Option<(String, bool)> { if params.contains_key(stripped) { Some((stripped.to_owned(), false)) + } else if let Some(moe_key) = + moe_mtp_param_key(stripped).filter(|key| params.contains_key(key.as_str())) + { + // MoE MTP head: plain remap, no dense rmsnorm adjustment. + Some((moe_key, false)) } else { dense_mtp_param_key(stripped) .filter(|dense_key| params.contains_key(dense_key.as_str())) @@ -4810,10 +5022,51 @@ fn qwen35_loaded_value( } } +/// Load `Qwen3Next` weights, remapping the `mtp.*` sidecar onto whichever MTP +/// head is active (`mtp` / `dense_mtp` / `moe_mtp`). +/// +/// Backbone keys match params directly (via `qwen35_target_param_key`'s +/// direct-match branch), so this is behaviour-compatible with the plain loader +/// for the common `Quantized` layout. The only added behaviour is the +/// `mtp.*` → `dense_mtp.*` / `moe_mtp.*` remap, which the plain loader lacks — +/// without it a dense/MoE draft head selected by +/// `maybe_disable_mtp_without_checkpoint_weights` is silently left uninitialized. +#[allow(clippy::shadow_reuse)] +fn load_qwen3_next_weights( + model: &mut M, + model_path: &Path, +) -> Result<(), crate::error::ModelError> { + let safetensors_files = crate::collect_safetensors_files(model_path)?; + let mut params = model.parameters_mut().flatten(); + + for file_path in &safetensors_files { + let loaded = Array::load_safetensors(file_path) + .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; + + for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); + if let Some((target_key, dense_mtp_target)) = qwen35_target_param_key(¶ms, &key) { + if let Some(param) = params.get_mut(target_key.as_str()) { + **param = qwen35_loaded_value(&key, value, dense_mtp_target)?; + continue; + } + } + tracing::warn!(key = %key, "Weight key not found in model parameters"); + } + } + + model + .eval() + .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; + + Ok(()) +} + /// Load Qwen3.5-MoE weights with GDN projection fusion. /// /// Direct weight loader: strip `language_model.` prefix, no rearrangement. /// Used when `use_separate_gdn_projections = true`. +#[allow(clippy::shadow_reuse)] fn load_qwen3_5_moe_weights_direct( model: &mut M, model_path: &Path, @@ -4828,6 +5081,7 @@ fn load_qwen3_5_moe_weights_direct( .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); let Some(stripped) = qwen35_checkpoint_param_key(&key) else { unmatched.push(key); continue; @@ -4861,9 +5115,17 @@ fn load_qwen3_5_moe_weights_direct( } } let param_count = params.len(); + // This loader is only used with separate GDN projections (see + // `load_qwen3_5_model_with_gdn_fallback`). In that mode the fused + // `in_proj_qkvz` / `in_proj_ba` QLinears are still constructed — as unused + // placeholders, since the forward path dispatches on + // `use_separate_projections` — so they must be exempt from the + // completeness check. Flagging them would reject every mixed-bit + // checkpoint that *requires* separate projections (e.g. OptiQ quants). ensure_all_model_params_loaded( params .iter() + .filter(|(name, _)| !(name.contains(".in_proj_qkvz.") || name.contains(".in_proj_ba."))) .map(|(name, value)| (std::rc::Rc::::clone(name), &**value)), )?; tracing::info!(param_count, matched, "Total model parameters loaded"); @@ -4877,7 +5139,7 @@ fn load_qwen3_5_moe_weights_direct( /// Rearranges flat (qkv,z,b,a) projections to per-head-grouped (qkvz,ba) /// so the model uses the fused 2-dispatch forward path instead of 4 separate. -#[allow(clippy::too_many_lines)] +#[allow(clippy::too_many_lines, clippy::shadow_reuse)] fn load_qwen3_5_moe_weights_fused( model: &mut M, model_path: &Path, @@ -4906,6 +5168,7 @@ fn load_qwen3_5_moe_weights_fused( .map_err(|e| crate::error::ModelError::Io(std::io::Error::other(e.to_string())))?; for (key, value) in loaded { + let key = normalize_sidecar_mtp_key(file_path, key); let Some(stripped) = qwen35_checkpoint_param_key(&key) else { continue; }; @@ -13908,6 +14171,63 @@ mod tests { ); } + #[test] + fn test_mtp_layout_detects_moe_structured_head() { + // Qwen3.6-A3B style: the MTP layer is a full MoE layer. + let layout = mtp_weight_layout_from_keys([ + "mtp.layers.0.self_attn.q_proj.weight", + "mtp.layers.0.mlp.gate.weight", + "mtp.layers.0.mlp.shared_expert.up_proj.scales", + "mtp.layers.0.mlp.switch_mlp.down_proj.weight", + "mtp.fc.weight", + ]); + assert_eq!(layout, MtpWeightLayout::MoeQuantized); + } + + #[test] + fn test_moe_mtp_param_key_remaps_mtp_prefix() { + assert_eq!( + moe_mtp_param_key("mtp.layers.0.mlp.gate.weight").as_deref(), + Some("moe_mtp.layers.0.mlp.gate.weight") + ); + assert_eq!( + moe_mtp_param_key("mtp.fc.scales").as_deref(), + Some("moe_mtp.fc.scales") + ); + assert!(moe_mtp_param_key("model.layers.0.mlp.gate.weight").is_none()); + } + + #[test] + fn test_normalize_sidecar_mtp_key_prefixes_aux_files_only() { + let aux = Path::new("/models/x/mtp.safetensors"); + let main = Path::new("/models/x/model-00001-of-00004.safetensors"); + // Unprefixed keys from the sidecar get the mtp. prefix. + assert_eq!( + normalize_sidecar_mtp_key(aux, "fc.weight".to_owned()), + "mtp.fc.weight" + ); + assert_eq!( + normalize_sidecar_mtp_key(aux, "layers.0.mlp.gate.weight".to_owned()), + "mtp.layers.0.mlp.gate.weight" + ); + // Already-prefixed sidecar keys are unchanged. + assert_eq!( + normalize_sidecar_mtp_key(aux, "mtp.fc.weight".to_owned()), + "mtp.fc.weight" + ); + // Already-namespaced sidecar keys (e.g. `language_model.mtp.*`) must NOT + // be over-prefixed into unmatchable `mtp.language_model.mtp.*`. + assert_eq!( + normalize_sidecar_mtp_key(aux, "language_model.mtp.layers.0.fc.weight".to_owned()), + "language_model.mtp.layers.0.fc.weight" + ); + // Keys from main shards are never touched. + assert_eq!( + normalize_sidecar_mtp_key(main, "fc.weight".to_owned()), + "fc.weight" + ); + } + #[test] fn test_checkpoint_mtp_weight_layout_detects_dense_auxiliary_mtp_file() { let dir = tempfile::tempdir().unwrap(); @@ -14053,6 +14373,39 @@ mod tests { ); } + #[test] + fn test_load_qwen35_mixed_qkvz_quantization_forces_separate_gdn() { + // Mixed-precision quants (e.g. OptiQ) can also put `in_proj_qkv` and + // `in_proj_z` at different bit-widths — that breaks the qkvz fusion + // concat exactly like a mixed BA pair does. + let dir = tempfile::tempdir().unwrap(); + let config = format!( + r#"{{ + "text_config": {}, + "tie_word_embeddings": false, + "quantization": {{ + "group_size": 64, + "bits": 4, + "mode": "affine", + "language_model.model.layers.2.linear_attn.in_proj_z": {{ + "group_size": 64, + "bits": 8, + "mode": "affine" + }} + }} + }}"#, + qwen35_dense_text_config() + ); + std::fs::write(dir.path().join("config.json"), config).unwrap(); + + let args = load_qwen3_5_moe_text_config_args(dir.path()).unwrap(); + + assert!( + args.use_separate_gdn_projections, + "mixed-bit in_proj_qkv/in_proj_z must force separate GDN projections" + ); + } + #[test] fn test_load_qwen35_matching_ba_quantization_keeps_fused_gdn() { let dir = tempfile::tempdir().unwrap();