Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions crates/higgs-engine/src/model_loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,14 @@ pub fn load_model<P: AsRef<Path>>(model_dir: P) -> Result<AnyModel, EngineError>
match config.model_type.as_str() {
"qwen2" | "qwen3" | "llama" | "mistral" => {
// Packed 1.25-bpw Bonsai-Q1 checkpoints declare model_type="qwen3"
// but the weights are quantized to bits=1. Keep detection ahead of
// the fp16/Q4 transformer loader so users get an explicit error
// while the workspace remains on upstream oxideai/mlx-rs.
// but the weights are quantized to bits=1. Route them to the
// dedicated packed engine, whose bits=1 matvec/dequant run through
// runtime JIT Metal kernels (higgs-models::metal_kernel) — so it
// runs on stock oxideai/mlx-rs with no forked bits=1 MLX kernel.
if is_bonsai_q1(&config.model_dir)? {
return Err(EngineError::Model(ModelError::UnsupportedModel(
"Bonsai-Q1 requires MLX bits=1 affine quantization support; \
the workspace stays on upstream oxideai/mlx-rs until that support lands"
.to_owned(),
)));
let gpu = higgs_models::bonsai_q1::load_bonsai_q1(&config.model_dir)
.map_err(EngineError::Model)?;
return Ok(AnyModel::BonsaiQ1(gpu));
}
let model = transformer::load_model(&config.model_dir).map_err(EngineError::Model)?;
Ok(AnyModel::Transformer(model))
Expand Down Expand Up @@ -315,16 +314,27 @@ mod tests {
}

#[test]
fn load_model_rejects_bonsai_q1_without_runtime_support() {
fn load_model_routes_bonsai_q1_to_packed_engine() {
// A bits=1 / group=128 qwen3 config now routes to the packed Bonsai-Q1
// engine (its bits=1 kernels live in higgs-models::metal_kernel) instead
// of being rejected up front. With no weights in the dir the load still
// fails inside the engine — but it must no longer be gated out, and the
// old "requires MLX bits=1" guard error must be gone.
let (dir, _result) = config_from_raw(
r#"{
"model_type": "qwen3",
"quantization": {"bits": 1, "group_size": 128}
}"#,
);
match load_model(dir.path()) {
Err(err) => assert!(err.to_string().contains("Bonsai-Q1 requires MLX bits=1")),
Ok(_) => panic!("Expected unsupported Bonsai-Q1 runtime error"),
Ok(_) => panic!("expected load failure: config-only dir has no weights"),
Err(EngineError::Model(ModelError::UnsupportedModel(_))) => {
panic!("Bonsai-Q1 must route to the packed engine, not be rejected as unsupported")
}
Err(err) => assert!(
!err.to_string().contains("requires MLX bits=1"),
"stale bits=1 guard error should be gone, got: {err}"
),
}
}

Expand Down
229 changes: 207 additions & 22 deletions crates/higgs-models/src/bonsai_q1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
//! Unlike `DiffusionEngine::load_q1` which dequantizes to fp32 at load (32 GB
//! residency on 8B), this engine holds MLX's `Q1_0_g128` affine encoding
//! verbatim: `w[row, col] = scales[row, col/128] * bit(col) + biases[row,
//! col/128]`. Dequant happens inline inside the MLX quantized matmul kernel
//! once upstream MLX provides bits=1 affine support.
//! col/128]`. Because stock `oxideai/mlx-rs` ships no bits=1 affine kernel, the
//! matvec/dequant run through runtime JIT Metal kernels in
//! [`crate::metal_kernel`] (decode uses a fused matvec over the packed weights;
//! prefill/embedding dequantize to dense f16).
//!
//! Residency: ~1.25 GB for Bonsai-8B-mlx-1bit, ~260 MB for Bonsai-1.7B-mlx-1bit.
//!
//! Scope: Rust-side loader and engine implementation. Runtime routing is held
//! back in `higgs-engine` until the upstream MLX dependency supports bits=1
//! affine quantization.
//! Scope: Rust-side loader and engine implementation. Routing is enabled in
//! `higgs-engine::model_loader` since the kernels run on the stock MLX pin.

#![allow(
clippy::too_many_arguments,
Expand Down Expand Up @@ -65,7 +66,6 @@ pub fn load_bonsai_q1<P: AsRef<Path>>(model_dir: P) -> Result<BonsaiQ1Gpu, Model
}

pub const GROUP_SIZE: usize = 128;
const BITS: i32 = 1;
const GROUP_SIZE_I32: i32 = GROUP_SIZE as i32;

/// Packed 1-bit linear layer with affine per-group dequant.
Expand All @@ -92,8 +92,9 @@ impl PackedQ1Linear {

/// Dequantize a single row to fp32 (reference path for correctness tests).
///
/// Not used on the hot path — P2 replaces this with a Metal kernel that
/// fuses dequant into the matmul.
/// Not used on the hot path — the hot path uses the fused matvec/dequant
/// kernels in [`crate::metal_kernel`]; this CPU path is the oracle those
/// kernels are tested against.
pub fn dequant_row_to_fp32(&self, row: usize, out: &mut [f32]) {
debug_assert_eq!(out.len(), self.in_features);
let n_groups = self.in_features / GROUP_SIZE;
Expand Down Expand Up @@ -367,7 +368,8 @@ impl BonsaiQ1Engine {
// ---------------------------------------------------------------------------

/// MLX-resident 1-bit linear: weight as uint32 packed, scales/biases as f16,
/// same shape as `PackedQ1Linear` but ready for `ops::quantized_matmul`.
/// same shape as `PackedQ1Linear`, consumed by the bits=1 matvec/dequant
/// kernels in [`crate::metal_kernel`].
pub struct BonsaiQ1GpuLinear {
pub w: Array,
pub scales: Array,
Expand Down Expand Up @@ -400,17 +402,44 @@ impl BonsaiQ1GpuLinear {
})
}

/// `y = x @ dequant(w, scales, biases).T` via fused bits=1 qmm.
/// `y = x @ dequant(w, scales, biases).T`.
///
/// Decode (M = B·T = 1) uses the fused bits=1 matvec kernel, which reads the
/// packed weights once — the performance path. Prefill (M > 1) dequantizes to
/// dense f16 and uses a regular matmul (the transient dense weight amortizes
/// over the M rows). Both run on stock `oxideai/mlx-rs` via
/// [`crate::metal_kernel`] — no `bits=1` MLX kernel required.
pub fn forward(&self, x: &Array) -> Result<Array, Exception> {
ops::quantized_matmul(
x,
&self.w,
&self.scales,
&self.biases,
true,
GROUP_SIZE_I32,
BITS,
)
// `m` is derived from the element count, so a shape whose trailing dim
// isn't `in_features` would be silently treated as a decode row of the
// wrong size. Reject it loudly instead.
let shape = x.shape();
let last_dim = shape.last().copied().unwrap_or(0);
if self.in_features <= 0 || last_dim != self.in_features {
return Err(Exception::custom(format!(
"BonsaiQ1GpuLinear::forward: expected last dim {}, got shape {shape:?}",
self.in_features
)));
}
let total: i32 = shape.iter().product();
let m = total / self.in_features;
if m == 1 {
crate::metal_kernel::bonsai_q1_qmv(
x,
&self.w,
&self.scales,
&self.biases,
GROUP_SIZE_I32,
)
} else {
let wd = crate::metal_kernel::bonsai_q1_dequant(
&self.w,
&self.scales,
&self.biases,
GROUP_SIZE_I32,
)?;
ops::matmul(x, &wd.transpose_axes(&[1, 0])?)
}
}
}

Expand Down Expand Up @@ -521,15 +550,15 @@ impl BonsaiQ1Gpu {

/// Gather embedding rows for a token-ID tensor.
///
/// Uses MLX dequantize after gathering the selected packed rows. This path
/// requires bits=1 affine support in the active MLX runtime.
/// Gathers the selected packed rows, then dequantizes them to dense f16 via
/// the [`crate::metal_kernel`] bits=1 kernel (runs on stock `oxideai/mlx-rs`).
fn embed_rows(&self, ids: &Array) -> Result<Array, Exception> {
let shape = ids.shape().to_vec();
let flat = ids.flatten(None, None)?;
let w = self.embed.w.take_axis(&flat, 0)?;
let s = self.embed.scales.take_axis(&flat, 0)?;
let b = self.embed.biases.take_axis(&flat, 0)?;
let out = ops::dequantize(&w, &s, &b, GROUP_SIZE_I32, BITS)?;
let out = crate::metal_kernel::bonsai_q1_dequant(&w, &s, &b, GROUP_SIZE_I32)?;
let mut ret_shape: Vec<i32> = shape;
ret_shape.push(-1);
out.reshape(&ret_shape)
Expand Down Expand Up @@ -1202,3 +1231,159 @@ fn bytes_to_f16_vec(b: &[u8]) -> Vec<f16> {
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------

#[cfg(test)]
mod tests {
use super::*;
use crate::metal_kernel::{bonsai_q1_dequant, bonsai_q1_qmv_legacy};

/// Deterministic PRNG (SplitMix-ish LCG). Tests prove the kernels match the
/// CPU reference over pseudo-random data, not against hand-picked constants.
fn lcg(state: &mut u64) -> u32 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
(*state >> 32) as u32
}

fn make_packed(out_features: usize, in_features: usize, seed: u64) -> PackedQ1Linear {
let packed_cols = in_features / 32;
let n_groups = in_features / GROUP_SIZE;
let mut st = seed;
let w_packed: Vec<u32> = (0..out_features * packed_cols)
.map(|_| lcg(&mut st))
.collect();
// Per-(row,group) scales/biases so a wrong group index produces a clear
// mismatch. Magnitudes are small and signed, like real affine params.
let scales: Vec<f16> = (0..out_features * n_groups)
.map(|i| f16::from_f32(0.05 + 0.013 * ((i % 7) as f32)))
.collect();
let biases: Vec<f16> = (0..out_features * n_groups)
.map(|i| f16::from_f32(-0.03 + 0.011 * ((i % 5) as f32)))
.collect();
PackedQ1Linear {
w_packed,
scales,
biases,
out_features,
in_features,
}
}

/// CPU reference: full dense dequant via the documented per-row path.
fn dense_reference(p: &PackedQ1Linear) -> Vec<f32> {
let mut wd = vec![0.0f32; p.out_features * p.in_features];
for r in 0..p.out_features {
let (lo, hi) = (r * p.in_features, (r + 1) * p.in_features);
p.dequant_row_to_fp32(r, &mut wd[lo..hi]);
}
wd
}

#[test]
fn dequant_kernel_matches_cpu_reference() {
let (out_f, in_f) = (96usize, 256usize); // 256 cols => 2 groups of 128
let p = make_packed(out_f, in_f, 0xDEAD_BEEF);
let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap();

let wd = bonsai_q1_dequant(&gpu.w, &gpu.scales, &gpu.biases, GROUP_SIZE_I32).unwrap();
wd.eval().unwrap();
let got = wd.as_slice::<f16>();
let want = dense_reference(&p);

assert_eq!(got.len(), want.len());
for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() {
let gv = g.to_f32();
assert!(
(gv - w).abs() <= 2e-3,
"dequant mismatch at {i}: got {gv} want {w}"
);
}
}

#[test]
fn qmv_kernel_matches_cpu_reference() {
let (out_f, in_f) = (96usize, 256usize);
let p = make_packed(out_f, in_f, 0x1234_5678);
let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap();

// x in [-1, 1], deterministic; the kernel reads it as f16, so the
// reference uses the f16-rounded values for an apples-to-apples compare.
let mut st = 0xABCD_EF01_u64;
let x_f32: Vec<f32> = (0..in_f)
.map(|_| (lcg(&mut st) as f32 / u32::MAX as f32).mul_add(2.0, -1.0))
.collect();
let x = Array::from_slice(&x_f32, &[1, in_f as i32])
.as_dtype(Dtype::Float16)
.unwrap();
let x_ref: Vec<f32> = x_f32.iter().map(|&v| f16::from_f32(v).to_f32()).collect();

let y = bonsai_q1_qmv_legacy(&x, &gpu.w, &gpu.scales, &gpu.biases, GROUP_SIZE_I32).unwrap();
y.eval().unwrap();
let got = y.as_slice::<f16>();
assert_eq!(got.len(), out_f);

let wd = dense_reference(&p);
for r in 0..out_f {
let mut acc = 0.0f32;
for c in 0..in_f {
acc += x_ref[c] * wd[r * in_f + c];
}
let gv = got[r].to_f32();
let tol = 1e-2 * acc.abs().max(1.0);
assert!(
(gv - acc).abs() <= tol,
"qmv mismatch at row {r}: got {gv} want {acc}"
);
}
}

/// Oracle for the `qmv_fast`-class kernel (`bonsai_q1_qmv_fast`). Covers the
/// tail path (K = 256 < block) and the main-block path (K = 4096) with an
/// N % 4 != 0 row remainder (the lm_head case). Bit-exact vs CPU reference.
#[test]
fn qmv_fast_kernel_matches_cpu_reference() {
for &(out_f, in_f, seed) in &[
(96usize, 256usize, 0x1234_5678_u64),
(130usize, 4096usize, 0x0BAD_F00D_u64),
] {
let p = make_packed(out_f, in_f, seed);
let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap();

let mut st = 0xABCD_EF01_u64;
let x_f32: Vec<f32> = (0..in_f)
.map(|_| (lcg(&mut st) as f32 / u32::MAX as f32).mul_add(2.0, -1.0))
.collect();
let x = Array::from_slice(&x_f32, &[1, in_f as i32])
.as_dtype(Dtype::Float16)
.unwrap();
let x_ref: Vec<f32> = x_f32.iter().map(|&v| f16::from_f32(v).to_f32()).collect();

let y = crate::metal_kernel::bonsai_q1_qmv_fast(
&x,
&gpu.w,
&gpu.scales,
&gpu.biases,
GROUP_SIZE_I32,
)
.unwrap();
y.eval().unwrap();
let got = y.as_slice::<f16>();
assert_eq!(got.len(), out_f);

let wd = dense_reference(&p);
for r in 0..out_f {
let mut acc = 0.0f32;
for c in 0..in_f {
acc += x_ref[c] * wd[r * in_f + c];
}
let gv = got[r].to_f32();
let tol = 1e-2 * acc.abs().max(1.0);
assert!(
(gv - acc).abs() <= tol,
"qmv_fast mismatch ({out_f}x{in_f}) row {r}: got {gv} want {acc}"
);
}
}
}
}
2 changes: 2 additions & 0 deletions crates/higgs-models/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ pub mod deepseek_v2;
pub mod error;
pub mod gemma2;
pub mod llava_qwen2;
/// Internal: runtime JIT Metal kernels (Bonsai-Q1 bits=1 matvec/dequant).
mod metal_kernel;
pub mod phi3;
pub mod qwen3_moe;
pub mod qwen3_next;
Expand Down
Loading