Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions crates/higgs-engine/src/mtp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
//!
//! Expected speedup: ~1.5x on dense models at ~80% acceptance rate.

use higgs_models::{AnyCache, AnyModel, MtpCache};
use higgs_models::{AnyCache, AnyModel, MtpCache, deep_clone_mtp_cache};
use mlx_rs::{
Array, argmax_axis,
ops::{self, concatenate_axis, indexing::IndexOp},
Expand All @@ -19,6 +19,35 @@ const fn draft_matches_target(draft_token_id: u32, target_id: u32) -> bool {
draft_token_id == target_id
}

/// Capture a backbone-cache rollback point before a speculative verify.
///
/// KV caches are rolled back by *trimming the offset* (see [`rollback_backbone`]),
/// so we deliberately return `None` and never clone them. Cloning a KV cache and
/// later restoring it (`*cache = base`) makes the checkpoint share the live
/// cache's underlying MLX buffers; the in-place `slice_update` writes during
/// verify then let MLX donate a buffer that the checkpoint still references,
/// corrupting it and double-freeing on drop (the `malloc: pointer being freed
/// was not allocated` abort). Hybrid SSM/recurrent state cannot be offset-
/// trimmed, so those still need a full clone-restore.
fn capture_backbone_checkpoint(cache: &AnyCache) -> Option<AnyCache> {
match cache {
AnyCache::KV(_) => None,
AnyCache::Hybrid(_) => Some(cache.deep_clone()),
}
}

/// Roll the backbone cache back after a rejected speculative verify.
///
/// `verify_len` is the number of tokens the verify batch advanced the cache by.
/// KV caches rewind by `trim_by(verify_len)` (no clone, no buffer aliasing);
/// hybrid caches restore the clone captured by [`capture_backbone_checkpoint`].
fn rollback_backbone(cache: &mut AnyCache, checkpoint: Option<AnyCache>, verify_len: usize) {
match checkpoint {
Some(base) => *cache = base,
None => cache.trim_by(verify_len),
}
}

/// Aggregate MTP decode counters.
///
/// Tracks per-cycle telemetry for MTP speculative decoding.
Expand Down Expand Up @@ -191,8 +220,8 @@ pub fn mtp_prompt_lookup_cycle(
return Ok(None);
}

let base_cache = cache.clone();
let base_mtp_cache = mtp_cache.clone();
let base_cache = capture_backbone_checkpoint(cache);
let base_mtp_cache = deep_clone_mtp_cache(mtp_cache);
let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1));
verify_tokens.push(confirmed_token_id);
verify_tokens.extend(drafts.iter().copied());
Expand All @@ -218,7 +247,7 @@ pub fn mtp_prompt_lookup_cycle(
})?;
(verify_hidden, next)
} else {
*cache = base_cache;
rollback_backbone(cache, base_cache, verify_tokens.len());
let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?;
let next = *replay_targets.get(accepted_drafts).ok_or_else(|| {
EngineError::Generation(format!(
Expand Down Expand Up @@ -351,7 +380,7 @@ pub fn prompt_lookup_cycle(
config.max_window,
);

let base_cache = cache.clone();
let base_cache = capture_backbone_checkpoint(cache);
let mut verify_tokens = Vec::with_capacity(drafts.len().saturating_add(1));
verify_tokens.push(confirmed_token_id);
verify_tokens.extend(drafts.iter().copied());
Expand All @@ -378,7 +407,7 @@ pub fn prompt_lookup_cycle(
))
})?
} else {
*cache = base_cache;
rollback_backbone(cache, base_cache, verify_tokens.len());
let replay_logits = model
.forward_all_logits(&token_input(&tokens)?, None, cache)
.map_err(EngineError::Mlx)?;
Expand Down Expand Up @@ -617,9 +646,9 @@ pub fn mtp_cycle(
draft_n_max: usize,
) -> Result<MtpCycleResult, EngineError> {
let draft_limit = draft_n_max.max(1);
let base_cache = cache.clone();
let base_mtp_cache = mtp_cache.clone();
let mut speculative_mtp_cache = mtp_cache.clone();
let base_cache = capture_backbone_checkpoint(cache);
let base_mtp_cache = deep_clone_mtp_cache(mtp_cache);
let mut speculative_mtp_cache = deep_clone_mtp_cache(mtp_cache);
let mut confirmed_mtp_cache: Option<MtpCache> = None;
let mut speculative_hidden = hidden.clone();
let mut speculative_token = confirmed_token_id;
Expand All @@ -638,7 +667,7 @@ pub fn mtp_cycle(
speculative_hidden = next_hidden;
speculative_token = draft_token_id;
if draft_idx == 0 {
confirmed_mtp_cache = Some(speculative_mtp_cache.clone());
confirmed_mtp_cache = Some(deep_clone_mtp_cache(&speculative_mtp_cache));
}
}

Expand Down Expand Up @@ -678,7 +707,7 @@ pub fn mtp_cycle(
})?;
(verify_hidden, next)
} else {
*cache = base_cache;
rollback_backbone(cache, base_cache, verify_tokens.len());
let (replay_hidden, replay_targets) = backbone_verify_batch(model, cache, &tokens)?;
let next = *replay_targets.get(accepted_drafts).ok_or_else(|| {
EngineError::Generation(format!(
Expand Down
122 changes: 122 additions & 0 deletions crates/higgs-models/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,26 @@ impl SteppingKeyValueCache {
self.offset = self.offset.saturating_sub(trim).max(0);
}

/// An **independent** deep copy: every MLX buffer is materialized into a
/// fresh buffer (`Array::deep_clone`), so the result shares no storage with
/// `self`. The derived `Clone` only bumps MLX refcounts (shared buffers),
/// which is unsafe as a speculative-decode rollback checkpoint: the live
/// cache's in-place `slice_update` lets MLX donate (reuse/free) a buffer the
/// checkpoint still references, double-freeing it (the MTP `malloc: pointer
/// being freed was not allocated` abort). Use this for any checkpoint that
/// will outlive an in-place update of the live cache.
#[must_use]
pub fn deep_clone(&self) -> Self {
Self {
keys: self.keys.as_ref().map(eval_deep_clone),
values: self.values.as_ref().map(eval_deep_clone),
turbo: self.turbo.as_ref().map(TurboQuantStorage::deep_clone),
config: self.config,
offset: self.offset,
step: self.step,
}
}

/// References to internal arrays that must be eval'd between chunked-prefill steps.
pub fn eval_targets(&self) -> Vec<&Array> {
let mut targets = Vec::with_capacity(8);
Expand Down Expand Up @@ -783,6 +803,22 @@ impl TurboQuantStorage {
}
}

/// Independent deep copy (see [`SteppingKeyValueCache::deep_clone`]). The
/// shared read-only `context` is refcounted (safe to share); every packed
/// array is materialized into its own buffer so an in-place update of the
/// live cache cannot donate/free a buffer this snapshot holds.
fn deep_clone(&self) -> Self {
Self {
context: Arc::clone(&self.context),
key_codes: self.key_codes.as_ref().map(eval_deep_clone),
key_norms: self.key_norms.as_ref().map(eval_deep_clone),
key_gammas: self.key_gammas.as_ref().map(eval_deep_clone),
value_codes: self.value_codes.as_ref().map(eval_deep_clone),
value_norms: self.value_norms.as_ref().map(eval_deep_clone),
capacity: self.capacity,
}
}

fn ensure_capacity(&mut self, required: i32, step: i32) -> Result<(), Exception> {
if required <= self.capacity {
return Ok(());
Expand Down Expand Up @@ -1020,6 +1056,25 @@ pub fn slice_axis1(arr: &Array, start: i32, end: i32) -> Result<Array, Exception
slice_axis(arr, 1, start, end)
}

/// Evaluate, then independently deep-copy an MLX array.
///
/// `Array::deep_clone` copies bytes straight from the buffer pointer
/// (`mlx_array_data_*`), which is only valid once the array is evaluated. The
/// cache stores *lazy* `slice_update` results (see `update_dense`), so cloning a
/// checkpoint without forcing evaluation first reads an unmaterialized pointer
/// and segfaults. Eval makes the snapshot both valid and buffer-independent of
/// the live cache (no shared buffer for a later in-place update to donate/free).
//
// `expect`: a failed `eval` on a live cache array leaves nothing safe to copy —
// a loud panic here is strictly better than the segfault a lazy `deep_clone`
// would cause, and the infallible signature keeps every checkpoint call site
// (`AnyCache::deep_clone`, `deep_clone_mtp_cache`, the MTP cycle) clone-and-go.
#[allow(clippy::expect_used)]
fn eval_deep_clone(a: &Array) -> Array {
a.eval().expect("eval before deep_clone checkpoint");
a.deep_clone()
}

/// Write `update` into `target` at `[..., start:start+n, ...]` on axis 2.
#[allow(unsafe_code, clippy::indexing_slicing)]
fn slice_update_axis2(
Expand Down Expand Up @@ -1337,6 +1392,73 @@ mod tests {
assert!((k_data[8] - 2.0).abs() < 1e-6);
}

#[test]
fn deep_clone_preserves_contents_and_offset() {
// deep_clone must be a faithful, independent copy.
let mut cache = SteppingKeyValueCache::new();
let ones_k = Array::ones::<f32>(&[1, 2, 2, 8]).unwrap();
let ones_v = Array::ones::<f32>(&[1, 2, 2, 8]).unwrap();
cache.update_and_fetch(ones_k, ones_v).unwrap();

let copy = cache.deep_clone();
assert_eq!(copy.offset(), cache.offset());

let orig_k: Vec<f32> = {
let k = cache.keys.as_ref().unwrap();
k.eval().unwrap();
k.as_slice().to_vec()
};
let copy_k: Vec<f32> = {
let k = copy.keys.as_ref().unwrap();
k.eval().unwrap();
k.as_slice().to_vec()
};
assert_eq!(orig_k, copy_k, "deep_clone must copy contents faithfully");
}

#[test]
fn deep_clone_checkpoint_survives_live_in_place_update() {
// The speculative-decode invariant: a checkpoint captured before the
// live cache is advanced must NOT change when the live cache does an
// in-place `slice_update`. A shallow `clone()` shares the KV buffer, so
// MLX can donate/free it under the checkpoint (the double-free abort);
// `deep_clone()` is independent.
let mut cache = SteppingKeyValueCache::new();
let ones_k = Array::ones::<f32>(&[1, 2, 2, 8]).unwrap();
let ones_v = Array::ones::<f32>(&[1, 2, 2, 8]).unwrap();
cache.update_and_fetch(ones_k, ones_v).unwrap();

let checkpoint = cache.deep_clone();
let before: Vec<f32> = {
let k = checkpoint.keys.as_ref().unwrap();
k.eval().unwrap();
k.as_slice().to_vec()
};

// Advance the LIVE cache in place with a token of value 2.0; force eval
// so any buffer donation would fire.
let two = Array::from_f32(2.0);
let twos_k = Array::full::<f32>(&[1, 2, 1, 8], &two).unwrap();
let twos_v = Array::full::<f32>(&[1, 2, 1, 8], &two).unwrap();
let (rk, _) = cache.update_and_fetch(twos_k, twos_v).unwrap();
rk.eval().unwrap();

assert_eq!(
checkpoint.offset(),
2,
"checkpoint offset must be unchanged"
);
let after: Vec<f32> = {
let k = checkpoint.keys.as_ref().unwrap();
k.eval().unwrap();
k.as_slice().to_vec()
};
assert_eq!(
before, after,
"deep_clone checkpoint must survive the live cache's in-place update"
);
}

#[test]
fn test_turboquant_cache_round_trips_dense_fetch() {
let config = KvCacheConfig {
Expand Down
48 changes: 47 additions & 1 deletion crates/higgs-models/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,54 @@ impl AnyCache {
}
}
}

/// An **independent** deep copy for use as a speculative-decode checkpoint.
/// KV layers are deep-cloned (their in-place `slice_update` buffers must not
/// be shared — see [`cache::SteppingKeyValueCache::deep_clone`]); GDN/SSM
/// (`Arrays`) layers update by full reassignment, never in place, so a cheap
/// shallow `clone()` of those is safe.
#[must_use]
pub fn deep_clone(&self) -> Self {
match self {
Self::KV(layers) => Self::KV(
layers
.iter()
.map(|l| l.as_ref().map(cache::SteppingKeyValueCache::deep_clone))
.collect(),
),
Self::Hybrid(layers) => Self::Hybrid(
layers
.iter()
.map(|l| {
l.as_ref().map(|lc| match lc {
LayerCache::KV(kv) => LayerCache::KV(kv.deep_clone()),
recurrent @ LayerCache::Arrays(_) => recurrent.clone(),
})
})
.collect(),
),
}
}
}

/// Independent deep copy of an MTP head cache (`Vec<SteppingKeyValueCache>`).
///
/// For use as a speculative-decode checkpoint. See
/// [`cache::SteppingKeyValueCache::deep_clone`] for why a shallow clone is
/// unsafe (buffer donation double-free).
#[must_use]
pub fn deep_clone_mtp_cache(c: &MtpCache) -> MtpCache {
c.iter()
.map(cache::SteppingKeyValueCache::deep_clone)
.collect()
}

/// Unified model wrapper dispatching to the correct architecture.
// One `AnyModel` exists per loaded model (held by the engine for the process
// lifetime), never stored in bulk, so the size spread between variants costs a
// few hundred bytes once. Boxing a dispatch variant would add an indirection on
// the forward path for no practical benefit.
#[allow(clippy::large_enum_variant)]
pub enum AnyModel {
/// Standard transformer architectures: Llama, Mistral, Qwen2/2.5, Qwen3.
Transformer(Model),
Expand Down Expand Up @@ -1135,7 +1180,8 @@ pub struct WeightMapIndex {
pub weight_map: HashMap<String, String>,
}

const AUXILIARY_SAFETENSORS_FILES: &[&str] = &["mtp.safetensors", "model-mtp.safetensors"];
pub(crate) const AUXILIARY_SAFETENSORS_FILES: &[&str] =
&["mtp.safetensors", "model-mtp.safetensors"];

/// Load a tokenizer from a model directory.
pub fn load_tokenizer<P: AsRef<Path>>(model_dir: P) -> Result<tokenizers::Tokenizer, ModelError> {
Expand Down
Loading