diff --git a/crates/higgs-engine/src/model_loader.rs b/crates/higgs-engine/src/model_loader.rs index bd2f91d1..ef3c9882 100644 --- a/crates/higgs-engine/src/model_loader.rs +++ b/crates/higgs-engine/src/model_loader.rs @@ -39,15 +39,14 @@ pub fn load_model>(model_dir: P) -> Result match config.model_type.as_str() { "qwen2" | "qwen3" | "llama" | "mistral" => { // Packed 1.25-bpw Bonsai-Q1 checkpoints declare model_type="qwen3" - // but the weights are quantized to bits=1. Keep detection ahead of - // the fp16/Q4 transformer loader so users get an explicit error - // while the workspace remains on upstream oxideai/mlx-rs. + // but the weights are quantized to bits=1. Route them to the + // dedicated packed engine, whose bits=1 matvec/dequant run through + // runtime JIT Metal kernels (higgs-models::metal_kernel) — so it + // runs on stock oxideai/mlx-rs with no forked bits=1 MLX kernel. if is_bonsai_q1(&config.model_dir)? { - return Err(EngineError::Model(ModelError::UnsupportedModel( - "Bonsai-Q1 requires MLX bits=1 affine quantization support; \ - the workspace stays on upstream oxideai/mlx-rs until that support lands" - .to_owned(), - ))); + let gpu = higgs_models::bonsai_q1::load_bonsai_q1(&config.model_dir) + .map_err(EngineError::Model)?; + return Ok(AnyModel::BonsaiQ1(gpu)); } let model = transformer::load_model(&config.model_dir).map_err(EngineError::Model)?; Ok(AnyModel::Transformer(model)) @@ -315,7 +314,12 @@ mod tests { } #[test] - fn load_model_rejects_bonsai_q1_without_runtime_support() { + fn load_model_routes_bonsai_q1_to_packed_engine() { + // A bits=1 / group=128 qwen3 config now routes to the packed Bonsai-Q1 + // engine (its bits=1 kernels live in higgs-models::metal_kernel) instead + // of being rejected up front. With no weights in the dir the load still + // fails inside the engine — but it must no longer be gated out, and the + // old "requires MLX bits=1" guard error must be gone. let (dir, _result) = config_from_raw( r#"{ "model_type": "qwen3", @@ -323,8 +327,14 @@ mod tests { }"#, ); match load_model(dir.path()) { - Err(err) => assert!(err.to_string().contains("Bonsai-Q1 requires MLX bits=1")), - Ok(_) => panic!("Expected unsupported Bonsai-Q1 runtime error"), + Ok(_) => panic!("expected load failure: config-only dir has no weights"), + Err(EngineError::Model(ModelError::UnsupportedModel(_))) => { + panic!("Bonsai-Q1 must route to the packed engine, not be rejected as unsupported") + } + Err(err) => assert!( + !err.to_string().contains("requires MLX bits=1"), + "stale bits=1 guard error should be gone, got: {err}" + ), } } diff --git a/crates/higgs-models/src/bonsai_q1.rs b/crates/higgs-models/src/bonsai_q1.rs index 93fe10e2..99cff8ba 100644 --- a/crates/higgs-models/src/bonsai_q1.rs +++ b/crates/higgs-models/src/bonsai_q1.rs @@ -3,14 +3,15 @@ //! Unlike `DiffusionEngine::load_q1` which dequantizes to fp32 at load (32 GB //! residency on 8B), this engine holds MLX's `Q1_0_g128` affine encoding //! verbatim: `w[row, col] = scales[row, col/128] * bit(col) + biases[row, -//! col/128]`. Dequant happens inline inside the MLX quantized matmul kernel -//! once upstream MLX provides bits=1 affine support. +//! col/128]`. Because stock `oxideai/mlx-rs` ships no bits=1 affine kernel, the +//! matvec/dequant run through runtime JIT Metal kernels in +//! [`crate::metal_kernel`] (decode uses a fused matvec over the packed weights; +//! prefill/embedding dequantize to dense f16). //! //! Residency: ~1.25 GB for Bonsai-8B-mlx-1bit, ~260 MB for Bonsai-1.7B-mlx-1bit. //! -//! Scope: Rust-side loader and engine implementation. Runtime routing is held -//! back in `higgs-engine` until the upstream MLX dependency supports bits=1 -//! affine quantization. +//! Scope: Rust-side loader and engine implementation. Routing is enabled in +//! `higgs-engine::model_loader` since the kernels run on the stock MLX pin. #![allow( clippy::too_many_arguments, @@ -65,7 +66,6 @@ pub fn load_bonsai_q1>(model_dir: P) -> Result 1) dequantizes to + /// dense f16 and uses a regular matmul (the transient dense weight amortizes + /// over the M rows). Both run on stock `oxideai/mlx-rs` via + /// [`crate::metal_kernel`] — no `bits=1` MLX kernel required. pub fn forward(&self, x: &Array) -> Result { - ops::quantized_matmul( - x, - &self.w, - &self.scales, - &self.biases, - true, - GROUP_SIZE_I32, - BITS, - ) + // `m` is derived from the element count, so a shape whose trailing dim + // isn't `in_features` would be silently treated as a decode row of the + // wrong size. Reject it loudly instead. + let shape = x.shape(); + let last_dim = shape.last().copied().unwrap_or(0); + if self.in_features <= 0 || last_dim != self.in_features { + return Err(Exception::custom(format!( + "BonsaiQ1GpuLinear::forward: expected last dim {}, got shape {shape:?}", + self.in_features + ))); + } + let total: i32 = shape.iter().product(); + let m = total / self.in_features; + if m == 1 { + crate::metal_kernel::bonsai_q1_qmv( + x, + &self.w, + &self.scales, + &self.biases, + GROUP_SIZE_I32, + ) + } else { + let wd = crate::metal_kernel::bonsai_q1_dequant( + &self.w, + &self.scales, + &self.biases, + GROUP_SIZE_I32, + )?; + ops::matmul(x, &wd.transpose_axes(&[1, 0])?) + } } } @@ -521,15 +550,15 @@ impl BonsaiQ1Gpu { /// Gather embedding rows for a token-ID tensor. /// - /// Uses MLX dequantize after gathering the selected packed rows. This path - /// requires bits=1 affine support in the active MLX runtime. + /// Gathers the selected packed rows, then dequantizes them to dense f16 via + /// the [`crate::metal_kernel`] bits=1 kernel (runs on stock `oxideai/mlx-rs`). fn embed_rows(&self, ids: &Array) -> Result { let shape = ids.shape().to_vec(); let flat = ids.flatten(None, None)?; let w = self.embed.w.take_axis(&flat, 0)?; let s = self.embed.scales.take_axis(&flat, 0)?; let b = self.embed.biases.take_axis(&flat, 0)?; - let out = ops::dequantize(&w, &s, &b, GROUP_SIZE_I32, BITS)?; + let out = crate::metal_kernel::bonsai_q1_dequant(&w, &s, &b, GROUP_SIZE_I32)?; let mut ret_shape: Vec = shape; ret_shape.push(-1); out.reshape(&ret_shape) @@ -1202,3 +1231,159 @@ fn bytes_to_f16_vec(b: &[u8]) -> Vec { // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use crate::metal_kernel::{bonsai_q1_dequant, bonsai_q1_qmv_legacy}; + + /// Deterministic PRNG (SplitMix-ish LCG). Tests prove the kernels match the + /// CPU reference over pseudo-random data, not against hand-picked constants. + fn lcg(state: &mut u64) -> u32 { + *state = state + .wrapping_mul(6_364_136_223_846_793_005) + .wrapping_add(1_442_695_040_888_963_407); + (*state >> 32) as u32 + } + + fn make_packed(out_features: usize, in_features: usize, seed: u64) -> PackedQ1Linear { + let packed_cols = in_features / 32; + let n_groups = in_features / GROUP_SIZE; + let mut st = seed; + let w_packed: Vec = (0..out_features * packed_cols) + .map(|_| lcg(&mut st)) + .collect(); + // Per-(row,group) scales/biases so a wrong group index produces a clear + // mismatch. Magnitudes are small and signed, like real affine params. + let scales: Vec = (0..out_features * n_groups) + .map(|i| f16::from_f32(0.05 + 0.013 * ((i % 7) as f32))) + .collect(); + let biases: Vec = (0..out_features * n_groups) + .map(|i| f16::from_f32(-0.03 + 0.011 * ((i % 5) as f32))) + .collect(); + PackedQ1Linear { + w_packed, + scales, + biases, + out_features, + in_features, + } + } + + /// CPU reference: full dense dequant via the documented per-row path. + fn dense_reference(p: &PackedQ1Linear) -> Vec { + let mut wd = vec![0.0f32; p.out_features * p.in_features]; + for r in 0..p.out_features { + let (lo, hi) = (r * p.in_features, (r + 1) * p.in_features); + p.dequant_row_to_fp32(r, &mut wd[lo..hi]); + } + wd + } + + #[test] + fn dequant_kernel_matches_cpu_reference() { + let (out_f, in_f) = (96usize, 256usize); // 256 cols => 2 groups of 128 + let p = make_packed(out_f, in_f, 0xDEAD_BEEF); + let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap(); + + let wd = bonsai_q1_dequant(&gpu.w, &gpu.scales, &gpu.biases, GROUP_SIZE_I32).unwrap(); + wd.eval().unwrap(); + let got = wd.as_slice::(); + let want = dense_reference(&p); + + assert_eq!(got.len(), want.len()); + for (i, (g, w)) in got.iter().zip(want.iter()).enumerate() { + let gv = g.to_f32(); + assert!( + (gv - w).abs() <= 2e-3, + "dequant mismatch at {i}: got {gv} want {w}" + ); + } + } + + #[test] + fn qmv_kernel_matches_cpu_reference() { + let (out_f, in_f) = (96usize, 256usize); + let p = make_packed(out_f, in_f, 0x1234_5678); + let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap(); + + // x in [-1, 1], deterministic; the kernel reads it as f16, so the + // reference uses the f16-rounded values for an apples-to-apples compare. + let mut st = 0xABCD_EF01_u64; + let x_f32: Vec = (0..in_f) + .map(|_| (lcg(&mut st) as f32 / u32::MAX as f32).mul_add(2.0, -1.0)) + .collect(); + let x = Array::from_slice(&x_f32, &[1, in_f as i32]) + .as_dtype(Dtype::Float16) + .unwrap(); + let x_ref: Vec = x_f32.iter().map(|&v| f16::from_f32(v).to_f32()).collect(); + + let y = bonsai_q1_qmv_legacy(&x, &gpu.w, &gpu.scales, &gpu.biases, GROUP_SIZE_I32).unwrap(); + y.eval().unwrap(); + let got = y.as_slice::(); + assert_eq!(got.len(), out_f); + + let wd = dense_reference(&p); + for r in 0..out_f { + let mut acc = 0.0f32; + for c in 0..in_f { + acc += x_ref[c] * wd[r * in_f + c]; + } + let gv = got[r].to_f32(); + let tol = 1e-2 * acc.abs().max(1.0); + assert!( + (gv - acc).abs() <= tol, + "qmv mismatch at row {r}: got {gv} want {acc}" + ); + } + } + + /// Oracle for the `qmv_fast`-class kernel (`bonsai_q1_qmv_fast`). Covers the + /// tail path (K = 256 < block) and the main-block path (K = 4096) with an + /// N % 4 != 0 row remainder (the lm_head case). Bit-exact vs CPU reference. + #[test] + fn qmv_fast_kernel_matches_cpu_reference() { + for &(out_f, in_f, seed) in &[ + (96usize, 256usize, 0x1234_5678_u64), + (130usize, 4096usize, 0x0BAD_F00D_u64), + ] { + let p = make_packed(out_f, in_f, seed); + let gpu = BonsaiQ1GpuLinear::from_packed(&p).unwrap(); + + let mut st = 0xABCD_EF01_u64; + let x_f32: Vec = (0..in_f) + .map(|_| (lcg(&mut st) as f32 / u32::MAX as f32).mul_add(2.0, -1.0)) + .collect(); + let x = Array::from_slice(&x_f32, &[1, in_f as i32]) + .as_dtype(Dtype::Float16) + .unwrap(); + let x_ref: Vec = x_f32.iter().map(|&v| f16::from_f32(v).to_f32()).collect(); + + let y = crate::metal_kernel::bonsai_q1_qmv_fast( + &x, + &gpu.w, + &gpu.scales, + &gpu.biases, + GROUP_SIZE_I32, + ) + .unwrap(); + y.eval().unwrap(); + let got = y.as_slice::(); + assert_eq!(got.len(), out_f); + + let wd = dense_reference(&p); + for r in 0..out_f { + let mut acc = 0.0f32; + for c in 0..in_f { + acc += x_ref[c] * wd[r * in_f + c]; + } + let gv = got[r].to_f32(); + let tol = 1e-2 * acc.abs().max(1.0); + assert!( + (gv - acc).abs() <= tol, + "qmv_fast mismatch ({out_f}x{in_f}) row {r}: got {gv} want {acc}" + ); + } + } + } +} diff --git a/crates/higgs-models/src/lib.rs b/crates/higgs-models/src/lib.rs index 8ae2716a..ed81b5c1 100644 --- a/crates/higgs-models/src/lib.rs +++ b/crates/higgs-models/src/lib.rs @@ -4,6 +4,8 @@ pub mod deepseek_v2; pub mod error; pub mod gemma2; pub mod llava_qwen2; +/// Internal: runtime JIT Metal kernels (Bonsai-Q1 bits=1 matvec/dequant). +mod metal_kernel; pub mod phi3; pub mod qwen3_moe; pub mod qwen3_next; diff --git a/crates/higgs-models/src/metal_kernel.rs b/crates/higgs-models/src/metal_kernel.rs new file mode 100644 index 00000000..ce81ab06 --- /dev/null +++ b/crates/higgs-models/src/metal_kernel.rs @@ -0,0 +1,814 @@ +//! Runtime JIT Metal kernels for Bonsai-Q1 (1-bit affine quantization). +//! +//! Upstream `oxideai/mlx-rs` ships no `bits=1` affine kernels (MLX gates affine +//! quant to `bits >= 2`), so `ops::quantized_matmul`/`ops::dequantize` with +//! `bits=1` fail at runtime with `Unable to load kernel affine_dequantize_*_b_1`. +//! +//! Rather than fork mlx-rs (which forces a full from-source mlx-c rebuild), we +//! add the missing kernels *from this crate* using the runtime JIT facility that +//! mlx-c already exposes (`mlx_fast_metal_kernel_*`) and that `mlx-sys` compiles +//! in. The kernels below are JIT-compiled by Metal at first use and cached by +//! MLX internally per template instantiation. This keeps us on the stock +//! `oxideai/mlx-rs` pin with no extra native recompile. +//! +//! The FFI plumbing (kernel handle wrapper, `Array` <-> `mlx_array`, vector +//! construction, error capture) mirrors the proven `qgemv_4bit` path in +//! [`crate::qwen3_next`]; the kernel math mirrors +//! [`crate::bonsai_q1::PackedQ1Linear::dequant_row_to_fp32`]: +//! `W[r,c] = scale[r, c/G] * bit + bias[r, c/G]`, `bit = (w[r, c/32] >> (c%32)) & 1`. + +use std::ffi::{CStr, CString, c_char, c_void}; +use std::sync::OnceLock; + +use mlx_rs::{Array, Stream, error::Exception}; + +// --------------------------------------------------------------------------- +// FFI error capture (per-thread, mirrors qwen3_next). +// --------------------------------------------------------------------------- + +thread_local! { + static FFI_LAST_ERROR: std::cell::RefCell> = + const { std::cell::RefCell::new(None) }; +} + +/// Error handler registered once with MLX to capture error messages on the +/// calling thread. +#[allow(unsafe_code)] +unsafe extern "C" fn ffi_error_handler(msg: *const c_char, _data: *mut c_void) { + let s = unsafe { CStr::from_ptr(msg) } + .to_string_lossy() + .into_owned(); + FFI_LAST_ERROR.with(|cell| *cell.borrow_mut() = Some(s)); +} + +fn ensure_ffi_error_handler() { + static REGISTERED: OnceLock<()> = OnceLock::new(); + REGISTERED.get_or_init(|| { + #[allow(unsafe_code)] + unsafe { + mlx_sys::mlx_set_error_handler(Some(ffi_error_handler), std::ptr::null_mut(), None); + } + }); +} + +fn take_last_error() -> String { + FFI_LAST_ERROR + .with(|cell| cell.borrow_mut().take()) + .unwrap_or_else(|| "(no MLX error message captured)".to_owned()) +} + +// --------------------------------------------------------------------------- +// Cached kernel handle. +// --------------------------------------------------------------------------- + +/// Wraps a compiled `mlx_fast_metal_kernel`, freed on drop. +struct CachedMetalKernel(mlx_sys::mlx_fast_metal_kernel); + +// SAFETY: the handle is created once and only ever read (passed by value to +// `mlx_fast_metal_kernel_apply`); no interior mutability is shared across threads. +#[allow(unsafe_code)] +unsafe impl Send for CachedMetalKernel {} +#[allow(unsafe_code)] +unsafe impl Sync for CachedMetalKernel {} + +impl Drop for CachedMetalKernel { + fn drop(&mut self) { + #[allow(unsafe_code)] + unsafe { + mlx_sys::mlx_fast_metal_kernel_free(self.0); + } + } +} + +/// Number of simdgroups per threadgroup for the fused matvec. More simdgroups +/// help large-K layers (fewer chunk barriers). Overridable for tuning. +fn qmv_nsg(k_dim: i32) -> i32 { + static OVERRIDE: OnceLock> = OnceLock::new(); + let ovr = *OVERRIDE.get_or_init(|| { + std::env::var("HIGGS_BONSAI_QMV_NSG") + .ok() + .and_then(|s| s.parse::().ok()) + .filter(|n| matches!(n, 4 | 8 | 16 | 32)) + }); + ovr.unwrap_or(if k_dim > 8192 { 16 } else { 8 }) +} + +/// Build the vector-of-strings that names kernel inputs/outputs. +#[allow(unsafe_code)] +fn cstr_vec(names: &[&CStr]) -> mlx_sys::mlx_vector_string { + let ptrs: Vec<*const c_char> = names.iter().map(|s| s.as_ptr()).collect(); + unsafe { mlx_sys::mlx_vector_string_new_data(ptrs.as_ptr().cast_mut(), ptrs.len()) } +} + +// --------------------------------------------------------------------------- +// Fused 1-bit quantized matvec (decode hot path). +// +// y = x @ dequant(W).T for a single token (M = 1). +// Mirrors qgemv_4bit but unpacks 32 1-bit weights per uint32 word. +// One simdgroup per output row; x staged in threadgroup memory; simd_sum reduce. +// --------------------------------------------------------------------------- + +const QMV_KERNEL_SOURCE: &str = r" +constexpr int CHUNK = (K <= 8192) ? K : 8192; + +threadgroup OutT x_sh[CHUNK]; + +auto tg = threadgroup_position_in_grid.x; +auto sg = simdgroup_index_in_threadgroup; +auto lane = thread_index_in_simdgroup; +auto tid = thread_index_in_threadgroup; +auto n_sg = simdgroups_per_threadgroup; +uint tg_sz = n_sg * 32u; + +int row = tg * int(n_sg) + int(sg); +bool valid = (row < n_param); + +float acc = 0.0f; + +for (int k_off = 0; k_off < K; k_off += CHUNK) { + int k_end = min(k_off + CHUNK, K); + int k_len = k_end - k_off; + + for (uint i = tid; i < uint(k_len); i += tg_sz) { + x_sh[i] = x[k_off + i]; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (valid) { + int wp_off = k_off / 32; + int wp_end = k_end / 32; + auto w_row = w + row * KPacked; + + for (int idx = wp_off + int(lane); idx < wp_end; idx += 32) { + uint packed = w_row[idx]; + int kl = (idx - wp_off) * 32; + + float dot_val = 0.0f; + float sum_x = 0.0f; + for (uint j = 0u; j < 32u; ++j) { + float xv = float(x_sh[kl + int(j)]); + float bit = float((packed >> j) & 1u); + dot_val += bit * xv; + sum_x += xv; + } + + int g = idx * 32 / GroupSize; + float s_val = float(sc[row * NumGroups + g]); + float b_val = float(bi[row * NumGroups + g]); + acc += s_val * dot_val + b_val * sum_x; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); +} + +if (valid) { + acc = simd_sum(acc); + if (lane == 0) { + y[row] = OutT(acc); + } +} +"; + +#[allow(unsafe_code)] +fn create_qmv_kernel() -> mlx_sys::mlx_fast_metal_kernel { + let in_vec = cstr_vec(&[c"w", c"sc", c"bi", c"x", c"n_param"]); + let out_vec = cstr_vec(&[c"y"]); + let source = CString::new(QMV_KERNEL_SOURCE).unwrap_or_default(); + unsafe { + let kernel = mlx_sys::mlx_fast_metal_kernel_new( + c"higgs_bonsai_q1_qmv".as_ptr(), + in_vec, + out_vec, + source.as_ptr(), + c"".as_ptr(), + false, // ensure_row_contiguous + false, // atomic_outputs + ); + mlx_sys::mlx_vector_string_free(in_vec); + mlx_sys::mlx_vector_string_free(out_vec); + kernel + } +} + +#[allow(unsafe_code)] +fn configure_qmv_kernel( + out_dtype: mlx_sys::mlx_dtype, + n_rows: i32, + k_dim: i32, + group_size: i32, +) -> mlx_sys::mlx_fast_metal_kernel_config { + unsafe { + let config = mlx_sys::mlx_fast_metal_kernel_config_new(); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_dtype( + config, + c"OutT".as_ptr(), + out_dtype, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int(config, c"K".as_ptr(), k_dim); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"GroupSize".as_ptr(), + group_size, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"KPacked".as_ptr(), + k_dim / 32, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"NumGroups".as_ptr(), + k_dim / group_size, + ); + + let nsg = qmv_nsg(k_dim); + let n_tgs = (n_rows + nsg - 1) / nsg; + mlx_sys::mlx_fast_metal_kernel_config_set_grid(config, n_tgs * 32, nsg, 1); + mlx_sys::mlx_fast_metal_kernel_config_set_thread_group(config, 32, nsg, 1); + + let y_shape = [1, n_rows]; + mlx_sys::mlx_fast_metal_kernel_config_add_output_arg( + config, + y_shape.as_ptr(), + y_shape.len(), + out_dtype, + ); + config + } +} + +/// Original per-row 1-bit matvec: one simdgroup computes one output row, with +/// `x` staged in threadgroup memory. Kept as the A/B baseline (selected when +/// `HIGGS_BONSAI_QMV_KERNEL=legacy`). See [`bonsai_q1_qmv`] for the dispatcher. +#[allow(unsafe_code)] +pub fn bonsai_q1_qmv_legacy( + x: &Array, + weight: &Array, + scales: &Array, + biases: &Array, + group_size: i32, +) -> Result { + ensure_ffi_error_handler(); + + let x_shape = x.shape(); + let weight_shape = weight.shape(); + let n_rows = weight_shape + .first() + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_qmv: weight has no rows"))?; + let k_packed = weight_shape + .get(1) + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_qmv: weight has no columns"))?; + let k_dim = k_packed * 32; // 32 one-bit weights per uint32 word + + let x_flat = x.reshape(&[k_dim])?; + let w_flat = weight.reshape(&[-1])?; + let s_flat = scales.flatten(None, None)?; + let b_flat = biases.flatten(None, None)?; + + let stream = Stream::task_local_or_default(); + let out_dtype = unsafe { mlx_sys::mlx_array_dtype(x.as_ptr()) }; + + let cached = QMV_KERNEL.get_or_init(|| CachedMetalKernel(create_qmv_kernel())); + let config = configure_qmv_kernel(out_dtype, n_rows, k_dim, group_size); + + let n_scalar = unsafe { mlx_sys::mlx_array_new_int(n_rows) }; + let input_ptrs = [ + w_flat.as_ptr(), + s_flat.as_ptr(), + b_flat.as_ptr(), + x_flat.as_ptr(), + n_scalar, + ]; + let inputs_vec = + unsafe { mlx_sys::mlx_vector_array_new_data(input_ptrs.as_ptr(), input_ptrs.len()) }; + + let mut outputs_vec = unsafe { mlx_sys::mlx_vector_array_new() }; + let status = unsafe { + mlx_sys::mlx_fast_metal_kernel_apply( + &raw mut outputs_vec, + cached.0, + inputs_vec, + config, + stream.as_ptr(), + ) + }; + + let result = if status != 0 { + Err(Exception::custom(format!( + "bonsai_q1_qmv failed: {}", + take_last_error() + ))) + } else { + let mut y_ptr = unsafe { mlx_sys::mlx_array_new() }; + unsafe { mlx_sys::mlx_vector_array_get(&raw mut y_ptr, outputs_vec, 0) }; + let y = unsafe { Array::from_ptr(y_ptr) }; + let trim_to = x_shape.len().saturating_sub(1); + let mut out_shape = x_shape + .get(..trim_to) + .ok_or_else(|| Exception::custom("bonsai_q1_qmv: x_shape too small"))? + .to_vec(); + out_shape.push(n_rows); + y.reshape(&out_shape) + }; + + unsafe { + mlx_sys::mlx_fast_metal_kernel_config_free(config); + mlx_sys::mlx_vector_array_free(inputs_vec); + mlx_sys::mlx_vector_array_free(outputs_vec); + mlx_sys::mlx_array_free(n_scalar); + } + result +} + +static QMV_KERNEL: OnceLock = OnceLock::new(); +static FAST_QMV_KERNEL: OnceLock = OnceLock::new(); + +/// Simdgroups per threadgroup for the `qmv_fast`-class kernel. Each simdgroup +/// computes `RESULTS_PER_SIMDGROUP` (= 4) output rows. Tunable via +/// `HIGGS_BONSAI_FAST_NSG` (Phase-2 sweep); MLX's reference uses 2. +fn fast_qmv_nsg() -> i32 { + static OVERRIDE: OnceLock = OnceLock::new(); + *OVERRIDE.get_or_init(|| { + std::env::var("HIGGS_BONSAI_FAST_NSG") + .ok() + .and_then(|s| s.parse::().ok()) + .filter(|n| matches!(n, 1 | 2 | 4 | 8)) + .unwrap_or(2) + }) +} + +/// Whether to route the decode matvec through the `qmv_fast`-class kernel. +/// It is the **default** (measured 2.3× faster on Bonsai-8B decode and bit-exact +/// vs the CPU reference); opt back to the original per-row kernel with +/// `HIGGS_BONSAI_QMV_KERNEL=legacy`. +fn use_fast_qmv() -> bool { + static FAST: OnceLock = OnceLock::new(); + *FAST.get_or_init(|| { + !std::env::var("HIGGS_BONSAI_QMV_KERNEL").is_ok_and(|v| v.eq_ignore_ascii_case("legacy")) + }) +} + +/// Fused 1-bit quantized matvec: `y = x @ dequant(weight).T` for a single token. +/// +/// `x` must hold exactly `in_features` elements (M = 1). `weight` is the packed +/// `[out_features, in_features/32]` uint32 matrix; `scales`/`biases` are +/// `[out_features, in_features/group_size]`. Output dtype matches `x`. +/// +/// Dispatches to the `qmv_fast`-class kernel ([`bonsai_q1_qmv_fast`]) by +/// default; set `HIGGS_BONSAI_QMV_KERNEL=legacy` to force the per-row kernel. +pub fn bonsai_q1_qmv( + x: &Array, + weight: &Array, + scales: &Array, + biases: &Array, + group_size: i32, +) -> Result { + if use_fast_qmv() { + bonsai_q1_qmv_fast(x, weight, scales, biases, group_size) + } else { + bonsai_q1_qmv_legacy(x, weight, scales, biases, group_size) + } +} + +// --------------------------------------------------------------------------- +// `qmv_fast`-class 1-bit matvec (decode hot path). +// +// Ports MLX/PrismML `qmv_fast` tiling onto our uint32 packing: each simdgroup +// computes RESULTS_PER_SIMDGROUP (4) output rows; each of its 32 lanes holds +// VPT (64) input values in registers (no threadgroup memory, no barriers) and +// reuses them across all 4 rows. block_size = 64 * 32 = 2048. The bits=1 affine +// math is identical to the legacy kernel — `scale * sum(bit*x) + bias * sum(x)` +// — only the data movement differs. Group scales/biases are per-lane (a lane's +// 64 values lie in one 128-wide group); per-row partials are simd_sum-reduced. +// --------------------------------------------------------------------------- + +const FAST_QMV_KERNEL_SOURCE: &str = r" +constexpr int VPT = 64; // values_per_thread +constexpr int RPS = 4; // results_per_simdgroup +constexpr int WPT = VPT / 32; // packed uint32 words per thread (2) +constexpr int BLK = VPT * 32; // block_size = 2048 + +uint tgx = threadgroup_position_in_grid.x; +uint sg = simdgroup_index_in_threadgroup; +uint lid = thread_index_in_simdgroup; +uint nsg = simdgroups_per_threadgroup; + +int out_row = int(tgx) * (int(nsg) * RPS) + int(sg) * RPS; + +float xt[VPT]; +float result[RPS]; +for (int r = 0; r < RPS; ++r) { result[r] = 0.0f; } + +int aligned_end = (K / BLK) * BLK; + +// Main loop: full 2048-element blocks (covers every real Bonsai layer, since +// all K are multiples of 2048). +for (int k = 0; k < aligned_end; k += BLK) { + int xbase = k + int(lid) * VPT; + float sum = 0.0f; + for (int i = 0; i < VPT; ++i) { float v = float(x[xbase + i]); xt[i] = v; sum += v; } + + int wcol = (k / 32) + int(lid) * WPT; + int g = xbase / GroupSize; // all VPT values fall in one group + + for (int r = 0; r < RPS; ++r) { + int row = out_row + r; + if (row >= n_param) { continue; } + float accum = 0.0f; + for (int wp = 0; wp < WPT; ++wp) { + uint packed = w[row * KPacked + wcol + wp]; + int xo = wp * 32; + for (int bk = 0; bk < 4; ++bk) { + uint wb = (packed >> (uint(bk) * 8u)) & 0xFFu; + int b = xo + bk * 8; + accum += select(0.0f, xt[b + 0], (wb & 0x01u) != 0u); + accum += select(0.0f, xt[b + 1], (wb & 0x02u) != 0u); + accum += select(0.0f, xt[b + 2], (wb & 0x04u) != 0u); + accum += select(0.0f, xt[b + 3], (wb & 0x08u) != 0u); + accum += select(0.0f, xt[b + 4], (wb & 0x10u) != 0u); + accum += select(0.0f, xt[b + 5], (wb & 0x20u) != 0u); + accum += select(0.0f, xt[b + 6], (wb & 0x40u) != 0u); + accum += select(0.0f, xt[b + 7], (wb & 0x80u) != 0u); + } + } + float s_val = float(sc[row * NumGroups + g]); + float b_val = float(bi[row * NumGroups + g]); + result[r] += s_val * accum + b_val * sum; + } +} + +// Tail: only exercised by tests with K < 2048 or K % 2048 != 0. +if (aligned_end < K) { + int xbase = aligned_end + int(lid) * VPT; + bool in_bounds = xbase < K; + float sum = 0.0f; + for (int i = 0; i < VPT; ++i) { + float v = (in_bounds && (xbase + i) < K) ? float(x[xbase + i]) : 0.0f; + xt[i] = v; + sum += v; + } + int wcol = (aligned_end / 32) + int(lid) * WPT; + int g = in_bounds ? (xbase / GroupSize) : 0; + for (int r = 0; r < RPS; ++r) { + int row = out_row + r; + if (row >= n_param || !in_bounds) { continue; } + float accum = 0.0f; + for (int wp = 0; wp < WPT; ++wp) { + int widx = wcol + wp; + if (widx >= KPacked) { continue; } + uint packed = w[row * KPacked + widx]; + int xo = wp * 32; + for (int bk = 0; bk < 4; ++bk) { + uint wb = (packed >> (uint(bk) * 8u)) & 0xFFu; + int b = xo + bk * 8; + accum += select(0.0f, xt[b + 0], (wb & 0x01u) != 0u); + accum += select(0.0f, xt[b + 1], (wb & 0x02u) != 0u); + accum += select(0.0f, xt[b + 2], (wb & 0x04u) != 0u); + accum += select(0.0f, xt[b + 3], (wb & 0x08u) != 0u); + accum += select(0.0f, xt[b + 4], (wb & 0x10u) != 0u); + accum += select(0.0f, xt[b + 5], (wb & 0x20u) != 0u); + accum += select(0.0f, xt[b + 6], (wb & 0x40u) != 0u); + accum += select(0.0f, xt[b + 7], (wb & 0x80u) != 0u); + } + } + float s_val = float(sc[row * NumGroups + g]); + float b_val = float(bi[row * NumGroups + g]); + result[r] += s_val * accum + b_val * sum; + } +} + +for (int r = 0; r < RPS; ++r) { + int row = out_row + r; + float v = simd_sum(result[r]); + if (lid == 0u && row < n_param) { + y[row] = OutT(v); + } +} +"; + +#[allow(unsafe_code)] +fn create_fast_qmv_kernel() -> mlx_sys::mlx_fast_metal_kernel { + let in_vec = cstr_vec(&[c"w", c"sc", c"bi", c"x", c"n_param"]); + let out_vec = cstr_vec(&[c"y"]); + let source = CString::new(FAST_QMV_KERNEL_SOURCE).unwrap_or_default(); + unsafe { + let kernel = mlx_sys::mlx_fast_metal_kernel_new( + c"higgs_bonsai_q1_qmv_fast".as_ptr(), + in_vec, + out_vec, + source.as_ptr(), + c"".as_ptr(), + false, // ensure_row_contiguous + false, // atomic_outputs + ); + mlx_sys::mlx_vector_string_free(in_vec); + mlx_sys::mlx_vector_string_free(out_vec); + kernel + } +} + +#[allow(unsafe_code)] +fn configure_fast_qmv_kernel( + out_dtype: mlx_sys::mlx_dtype, + n_rows: i32, + k_dim: i32, + group_size: i32, +) -> mlx_sys::mlx_fast_metal_kernel_config { + unsafe { + let config = mlx_sys::mlx_fast_metal_kernel_config_new(); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_dtype( + config, + c"OutT".as_ptr(), + out_dtype, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int(config, c"K".as_ptr(), k_dim); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"GroupSize".as_ptr(), + group_size, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"KPacked".as_ptr(), + k_dim / 32, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"NumGroups".as_ptr(), + k_dim / group_size, + ); + + // Each simdgroup computes 4 rows; nsg simdgroups per threadgroup. + let nsg = fast_qmv_nsg(); + let rows_per_tg = nsg * 4; + let n_tgs = (n_rows + rows_per_tg - 1) / rows_per_tg; + mlx_sys::mlx_fast_metal_kernel_config_set_grid(config, n_tgs * 32, nsg, 1); + mlx_sys::mlx_fast_metal_kernel_config_set_thread_group(config, 32, nsg, 1); + + let y_shape = [1, n_rows]; + mlx_sys::mlx_fast_metal_kernel_config_add_output_arg( + config, + y_shape.as_ptr(), + y_shape.len(), + out_dtype, + ); + config + } +} + +/// `qmv_fast`-class variant of [`bonsai_q1_qmv_legacy`]. Same inputs/outputs and +/// bit-exact result; faster tiling. See [`bonsai_q1_qmv`] for dispatch. +#[allow(unsafe_code)] +pub fn bonsai_q1_qmv_fast( + x: &Array, + weight: &Array, + scales: &Array, + biases: &Array, + group_size: i32, +) -> Result { + ensure_ffi_error_handler(); + + let x_shape = x.shape(); + let weight_shape = weight.shape(); + let n_rows = weight_shape + .first() + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_qmv_fast: weight has no rows"))?; + let k_packed = weight_shape + .get(1) + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_qmv_fast: weight has no columns"))?; + let k_dim = k_packed * 32; + + let x_flat = x.reshape(&[k_dim])?; + let w_flat = weight.reshape(&[-1])?; + let s_flat = scales.flatten(None, None)?; + let b_flat = biases.flatten(None, None)?; + + let stream = Stream::task_local_or_default(); + let out_dtype = unsafe { mlx_sys::mlx_array_dtype(x.as_ptr()) }; + + let cached = FAST_QMV_KERNEL.get_or_init(|| CachedMetalKernel(create_fast_qmv_kernel())); + let config = configure_fast_qmv_kernel(out_dtype, n_rows, k_dim, group_size); + + let n_scalar = unsafe { mlx_sys::mlx_array_new_int(n_rows) }; + let input_ptrs = [ + w_flat.as_ptr(), + s_flat.as_ptr(), + b_flat.as_ptr(), + x_flat.as_ptr(), + n_scalar, + ]; + let inputs_vec = + unsafe { mlx_sys::mlx_vector_array_new_data(input_ptrs.as_ptr(), input_ptrs.len()) }; + + let mut outputs_vec = unsafe { mlx_sys::mlx_vector_array_new() }; + let status = unsafe { + mlx_sys::mlx_fast_metal_kernel_apply( + &raw mut outputs_vec, + cached.0, + inputs_vec, + config, + stream.as_ptr(), + ) + }; + + let result = if status != 0 { + Err(Exception::custom(format!( + "bonsai_q1_qmv_fast failed: {}", + take_last_error() + ))) + } else { + let mut y_ptr = unsafe { mlx_sys::mlx_array_new() }; + unsafe { mlx_sys::mlx_vector_array_get(&raw mut y_ptr, outputs_vec, 0) }; + let y = unsafe { Array::from_ptr(y_ptr) }; + let trim_to = x_shape.len().saturating_sub(1); + let mut out_shape = x_shape + .get(..trim_to) + .ok_or_else(|| Exception::custom("bonsai_q1_qmv_fast: x_shape too small"))? + .to_vec(); + out_shape.push(n_rows); + y.reshape(&out_shape) + }; + + unsafe { + mlx_sys::mlx_fast_metal_kernel_config_free(config); + mlx_sys::mlx_vector_array_free(inputs_vec); + mlx_sys::mlx_vector_array_free(outputs_vec); + mlx_sys::mlx_array_free(n_scalar); + } + result +} + +// --------------------------------------------------------------------------- +// 1-bit dequantize to dense (embedding gather + prefill matmul path). +// +// wd[n, c] = scales[n, c/G] * bit(w[n, c/32], c%32) + biases[n, c/G]. +// One thread per packed uint32 word (writes 32 dense outputs). +// --------------------------------------------------------------------------- + +const DEQUANT_KERNEL_SOURCE: &str = r" +uint gid = thread_position_in_grid.x; +if (gid >= uint(NWords)) { return; } + +uint n = gid / uint(KPacked); +uint idx = gid % uint(KPacked); +uint packed = w[gid]; + +int g = int(idx) * 32 / GroupSize; +float s_val = float(sc[n * uint(NumGroups) + uint(g)]); +float b_val = float(bi[n * uint(NumGroups) + uint(g)]); + +uint base = n * uint(K) + idx * 32u; +for (uint j = 0u; j < 32u; ++j) { + float bit = float((packed >> j) & 1u); + wd[base + j] = OutT(s_val * bit + b_val); +} +"; + +#[allow(unsafe_code)] +fn create_dequant_kernel() -> mlx_sys::mlx_fast_metal_kernel { + let in_vec = cstr_vec(&[c"w", c"sc", c"bi"]); + let out_vec = cstr_vec(&[c"wd"]); + let source = CString::new(DEQUANT_KERNEL_SOURCE).unwrap_or_default(); + unsafe { + let kernel = mlx_sys::mlx_fast_metal_kernel_new( + c"higgs_bonsai_q1_dequant".as_ptr(), + in_vec, + out_vec, + source.as_ptr(), + c"".as_ptr(), + false, + false, + ); + mlx_sys::mlx_vector_string_free(in_vec); + mlx_sys::mlx_vector_string_free(out_vec); + kernel + } +} + +#[allow(unsafe_code)] +fn configure_dequant_kernel( + out_dtype: mlx_sys::mlx_dtype, + n_rows: i32, + k_dim: i32, + group_size: i32, +) -> mlx_sys::mlx_fast_metal_kernel_config { + let k_packed = k_dim / 32; + let n_words = n_rows * k_packed; + unsafe { + let config = mlx_sys::mlx_fast_metal_kernel_config_new(); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_dtype( + config, + c"OutT".as_ptr(), + out_dtype, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int(config, c"K".as_ptr(), k_dim); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"KPacked".as_ptr(), + k_packed, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"GroupSize".as_ptr(), + group_size, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"NumGroups".as_ptr(), + k_dim / group_size, + ); + mlx_sys::mlx_fast_metal_kernel_config_add_template_arg_int( + config, + c"NWords".as_ptr(), + n_words, + ); + + let tg: i32 = 256; + let grid = ((n_words + tg - 1) / tg) * tg; + mlx_sys::mlx_fast_metal_kernel_config_set_grid(config, grid, 1, 1); + mlx_sys::mlx_fast_metal_kernel_config_set_thread_group(config, tg, 1, 1); + + let wd_shape = [n_rows, k_dim]; + mlx_sys::mlx_fast_metal_kernel_config_add_output_arg( + config, + wd_shape.as_ptr(), + wd_shape.len(), + out_dtype, + ); + config + } +} + +/// Dequantize a packed 1-bit matrix to a dense `[out_features, in_features]` +/// array (dtype matches `scales`). Used for embedding gather and the prefill +/// (M > 1) matmul path. +#[allow(unsafe_code)] +pub fn bonsai_q1_dequant( + weight: &Array, + scales: &Array, + biases: &Array, + group_size: i32, +) -> Result { + ensure_ffi_error_handler(); + + let weight_shape = weight.shape(); + let n_rows = weight_shape + .first() + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_dequant: weight has no rows"))?; + let k_packed = weight_shape + .get(1) + .copied() + .ok_or_else(|| Exception::custom("bonsai_q1_dequant: weight has no columns"))?; + let k_dim = k_packed * 32; + + let w_flat = weight.reshape(&[-1])?; + let s_flat = scales.flatten(None, None)?; + let b_flat = biases.flatten(None, None)?; + + let stream = Stream::task_local_or_default(); + let out_dtype = unsafe { mlx_sys::mlx_array_dtype(scales.as_ptr()) }; + + let cached = DEQUANT_KERNEL.get_or_init(|| CachedMetalKernel(create_dequant_kernel())); + let config = configure_dequant_kernel(out_dtype, n_rows, k_dim, group_size); + + let input_ptrs = [w_flat.as_ptr(), s_flat.as_ptr(), b_flat.as_ptr()]; + let inputs_vec = + unsafe { mlx_sys::mlx_vector_array_new_data(input_ptrs.as_ptr(), input_ptrs.len()) }; + + let mut outputs_vec = unsafe { mlx_sys::mlx_vector_array_new() }; + let status = unsafe { + mlx_sys::mlx_fast_metal_kernel_apply( + &raw mut outputs_vec, + cached.0, + inputs_vec, + config, + stream.as_ptr(), + ) + }; + + let result = if status != 0 { + Err(Exception::custom(format!( + "bonsai_q1_dequant failed: {}", + take_last_error() + ))) + } else { + let mut wd_ptr = unsafe { mlx_sys::mlx_array_new() }; + unsafe { mlx_sys::mlx_vector_array_get(&raw mut wd_ptr, outputs_vec, 0) }; + Ok(unsafe { Array::from_ptr(wd_ptr) }) + }; + + unsafe { + mlx_sys::mlx_fast_metal_kernel_config_free(config); + mlx_sys::mlx_vector_array_free(inputs_vec); + mlx_sys::mlx_vector_array_free(outputs_vec); + } + result +} + +static DEQUANT_KERNEL: OnceLock = OnceLock::new();