From f26b509de663aeaecf2a6b1c93929c2433766b5c Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 29 Jan 2025 11:14:19 +0100 Subject: [PATCH 01/21] fix type confusion in temp tile for store --- linalg/src/frame/mmm/scratch.rs | 89 +++++++++++++++++++++++++-------- 1 file changed, 68 insertions(+), 21 deletions(-) diff --git a/linalg/src/frame/mmm/scratch.rs b/linalg/src/frame/mmm/scratch.rs index ac34fa6da4..80e1fb0020 100644 --- a/linalg/src/frame/mmm/scratch.rs +++ b/linalg/src/frame/mmm/scratch.rs @@ -47,7 +47,8 @@ impl TLSScratch { ker_specs.extend_from_slice(&scratch.ker_specs); unsafe { - self.blob.ensure_size_and_align(scratch.blob_size, scratch.blob_align); + self.blob + .ensure_size_and_align(scratch.blob_size, scratch.blob_align); for LocDependant { loc, ker_spec, .. } in &scratch.loc_dependant { #[allow(clippy::single_match)] @@ -121,7 +122,13 @@ impl ScratchSpaceImpl { let mut offset = 0; let mut align = std::mem::size_of::<*const ()>(); fn ld(spec: usize, uspec: usize, loc: usize) -> LocDependant { - LocDependant { spec, ker_spec: uspec, loc, buffer_a: None, buffer_b: None } + LocDependant { + spec, + ker_spec: uspec, + loc, + buffer_a: None, + buffer_b: None, + } } for (ix, spec) in specs.iter().enumerate() { offset = offset.next_multiple_of(&align); @@ -138,25 +145,35 @@ impl ScratchSpaceImpl { FS::RoundingShiftRight(s, rp) => FKS::RoundingShiftRight(*s, *rp), FS::QScale(s, rp, m) => FKS::QScale(*s, *rp, *m), FS::BinPerRow(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.mr(); FusedKerSpec::Done } FS::BinPerCol(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.nr(); FusedKerSpec::Done } FS::AddRowColProducts(_, _) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * (ker.mr() + ker.nr()); FusedKerSpec::Done } - FS::Store(_) | FS::AddUnicast(_) => { - self.loc_dependant.push(ld(ix, self.ker_specs.len(), offset)); + FS::AddUnicast(_) => { + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); offset += TI::datum_type().size_of() * ker.mr() * ker.nr(); FusedKerSpec::Done } + FS::Store(store) => { + self.loc_dependant + .push(ld(ix, self.ker_specs.len(), offset)); + offset += store.item_size * ker.mr() * ker.nr(); + FusedKerSpec::Done + } FS::LeakyRelu(t) => FKS::LeakyRelu(*t.to_scalar()?), FS::AddMatMul { a, b, packing } => { let mut ld = ld(ix, self.ker_specs.len(), offset); @@ -206,10 +223,16 @@ impl ScratchSpaceImpl { let err = ker.kernel(tls.ker_specs()); debug_assert_eq!(err, 0, "Kernel return error {err}"); } else { - let remnant_down = - if down < self.valid_down_tiles { ker.mr() } else { self.remnant_down }; - let remnant_right = - if right < self.valid_right_tiles { ker.nr() } else { self.remnant_right }; + let remnant_down = if down < self.valid_down_tiles { + ker.mr() + } else { + self.remnant_down + }; + let remnant_right = if right < self.valid_right_tiles { + ker.nr() + } else { + self.remnant_right + }; self.for_border_tile(ker, specs, tls, down, right, remnant_down, remnant_right)?; let err = ker.kernel(tls.ker_specs()); debug_assert_eq!(err, 0, "Kernel return error {err}"); @@ -230,9 +253,20 @@ impl ScratchSpaceImpl { ) -> TractResult<()> { use FusedKerSpec as FKS; use FusedSpec as FS; - let ScratchSpaceImpl { ker_specs, loc_dependant, .. } = self; + let ScratchSpaceImpl { + ker_specs, + loc_dependant, + .. + } = self; debug_assert!(specs.len() + 2 == ker_specs.len()); - for LocDependant { spec, ker_spec, loc, buffer_a, buffer_b } in loc_dependant { + for LocDependant { + spec, + ker_spec, + loc, + buffer_a, + buffer_b, + } in loc_dependant + { let spec = specs.get_unchecked(*spec); let it = match spec { FS::BinPerRow(v, op) => { @@ -265,8 +299,9 @@ impl ScratchSpaceImpl { FS::AddUnicast(store) => FKS::AddUnicast(store.tile_c(down, right)), FS::Store(c_store) => FKS::Store(c_store.tile_c(down, right)), FS::AddMatMul { a, b, packing } => { - let scratch = - (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp).as_mut().unwrap(); + let scratch = (tls.blob.as_mut_ptr().add(*loc) as *mut AddMatMulTemp) + .as_mut() + .unwrap(); if scratch.panel_a_id != down { scratch.ptr_a = a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; @@ -305,7 +340,14 @@ impl ScratchSpaceImpl { ) -> TractResult<()> { use FusedKerSpec as FKS; use FusedSpec as FS; - for LocDependant { spec, ker_spec: uspec, loc, buffer_a, buffer_b } in &self.loc_dependant { + for LocDependant { + spec, + ker_spec: uspec, + loc, + buffer_a, + buffer_b, + } in &self.loc_dependant + { let loc = tls.blob.as_mut_ptr().add(*loc); let spec = specs.get_unchecked(*spec); let it = match spec { @@ -442,13 +484,13 @@ impl ScratchSpaceImpl { FS::AddMatMul { a, b, packing } => { let scratch = (loc as *mut AddMatMulTemp).as_mut().unwrap(); if scratch.panel_a_id != down { - scratch.ptr_a = a - .panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; + scratch.ptr_a = + a.panel_bytes(down, buffer_a.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_a_id = down; } if scratch.panel_b_id != right { - scratch.ptr_b = b - .panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; + scratch.ptr_b = + b.panel_bytes(right, buffer_b.map(|o| tls.blob.as_mut_ptr().add(o)))?; scratch.panel_b_id = right; } FKS::AddMatMul { @@ -482,7 +524,12 @@ impl ScratchSpaceImpl { where TI: LADatum, { - for LocDependant { spec, ker_spec: uspec, .. } in self.loc_dependant.iter() { + for LocDependant { + spec, + ker_spec: uspec, + .. + } in self.loc_dependant.iter() + { let spec = specs.get_unchecked(*spec); let ker_spec = tls.ker_specs::().get_unchecked(*uspec); if let (FusedSpec::Store(c_store), FusedKerSpec::Store(tmp)) = (spec, ker_spec) { From 898f89168640bfd137c88af08fca785550b557ff Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 17 Jan 2025 14:51:49 +0100 Subject: [PATCH 02/21] feat: minimal mapping vptq empty ops --- core/src/lib.rs | 20 +++++----- core/src/ops/macros.rs | 25 +++++++++++- core/src/ops/mod.rs | 1 + core/src/ops/vptq.rs | 43 ++++++++++++++++++++ nnef/src/ops/core.rs | 2 + nnef/src/ops/core/vptq.rs | 82 +++++++++++++++++++++++++++++++++++++++ 6 files changed, 162 insertions(+), 11 deletions(-) create mode 100644 core/src/ops/vptq.rs create mode 100644 nnef/src/ops/core/vptq.rs diff --git a/core/src/lib.rs b/core/src/lib.rs index 3cfd830932..8994faa1cd 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -45,13 +45,13 @@ //! tract-tensorflow or tract-onnx crates. //! -#[cfg(feature="blas")] -extern crate cblas; -#[cfg(feature="accelerate")] +#[cfg(feature = "accelerate")] extern crate accelerate_src; -#[cfg(feature="blis")] +#[cfg(feature = "blis")] extern crate blis_src; -#[cfg(feature="openblas")] +#[cfg(feature = "blas")] +extern crate cblas; +#[cfg(feature = "openblas")] extern crate openblas_src; extern crate bit_set; @@ -81,8 +81,8 @@ pub mod ops; pub mod axes; pub mod broadcast; -pub mod framework; pub mod floats; +pub mod framework; pub mod model; pub mod optim; pub mod plan; @@ -98,7 +98,7 @@ mod late_bind; pub mod prelude { pub use crate::framework::Framework; pub use crate::model::*; - pub use crate::plan::{SimplePlan, SimpleState, PlanOptions}; + pub use crate::plan::{PlanOptions, SimplePlan, SimpleState}; pub use crate::value::{IntoTValue, TValue}; pub use std::sync::Arc; pub use tract_data::prelude::*; @@ -118,8 +118,9 @@ pub mod internal { pub use crate::ops::change_axes::*; pub use crate::ops::element_wise::ElementWiseMiniOp; pub use crate::ops::{Cost, EvalOp, FrozenOpState, Op, OpState, Validation}; - pub use crate::plan::{ SessionState, SessionStateHandler }; + pub use crate::plan::{SessionState, SessionStateHandler}; pub use crate::prelude::*; + pub use crate::runtime::{DefaultRuntime, Runnable, Runtime, State}; pub use dims; pub use downcast_rs as tract_downcast_rs; pub use std::borrow::Cow; @@ -131,10 +132,9 @@ pub mod internal { dispatch_copy, dispatch_datum, dispatch_datum_by_size, dispatch_floatlike, dispatch_numbers, }; pub use tvec; - pub use {args_1, args_2, args_3, args_4, args_5, args_6, args_7, args_8}; + pub use {args_1, args_2, args_3, args_4, args_5, args_6, args_7, args_8, args_9}; pub use {as_op, impl_op_same_as, not_a_typed_op, op_as_typed_op}; pub use {bin_to_super_type, element_wise, element_wise_oop}; - pub use crate::runtime::{Runtime, Runnable, State, DefaultRuntime}; } #[cfg(test)] diff --git a/core/src/ops/macros.rs b/core/src/ops/macros.rs index 3a9794da14..0ee4f9a19f 100644 --- a/core/src/ops/macros.rs +++ b/core/src/ops/macros.rs @@ -173,6 +173,30 @@ macro_rules! args_8 { }}; } +#[allow(unused_macros)] +#[macro_export] +macro_rules! args_9 { + ($inputs:expr) => {{ + let mut inputs = $inputs; + if inputs.len() != 9 { + $crate::internal::bail!("Expected 9 arg, got {:?}", inputs) + } + inputs.reverse(); + let result = ( + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + inputs.pop().unwrap(), + ); + result + }}; +} + #[macro_export] macro_rules! impl_op_same_as { () => { @@ -233,4 +257,3 @@ macro_rules! trivial_op_state_freeeze { } }; } - diff --git a/core/src/ops/mod.rs b/core/src/ops/mod.rs index 4760d1207b..f4a95d2636 100644 --- a/core/src/ops/mod.rs +++ b/core/src/ops/mod.rs @@ -32,6 +32,7 @@ pub mod scan; pub mod source; pub mod submodel; pub mod unimpl; +pub mod vptq; pub use downsample::Downsample; pub use memory::*; diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs new file mode 100644 index 0000000000..402b88754b --- /dev/null +++ b/core/src/ops/vptq.rs @@ -0,0 +1,43 @@ +use crate::internal::*; + +#[derive(Debug, Clone)] +pub struct VPTQGemm {} + +impl Op for VPTQGemm { + fn name(&self) -> Cow { + "VPTQGemm".into() + } + + op_as_typed_op!(); +} + +impl EvalOp for VPTQGemm { + fn is_stateless(&self) -> bool { + true + } + + fn eval(&self, inputs: TVec) -> TractResult> { + let ( + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ) = args_9!(inputs); + let mut input = input.into_tensor(); + // todo: implement it now! + Ok(tvec!(input.into_tvalue())) + } +} + +impl TypedOp for VPTQGemm { + fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { + Ok(tvec!(inputs[0].without_value())) + } + + as_op!(); +} diff --git a/nnef/src/ops/core.rs b/nnef/src/ops/core.rs index dec4aeec94..2a607e0f76 100644 --- a/nnef/src/ops/core.rs +++ b/nnef/src/ops/core.rs @@ -27,6 +27,7 @@ mod store; mod submodel; mod topk; mod trilu; +mod vptq; pub fn register(registry: &mut Registry) { registry.register_unit_element_wise("tract_core_round_even", &ops::math::RoundHalfToEven {}); @@ -67,4 +68,5 @@ pub fn register(registry: &mut Registry) { range::register(registry); topk::register(registry); trilu::register(registry); + vptq::register(registry); } diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs new file mode 100644 index 0000000000..ddfb34f8aa --- /dev/null +++ b/nnef/src/ops/core/vptq.rs @@ -0,0 +1,82 @@ +use crate::internal::*; +use crate::ser::*; +use tract_core::ops::cast::cast; +use tract_core::ops::vptq::VPTQGemm; + +pub fn register(registry: &mut Registry) { + registry.register_dumper(ser_vptq_gemm); + registry.register_primitive( + "tract_core_vptq_gemm", + &[ + TypeName::Scalar.tensor().named("input"), + TypeName::Scalar.tensor().named("indices"), + TypeName::Scalar.tensor().named("centroids"), + TypeName::Scalar.tensor().named("outlier_indices"), + TypeName::Scalar.tensor().named("outlier_centroids"), + TypeName::Scalar.tensor().named("perm"), + TypeName::Scalar.tensor().named("weight_scale"), + TypeName::Scalar.tensor().named("weight_bias"), + TypeName::Scalar.tensor().named("bias"), + ], + &[("output", TypeName::Scalar.tensor())], + de_vptq_gemm, + ); +} + +fn ser_vptq_gemm( + ast: &mut IntoAst, + node: &TypedNode, + op: &VPTQGemm, +) -> TractResult>> { + let input = ast.mapping[&node.inputs[0]].clone(); + let indices = ast.mapping[&node.inputs[1]].clone(); + let centroids = ast.mapping[&node.inputs[2]].clone(); + let outlier_indices = ast.mapping[&node.inputs[3]].clone(); + let outlier_centroids = ast.mapping[&node.inputs[4]].clone(); + let perm = ast.mapping[&node.inputs[5]].clone(); + let weight_scale = ast.mapping[&node.inputs[6]].clone(); + let weight_bias = ast.mapping[&node.inputs[7]].clone(); + let bias = ast.mapping[&node.inputs[8]].clone(); + Ok(Some(invocation( + "tract_core_vptq_gemm", + &[ + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ], + &[], + ))) +} + +fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> TractResult { + let input = invocation.named_arg_as(builder, "input")?; + let indices = invocation.named_arg_as(builder, "indices")?; + let centroids = invocation.named_arg_as(builder, "centroids")?; + let outlier_indices = invocation.named_arg_as(builder, "outlier_indices")?; + let outlier_centroids = invocation.named_arg_as(builder, "outlier_centroids")?; + let perm = invocation.named_arg_as(builder, "perm")?; + let weight_scale = invocation.named_arg_as(builder, "weight_scale")?; + let weight_bias = invocation.named_arg_as(builder, "weight_bias")?; + let bias = invocation.named_arg_as(builder, "bias")?; + + builder.wire( + VPTQGemm {}, + &[ + input, + indices, + centroids, + outlier_indices, + outlier_centroids, + perm, + weight_scale, + weight_bias, + bias, + ], + ) +} From 963fe5220c665508d3686c041db09c535f8a84ee Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 17 Jan 2025 19:02:20 +0100 Subject: [PATCH 03/21] fix: wip implementation of naive vptq --- core/src/ops/vptq.rs | 129 ++++++++++++++++++++++++++++++++++++-- nnef/src/ops/core/vptq.rs | 19 +++++- 2 files changed, 141 insertions(+), 7 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 402b88754b..f64b3cb7d1 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,7 +1,11 @@ -use crate::internal::*; +use crate::{internal::*, ops::array::GatherElements, ops::einsum::EinSum}; #[derive(Debug, Clone)] -pub struct VPTQGemm {} +pub struct VPTQGemm { + pub vector_len: usize, + pub in_features: usize, + pub out_features: usize, +} impl Op for VPTQGemm { fn name(&self) -> Cow { @@ -11,6 +15,54 @@ impl Op for VPTQGemm { op_as_typed_op!(); } +impl VPTQGemm { + fn eval_extract_from_vector_quant( + &self, + centroids: Tensor, + indices: Tensor, + ) -> TractResult { + let mut indices = indices.clone(); + let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { + unimplemented!("unexected centroid shape ?") + }; + + let [_, _, group_size] = *centroids.shape() else { + unimplemented!("unexected indice shape ?") + }; + + let mut vsh = indices.shape().to_vec(); + indices.insert_axis(3)?; + vsh.push(vector_len); + indices = indices.broadcast_to_shape(&vsh)?; + let intermediate_volume = indices.shape()[1..3].iter().fold(1, |r, x| r * x); + indices = indices.into_shape(&[num_codebooks, intermediate_volume, vector_len])?; + + let gather1 = GatherElements { axis: 1 }; + // selected_centroids = torch.gather(centroids, 1, indices) + let selected_centroids = gather1 + .eval(tvec!(centroids.into(), indices.into()))? + .pop() + .context("apply gather to get selected main centroids") + .unwrap() + .into_tensor(); + + let remain = selected_centroids.volume() / (num_codebooks * group_size * vector_len); + let mut qweight = selected_centroids + .into_shape(&[num_codebooks, remain, group_size, vector_len])? + .permute_axes(&[0, 1, 3, 2])? + .into_shape(&[num_codebooks, remain * vector_len, group_size])? + .permute_axes(&[1, 0, 2])? + .into_shape(&[vector_len * remain, num_codebooks * group_size])?; + + let dim0 = qweight.shape()[0]; + let padding = (-(self.out_features as i16) % vector_len as i16) as usize; + if padding > 0 { + qweight = qweight.slice(0, 0, dim0 - padding)?; + } + Ok(qweight) + } +} + impl EvalOp for VPTQGemm { fn is_stateless(&self) -> bool { true @@ -28,9 +80,76 @@ impl EvalOp for VPTQGemm { weight_bias, bias, ) = args_9!(inputs); - let mut input = input.into_tensor(); - // todo: implement it now! - Ok(tvec!(input.into_tvalue())) + let mut indices = indices.into_tensor(); + let mut centroids = centroids.into_tensor(); + let mut outlier_indices = outlier_indices.into_tensor(); + let mut outlier_centroids = outlier_centroids.into_tensor(); + let mut perm = perm.into_tensor(); + let mut weight_scale = weight_scale.into_tensor(); + let mut weight_bias = weight_bias.into_tensor(); + let mut bias = bias.into_tensor(); + + if weight_scale.len() > 1 { + unimplemented!("'weight scale' for vptq not yet supported !"); + } + if weight_bias.len() > 1 { + unimplemented!("'weight bias' for vptq not yet supported !"); + } + let enable_norm = weight_scale.len() > 1 && weight_bias.len() > 1; + if bias.len() > 1 { + unimplemented!("'bias' for vptq not yet supported !"); + } + assert_eq!(input.rank(), 3); + assert!(input.datum_type().is_float()); + + assert_eq!(indices.rank(), 3); + assert_eq!(indices.datum_type(), DatumType::U16); + assert_eq!(centroids.rank(), 3); + assert!(centroids.datum_type().is_float()); + + let enable_outlier = outlier_indices.len() > 0; + if enable_outlier { + assert_eq!(outlier_indices.rank(), 3); + assert_eq!(outlier_indices.datum_type(), DatumType::U16); + assert_eq!(outlier_centroids.rank(), 3); + assert_eq!(outlier_centroids.datum_type().is_float()); + } + + let mut qweight = self.eval_extract_from_vector_quant(centroids, indices)?; + if enable_outlier { + // same as centroids to qweights except for outlier + let outlier_qweight = + self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices)?; + // qweight = torch.cat([qweight_outlier, qweight], dim=1) + qweight = Tensor::stack_tensors(0, &[qweight, outlier_qweight])?; + } + + // let enable_perm = perm.len() <= 1; + // if enable_perm { + // // TODO: manage case with packed indice + // // if self.is_indice_packed: + // // invert_perm = torch.argsort( + // // self.perm.to(torch.uint16).to(torch.int64) + // // ) + // // else: + // // invert_perm = torch.argsort(self.perm) + // // if self.vector_quant_dim == "in": + // // assert True, "Not implemented" + // // # qweight = qweight[invert_perm, :] + // // else: + // // qweight = qweight[:, invert_perm] + // let invert_perm = perm.into_tensor().argsort(); + // } + + if enable_norm { + qweight = (qweight.into_array::()? * weight_scale.into_array::()? + + weight_bias.into_array::()?) + .into_tensor(); + } + // call matmul now with qweight + + let einsum_op = EinSum::new("bik,kj->ij".parse()?, f32::datum_type()); + einsum_op.eval(tvec!(input, qweight.into_tvalue())) } } diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs index ddfb34f8aa..224d16c451 100644 --- a/nnef/src/ops/core/vptq.rs +++ b/nnef/src/ops/core/vptq.rs @@ -17,6 +17,9 @@ pub fn register(registry: &mut Registry) { TypeName::Scalar.tensor().named("weight_scale"), TypeName::Scalar.tensor().named("weight_bias"), TypeName::Scalar.tensor().named("bias"), + TypeName::Integer.named("vector_len"), + TypeName::Integer.tensor().named("in_features"), + TypeName::Integer.tensor().named("out_features"), ], &[("output", TypeName::Scalar.tensor())], de_vptq_gemm, @@ -37,6 +40,10 @@ fn ser_vptq_gemm( let weight_scale = ast.mapping[&node.inputs[6]].clone(); let weight_bias = ast.mapping[&node.inputs[7]].clone(); let bias = ast.mapping[&node.inputs[8]].clone(); + + let vector_len = ast.mapping[&node.inputs[9]].clone(); + let in_features = ast.mapping[&node.inputs[10]].clone(); + let out_features = ast.mapping[&node.inputs[11]].clone(); Ok(Some(invocation( "tract_core_vptq_gemm", &[ @@ -50,7 +57,11 @@ fn ser_vptq_gemm( weight_bias, bias, ], - &[], + &[ + ("vector_len", numeric(vector_len)), + ("in_features", numeric(in_features)), + ("out_features", numeric(out_features)), + ], ))) } @@ -65,8 +76,12 @@ fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> let weight_bias = invocation.named_arg_as(builder, "weight_bias")?; let bias = invocation.named_arg_as(builder, "bias")?; + let vector_len = invocation.named_arg_as(builder, "vector_len")?; + let in_features = invocation.named_arg_as(builder, "in_features")?; + let out_features = invocation.named_arg_as(builder, "out_features")?; + builder.wire( - VPTQGemm {}, + VPTQGemm { vector_len, in_features, out_features }, &[ input, indices, From 1cdc707f42ed7f3c1ef9efd4372f78bc958c22af Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 17 Jan 2025 19:05:13 +0100 Subject: [PATCH 04/21] fix: . --- core/src/ops/vptq.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index f64b3cb7d1..c8fef533b1 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -112,7 +112,7 @@ impl EvalOp for VPTQGemm { assert_eq!(outlier_indices.rank(), 3); assert_eq!(outlier_indices.datum_type(), DatumType::U16); assert_eq!(outlier_centroids.rank(), 3); - assert_eq!(outlier_centroids.datum_type().is_float()); + assert!(outlier_centroids.datum_type().is_float()); } let mut qweight = self.eval_extract_from_vector_quant(centroids, indices)?; From 95080d2c2136db5e1fa7844a9d5df571f83a2cdf Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 17 Jan 2025 19:37:53 +0100 Subject: [PATCH 05/21] fix: tfact --- core/src/ops/vptq.rs | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index c8fef533b1..45c895e40f 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -99,7 +99,7 @@ impl EvalOp for VPTQGemm { if bias.len() > 1 { unimplemented!("'bias' for vptq not yet supported !"); } - assert_eq!(input.rank(), 3); + assert_eq!(input.rank(), 2); assert!(input.datum_type().is_float()); assert_eq!(indices.rank(), 3); @@ -148,14 +148,16 @@ impl EvalOp for VPTQGemm { } // call matmul now with qweight - let einsum_op = EinSum::new("bik,kj->ij".parse()?, f32::datum_type()); + let einsum_op = EinSum::new("ik,kj->".parse()?, f32::datum_type()); einsum_op.eval(tvec!(input, qweight.into_tvalue())) } } impl TypedOp for VPTQGemm { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { - Ok(tvec!(inputs[0].without_value())) + let mut tfact = inputs[0].without_value(); + tfact.shape.set(1, self.out_features.into()); + Ok(tvec!(tfact)) } as_op!(); From cbef6c37b77a28606ae0a5e18f74dd13457a8531 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Mon, 20 Jan 2025 11:55:15 +0100 Subject: [PATCH 06/21] fix: simple vptq case working --- core/src/ops/vptq.rs | 85 +++++++++++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 32 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 45c895e40f..69dea2e14f 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,4 +1,10 @@ -use crate::{internal::*, ops::array::GatherElements, ops::einsum::EinSum}; +use crate::{ + internal::*, + ops::{ + array::{GatherElements, Topk}, + einsum::EinSum, + }, +}; #[derive(Debug, Clone)] pub struct VPTQGemm { @@ -20,13 +26,17 @@ impl VPTQGemm { &self, centroids: Tensor, indices: Tensor, + is_indice_packed: bool, ) -> TractResult { + if is_indice_packed { + unimplemented!("unpacking indices not implemented yet !"); + } let mut indices = indices.clone(); let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { unimplemented!("unexected centroid shape ?") }; - let [_, _, group_size] = *centroids.shape() else { + let [_, _, group_size] = *indices.shape() else { unimplemented!("unexected indice shape ?") }; @@ -80,14 +90,14 @@ impl EvalOp for VPTQGemm { weight_bias, bias, ) = args_9!(inputs); - let mut indices = indices.into_tensor(); - let mut centroids = centroids.into_tensor(); - let mut outlier_indices = outlier_indices.into_tensor(); - let mut outlier_centroids = outlier_centroids.into_tensor(); - let mut perm = perm.into_tensor(); - let mut weight_scale = weight_scale.into_tensor(); - let mut weight_bias = weight_bias.into_tensor(); - let mut bias = bias.into_tensor(); + let indices = indices.into_tensor(); + let centroids = centroids.into_tensor(); + let outlier_indices = outlier_indices.into_tensor(); + let outlier_centroids = outlier_centroids.into_tensor(); + let perm = perm.into_tensor(); + let weight_scale = weight_scale.into_tensor(); + let weight_bias = weight_bias.into_tensor(); + let bias = bias.into_tensor(); if weight_scale.len() > 1 { unimplemented!("'weight scale' for vptq not yet supported !"); @@ -115,31 +125,42 @@ impl EvalOp for VPTQGemm { assert!(outlier_centroids.datum_type().is_float()); } - let mut qweight = self.eval_extract_from_vector_quant(centroids, indices)?; + let is_indice_packed = false; // TODO: apply + let mut qweight = + self.eval_extract_from_vector_quant(centroids, indices, is_indice_packed)?; if enable_outlier { // same as centroids to qweights except for outlier - let outlier_qweight = - self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices)?; - // qweight = torch.cat([qweight_outlier, qweight], dim=1) - qweight = Tensor::stack_tensors(0, &[qweight, outlier_qweight])?; + let outlier_qweight = self.eval_extract_from_vector_quant( + outlier_centroids, + outlier_indices, + is_indice_packed, + )?; + qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; } - // let enable_perm = perm.len() <= 1; - // if enable_perm { - // // TODO: manage case with packed indice - // // if self.is_indice_packed: - // // invert_perm = torch.argsort( - // // self.perm.to(torch.uint16).to(torch.int64) - // // ) - // // else: - // // invert_perm = torch.argsort(self.perm) - // // if self.vector_quant_dim == "in": - // // assert True, "Not implemented" - // // # qweight = qweight[invert_perm, :] - // // else: - // // qweight = qweight[:, invert_perm] - // let invert_perm = perm.into_tensor().argsort(); - // } + let enable_perm = perm.len() > 1; + if enable_perm { + unimplemented!("permutation not implemented yet"); + // if is_indice_packed { + // unimplemented!("permutation not implemented yet with indice packed"); + // // if self.is_indice_packed: + // // invert_perm = torch.argsort( + // // self.perm.to(torch.uint16).to(torch.int64) + // // ) + // } else { + // let axis = 1; + // top_k = Topk { axis= axis, largest=false}; + // let dim= perm.shape()[axis].into_tensor(); + // invert_perm = top_k.eval(tvec!(perm, dim)); + // } + // // // TODO: manage case with packed indice + // // // if self.vector_quant_dim == "in": + // // // assert True, "Not implemented" + // // // qweight = qweight[invert_perm, :] + // // // else: + // // // qweight = qweight[:, invert_perm] + // qweight = qweight[:, invert_perm]; + } if enable_norm { qweight = (qweight.into_array::()? * weight_scale.into_array::()? @@ -149,7 +170,7 @@ impl EvalOp for VPTQGemm { // call matmul now with qweight let einsum_op = EinSum::new("ik,kj->".parse()?, f32::datum_type()); - einsum_op.eval(tvec!(input, qweight.into_tvalue())) + einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) } } From 199144bab90e90cafa9295d160f5a2e09600830d Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Mon, 20 Jan 2025 15:42:58 +0100 Subject: [PATCH 07/21] fix: add perm support --- core/src/ops/vptq.rs | 50 +++++++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 69dea2e14f..61533feea7 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,7 +1,8 @@ use crate::{ internal::*, ops::{ - array::{GatherElements, Topk}, + array::{Gather, GatherElements, Topk}, + cast::cast, einsum::EinSum, }, }; @@ -140,26 +141,33 @@ impl EvalOp for VPTQGemm { let enable_perm = perm.len() > 1; if enable_perm { - unimplemented!("permutation not implemented yet"); - // if is_indice_packed { - // unimplemented!("permutation not implemented yet with indice packed"); - // // if self.is_indice_packed: - // // invert_perm = torch.argsort( - // // self.perm.to(torch.uint16).to(torch.int64) - // // ) - // } else { - // let axis = 1; - // top_k = Topk { axis= axis, largest=false}; - // let dim= perm.shape()[axis].into_tensor(); - // invert_perm = top_k.eval(tvec!(perm, dim)); - // } - // // // TODO: manage case with packed indice - // // // if self.vector_quant_dim == "in": - // // // assert True, "Not implemented" - // // // qweight = qweight[invert_perm, :] - // // // else: - // // // qweight = qweight[:, invert_perm] - // qweight = qweight[:, invert_perm]; + let axis = 0; + let dim = perm.shape()[0]; + let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; + let invert_perm = top_k + .eval(tvec!( + if is_indice_packed { + unimplemented!("permutation not implemented yet with indice packed"); + // self.perm.to(torch.uint16).to(torch.int64) + } else { + perm.into_tvalue() + }, + tensor0(dim as u16).into() + ))? + .remove(0); + // TODO: manage case with quant dim == 'in' ? + // if self.vector_quant_dim == "in": + // assert True, "Not implemented" + // qweight = qweight[invert_perm, :] + + let perm_gather_axis = 1; + let gather_perm = Gather { axis: perm_gather_axis }; + qweight = gather_perm + .eval(tvec!(qweight.into(), invert_perm))? + .pop() + .context("apply gather to permutation") + .unwrap() + .into_tensor(); } if enable_norm { From 00650a35f883853a7ca856d179f0f3e281ef2d17 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Mon, 20 Jan 2025 18:21:39 +0100 Subject: [PATCH 08/21] vptq index bits --- core/src/ops/vptq.rs | 62 ++++++++++++++++++++++++++++++++++----- nnef/src/ops/core/vptq.rs | 3 +- 2 files changed, 56 insertions(+), 9 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 61533feea7..ee3b1ebff5 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -2,7 +2,6 @@ use crate::{ internal::*, ops::{ array::{Gather, GatherElements, Topk}, - cast::cast, einsum::EinSum, }, }; @@ -12,6 +11,7 @@ pub struct VPTQGemm { pub vector_len: usize, pub in_features: usize, pub out_features: usize, + pub is_indice_packed: bool, } impl Op for VPTQGemm { @@ -23,17 +23,64 @@ impl Op for VPTQGemm { } impl VPTQGemm { + fn eval_unpack_index_tensor( + &self, + pack_tensor: Tensor, + index_bits: usize, + num_elements: usize, + ) -> TractResult { + // + // + // TODO: implement decompression of index + // + // total_bits = index_bits + res_bits + // wf = torch.arange(0, 32, 1).to(pack_tensor.device).view(1, 1, 1, -1) + // out = torch.bitwise_right_shift(torch.unsqueeze(pack_tensor, -1), wf) + // torch.bitwise_and(out, 1, out=out) + // pad_size = (pack_tensor.shape[-1] * 32) % ( + // index_bits * num_elements + res_bits * num_res_elements + // ) + // out = out.reshape(*pack_tensor.shape[:-1], -1) + // if pad_size > 0: + // out = out[..., :-pad_size] + // out = out.reshape(*pack_tensor.shape[:-1], -1, total_bits) + // wf1 = torch.arange(0, total_bits, 1).to(pack_tensor.device).view(1, 1, 1, -1) + // out = torch.bitwise_left_shift(out, wf1).sum(dim=-1) + // + // unpack_indice = out.to(torch.uint64).view(torch.int64) + // + // indices = ( + // (unpack_indice & ((1 << index_bits) - 1)).view(torch.uint64).to(torch.int64) + // ) + // + // # indices = indices.squeeze() + // + // if res_bits > 0: + // res_indices = ( + // ((unpack_indice >> index_bits) & ((1 << index_bits) - 1)) + // .view(torch.uint64) + // .to(torch.int64) + // ) + // # res_indices = res_indices.squeeze() + // else: + // res_indices = None + // + // return indices, res_indices + + Ok(tensor0(0)) + } + fn eval_extract_from_vector_quant( &self, centroids: Tensor, indices: Tensor, is_indice_packed: bool, ) -> TractResult { - if is_indice_packed { + if self.is_indice_packed { unimplemented!("unpacking indices not implemented yet !"); } let mut indices = indices.clone(); - let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { + let [num_codebooks, _num_centroids, vector_len] = *centroids.shape() else { unimplemented!("unexected centroid shape ?") }; @@ -45,7 +92,7 @@ impl VPTQGemm { indices.insert_axis(3)?; vsh.push(vector_len); indices = indices.broadcast_to_shape(&vsh)?; - let intermediate_volume = indices.shape()[1..3].iter().fold(1, |r, x| r * x); + let intermediate_volume = indices.shape()[1..3].iter().product(); indices = indices.into_shape(&[num_codebooks, intermediate_volume, vector_len])?; let gather1 = GatherElements { axis: 1 }; @@ -126,15 +173,14 @@ impl EvalOp for VPTQGemm { assert!(outlier_centroids.datum_type().is_float()); } - let is_indice_packed = false; // TODO: apply let mut qweight = - self.eval_extract_from_vector_quant(centroids, indices, is_indice_packed)?; + self.eval_extract_from_vector_quant(centroids, indices, self.is_indice_packed)?; if enable_outlier { // same as centroids to qweights except for outlier let outlier_qweight = self.eval_extract_from_vector_quant( outlier_centroids, outlier_indices, - is_indice_packed, + self.is_indice_packed, )?; qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; } @@ -146,7 +192,7 @@ impl EvalOp for VPTQGemm { let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; let invert_perm = top_k .eval(tvec!( - if is_indice_packed { + if self.is_indice_packed { unimplemented!("permutation not implemented yet with indice packed"); // self.perm.to(torch.uint16).to(torch.int64) } else { diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs index 224d16c451..433fd2f0cd 100644 --- a/nnef/src/ops/core/vptq.rs +++ b/nnef/src/ops/core/vptq.rs @@ -79,9 +79,10 @@ fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> let vector_len = invocation.named_arg_as(builder, "vector_len")?; let in_features = invocation.named_arg_as(builder, "in_features")?; let out_features = invocation.named_arg_as(builder, "out_features")?; + let is_indice_packed = invocation.named_arg_as(builder, "is_indice_packed")?; builder.wire( - VPTQGemm { vector_len, in_features, out_features }, + VPTQGemm { vector_len, in_features, out_features, is_indice_packed }, &[ input, indices, From 186ff57c3be46e6ee74d57411d6209f70f19747d Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Tue, 21 Jan 2025 16:45:43 +0100 Subject: [PATCH 09/21] feat: basic translate from py to rust decompress --- core/src/ops/vptq.rs | 119 +++++++++++++++++++++++++------------------ 1 file changed, 69 insertions(+), 50 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index ee3b1ebff5..9d89e8b1aa 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,8 +1,13 @@ +use tract_ndarray::Array1; + use crate::{ internal::*, ops::{ array::{Gather, GatherElements, Topk}, einsum::EinSum, + logic::and, + math::shift_left, + math::shift_right, }, }; @@ -23,62 +28,75 @@ impl Op for VPTQGemm { } impl VPTQGemm { + /// decompression of indexes fn eval_unpack_index_tensor( &self, pack_tensor: Tensor, index_bits: usize, num_elements: usize, ) -> TractResult { - // - // - // TODO: implement decompression of index - // - // total_bits = index_bits + res_bits - // wf = torch.arange(0, 32, 1).to(pack_tensor.device).view(1, 1, 1, -1) - // out = torch.bitwise_right_shift(torch.unsqueeze(pack_tensor, -1), wf) - // torch.bitwise_and(out, 1, out=out) - // pad_size = (pack_tensor.shape[-1] * 32) % ( - // index_bits * num_elements + res_bits * num_res_elements - // ) - // out = out.reshape(*pack_tensor.shape[:-1], -1) - // if pad_size > 0: - // out = out[..., :-pad_size] - // out = out.reshape(*pack_tensor.shape[:-1], -1, total_bits) - // wf1 = torch.arange(0, total_bits, 1).to(pack_tensor.device).view(1, 1, 1, -1) - // out = torch.bitwise_left_shift(out, wf1).sum(dim=-1) - // - // unpack_indice = out.to(torch.uint64).view(torch.int64) - // - // indices = ( - // (unpack_indice & ((1 << index_bits) - 1)).view(torch.uint64).to(torch.int64) - // ) - // - // # indices = indices.squeeze() - // - // if res_bits > 0: - // res_indices = ( - // ((unpack_indice >> index_bits) & ((1 << index_bits) - 1)) - // .view(torch.uint64) - // .to(torch.int64) - // ) - // # res_indices = res_indices.squeeze() - // else: - // res_indices = None - // - // return indices, res_indices - - Ok(tensor0(0)) + dbg!("A"); + let wf = Tensor::from(Array1::from_iter(0..32u16).to_shape([1, 1, 1, 32])?.into_owned()) + .cast_to_dt(DatumType::U16)? + .into_owned(); + + dbg!("B"); + let mut pack_tensor_shape = pack_tensor.shape().to_vec(); + pack_tensor_shape.push(1); + + dbg!("C"); + let mut out = shift_right() + .eval(tvec!(pack_tensor.clone().into_shape(&pack_tensor_shape)?.into(), wf.into()))? + .pop() + .unwrap(); + + dbg!("D"); + out = and().eval(tvec!(out, tensor0(1).into()))?.pop().unwrap(); + + let pad_size = (pack_tensor_shape[2] * 32) % (index_bits * num_elements); + + dbg!("E"); + let dim_idx = pack_tensor_shape.len() - 1; + pack_tensor_shape[dim_idx] = out.volume() / pack_tensor.volume(); + out = out.into_tensor().clone().into_shape(&pack_tensor_shape)?.into_tvalue(); + if pad_size > 0 { + let end = out.shape()[out.rank() - 1] - pad_size; + out = out.slice(out.rank(), 0, end)?.into(); + } + + dbg!("F"); + pack_tensor_shape.pop(); + let auto = out.volume() / pack_tensor_shape.iter().product::() / index_bits; + pack_tensor_shape.push(auto); + pack_tensor_shape.push(index_bits); + out = out.into_tensor().into_shape(&pack_tensor_shape)?.into(); + + dbg!("G"); + let wf1 = Tensor::from( + Array1::from_iter(0..(index_bits as u16)).to_shape([1, 1, 1, index_bits])?.into_owned(), + ); + + out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); + + dbg!("H"); + let unpack_indice = out.cast_to_dt(DatumType::U16)?; + + let mut indices = + unsafe { Tensor::uninitialized_dt(DatumType::U16, unpack_indice.shape())? }; + + dbg!("I"); + crate::ndarray::Zip::from(&mut indices.to_array_view_mut::()?) + .and_broadcast(unpack_indice.to_array_view::()?) + .for_each(|indice, upack_indice| *indice = upack_indice & ((1 << index_bits) - 1)); + + Ok(indices) } fn eval_extract_from_vector_quant( &self, centroids: Tensor, indices: Tensor, - is_indice_packed: bool, ) -> TractResult { - if self.is_indice_packed { - unimplemented!("unpacking indices not implemented yet !"); - } let mut indices = indices.clone(); let [num_codebooks, _num_centroids, vector_len] = *centroids.shape() else { unimplemented!("unexected centroid shape ?") @@ -89,6 +107,11 @@ impl VPTQGemm { }; let mut vsh = indices.shape().to_vec(); + if self.is_indice_packed { + // unimplemented!("unpacking indices not implemented yet !"); + let index_bits = (_num_centroids as f32).log2().ceil() as usize; + indices = self.eval_unpack_index_tensor(indices, index_bits, _num_centroids)?; + } indices.insert_axis(3)?; vsh.push(vector_len); indices = indices.broadcast_to_shape(&vsh)?; @@ -173,15 +196,11 @@ impl EvalOp for VPTQGemm { assert!(outlier_centroids.datum_type().is_float()); } - let mut qweight = - self.eval_extract_from_vector_quant(centroids, indices, self.is_indice_packed)?; + let mut qweight = self.eval_extract_from_vector_quant(centroids, indices)?; if enable_outlier { // same as centroids to qweights except for outlier - let outlier_qweight = self.eval_extract_from_vector_quant( - outlier_centroids, - outlier_indices, - self.is_indice_packed, - )?; + let outlier_qweight = + self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices)?; qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; } From 76f3164c4fca1bfba62161d9395acd61fec94704 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Wed, 22 Jan 2025 16:02:32 +0100 Subject: [PATCH 10/21] fix: make packed indexes works --- core/src/ops/vptq.rs | 81 +++++++++++++++++++++++--------------------- 1 file changed, 42 insertions(+), 39 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 9d89e8b1aa..c7da415ce5 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,3 +1,4 @@ +use tract_data::itertools::Itertools; use tract_ndarray::Array1; use crate::{ @@ -26,6 +27,17 @@ impl Op for VPTQGemm { op_as_typed_op!(); } +fn shift_right_zero_and_1(input: TValue, shift_value: TValue) -> TractResult { + let input = input.to_array_view::()?; + let shift_value = shift_value.to_array_view::()?; + let out_shape = crate::broadcast::multi_broadcast(&[input.shape(), shift_value.shape()])?; + let mut out = Tensor::zero_dt(DatumType::U16, &out_shape)?; + crate::ndarray::Zip::from(out.to_array_view_mut::()?) + .and_broadcast(input) + .and_broadcast(shift_value) + .for_each(|c, a, b| *c = a.checked_shr(*b as u32).unwrap_or(0u16) & 1u16); + Ok(out.into_tvalue()) +} impl VPTQGemm { /// decompression of indexes @@ -35,59 +47,57 @@ impl VPTQGemm { index_bits: usize, num_elements: usize, ) -> TractResult { - dbg!("A"); - let wf = Tensor::from(Array1::from_iter(0..32u16).to_shape([1, 1, 1, 32])?.into_owned()) - .cast_to_dt(DatumType::U16)? - .into_owned(); + // let wf = Tensor::from(Array1::from_iter(0..32u16).to_shape([1, 1, 1, 32])?.into_owned()); + // can be reexpressed + let wf = tensor1(&(0..32u16).collect_vec()).into_shape(&[1, 1, 1, 32])?; - dbg!("B"); let mut pack_tensor_shape = pack_tensor.shape().to_vec(); pack_tensor_shape.push(1); - dbg!("C"); - let mut out = shift_right() - .eval(tvec!(pack_tensor.clone().into_shape(&pack_tensor_shape)?.into(), wf.into()))? - .pop() - .unwrap(); - - dbg!("D"); - out = and().eval(tvec!(out, tensor0(1).into()))?.pop().unwrap(); + let mut out = shift_right_zero_and_1( + pack_tensor.clone().into_shape(&pack_tensor_shape)?.into(), + wf.into(), + )?; let pad_size = (pack_tensor_shape[2] * 32) % (index_bits * num_elements); - dbg!("E"); - let dim_idx = pack_tensor_shape.len() - 1; - pack_tensor_shape[dim_idx] = out.volume() / pack_tensor.volume(); + let mut pack_tensor_shape = pack_tensor.shape().to_vec(); + pack_tensor_shape.pop(); + pack_tensor_shape.push(out.volume() / pack_tensor.volume()); out = out.into_tensor().clone().into_shape(&pack_tensor_shape)?.into_tvalue(); if pad_size > 0 { let end = out.shape()[out.rank() - 1] - pad_size; out = out.slice(out.rank(), 0, end)?.into(); } - dbg!("F"); pack_tensor_shape.pop(); let auto = out.volume() / pack_tensor_shape.iter().product::() / index_bits; pack_tensor_shape.push(auto); pack_tensor_shape.push(index_bits); out = out.into_tensor().into_shape(&pack_tensor_shape)?.into(); - dbg!("G"); let wf1 = Tensor::from( Array1::from_iter(0..(index_bits as u16)).to_shape([1, 1, 1, index_bits])?.into_owned(), ); out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); - dbg!("H"); + let axis = out.rank() - 1; + out = out + .into_tensor() + .into_array::()? + .sum_axis(tract_ndarray::Axis(axis)) + .into_tvalue(); + let unpack_indice = out.cast_to_dt(DatumType::U16)?; let mut indices = unsafe { Tensor::uninitialized_dt(DatumType::U16, unpack_indice.shape())? }; - dbg!("I"); crate::ndarray::Zip::from(&mut indices.to_array_view_mut::()?) .and_broadcast(unpack_indice.to_array_view::()?) .for_each(|indice, upack_indice| *indice = upack_indice & ((1 << index_bits) - 1)); + indices = indices.slice(2, 0, num_elements)?; Ok(indices) } @@ -97,8 +107,9 @@ impl VPTQGemm { centroids: Tensor, indices: Tensor, ) -> TractResult { + /// TODO: instead use ndarray for indices to use views transform (--) let mut indices = indices.clone(); - let [num_codebooks, _num_centroids, vector_len] = *centroids.shape() else { + let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { unimplemented!("unexected centroid shape ?") }; @@ -109,12 +120,12 @@ impl VPTQGemm { let mut vsh = indices.shape().to_vec(); if self.is_indice_packed { // unimplemented!("unpacking indices not implemented yet !"); - let index_bits = (_num_centroids as f32).log2().ceil() as usize; - indices = self.eval_unpack_index_tensor(indices, index_bits, _num_centroids)?; + let index_bits = (num_centroids as f32).log2().ceil() as usize; + indices = self.eval_unpack_index_tensor(indices, index_bits, 1)?; } indices.insert_axis(3)?; vsh.push(vector_len); - indices = indices.broadcast_to_shape(&vsh)?; + indices = indices.broadcast_to_shape(&vsh)?; // NOTE: costly in tract (applied in memory but not in ndarray) let intermediate_volume = indices.shape()[1..3].iter().product(); indices = indices.into_shape(&[num_codebooks, intermediate_volume, vector_len])?; @@ -128,11 +139,12 @@ impl VPTQGemm { .into_tensor(); let remain = selected_centroids.volume() / (num_codebooks * group_size * vector_len); + // split_axes let mut qweight = selected_centroids .into_shape(&[num_codebooks, remain, group_size, vector_len])? - .permute_axes(&[0, 1, 3, 2])? + .permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory) .into_shape(&[num_codebooks, remain * vector_len, group_size])? - .permute_axes(&[1, 0, 2])? + .permute_axes(&[1, 0, 2])?// NOTE: costly in tract (applied in memory) .into_shape(&[vector_len * remain, num_codebooks * group_size])?; let dim0 = qweight.shape()[0]; @@ -209,17 +221,8 @@ impl EvalOp for VPTQGemm { let axis = 0; let dim = perm.shape()[0]; let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; - let invert_perm = top_k - .eval(tvec!( - if self.is_indice_packed { - unimplemented!("permutation not implemented yet with indice packed"); - // self.perm.to(torch.uint16).to(torch.int64) - } else { - perm.into_tvalue() - }, - tensor0(dim as u16).into() - ))? - .remove(0); + let invert_perm = + top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(0); // TODO: manage case with quant dim == 'in' ? // if self.vector_quant_dim == "in": // assert True, "Not implemented" @@ -241,8 +244,8 @@ impl EvalOp for VPTQGemm { .into_tensor(); } // call matmul now with qweight - - let einsum_op = EinSum::new("ik,kj->".parse()?, f32::datum_type()); + let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); + // einsum -> matmul imperatif einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) } } From 215a4608292b0747f6462ebcb7526b26c20372bc Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Thu, 23 Jan 2025 18:06:02 +0100 Subject: [PATCH 11/21] feat: working vptq --- core/src/ops/vptq.rs | 117 +++++++++++++++++++++++--------------- nnef/src/ops/core/vptq.rs | 14 ++++- 2 files changed, 83 insertions(+), 48 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index c7da415ce5..86f639a6e3 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -6,11 +6,10 @@ use crate::{ ops::{ array::{Gather, GatherElements, Topk}, einsum::EinSum, - logic::and, math::shift_left, - math::shift_right, }, }; +use tract_linalg::{mmm::{FusedSpec, Packing}, ops}; #[derive(Debug, Clone)] pub struct VPTQGemm { @@ -18,6 +17,8 @@ pub struct VPTQGemm { pub in_features: usize, pub out_features: usize, pub is_indice_packed: bool, + pub group_size: usize, + pub outlier_size: usize } impl Op for VPTQGemm { @@ -28,14 +29,14 @@ impl Op for VPTQGemm { op_as_typed_op!(); } fn shift_right_zero_and_1(input: TValue, shift_value: TValue) -> TractResult { - let input = input.to_array_view::()?; - let shift_value = shift_value.to_array_view::()?; + let input = input.to_array_view::()?; + let shift_value = shift_value.to_array_view::()?; let out_shape = crate::broadcast::multi_broadcast(&[input.shape(), shift_value.shape()])?; - let mut out = Tensor::zero_dt(DatumType::U16, &out_shape)?; - crate::ndarray::Zip::from(out.to_array_view_mut::()?) + let mut out = Tensor::zero_dt(DatumType::I32, &out_shape)?; + crate::ndarray::Zip::from(out.to_array_view_mut::()?) .and_broadcast(input) .and_broadcast(shift_value) - .for_each(|c, a, b| *c = a.checked_shr(*b as u32).unwrap_or(0u16) & 1u16); + .for_each(|c, a, b| *c = a.checked_shr(*b as u32).unwrap_or(0i32) & 1i32); Ok(out.into_tvalue()) } @@ -47,37 +48,39 @@ impl VPTQGemm { index_bits: usize, num_elements: usize, ) -> TractResult { - // let wf = Tensor::from(Array1::from_iter(0..32u16).to_shape([1, 1, 1, 32])?.into_owned()); - // can be reexpressed - let wf = tensor1(&(0..32u16).collect_vec()).into_shape(&[1, 1, 1, 32])?; + let wf = tensor1(&(0..32i32).collect_vec()).into_shape(&[1, 1, 1, 32])?; - let mut pack_tensor_shape = pack_tensor.shape().to_vec(); - pack_tensor_shape.push(1); + let pack_tensor_shape = pack_tensor.shape().to_vec(); + + + let mut pre_shift_pack_tensor_shape = pack_tensor_shape.clone(); + pre_shift_pack_tensor_shape.push(1); let mut out = shift_right_zero_and_1( - pack_tensor.clone().into_shape(&pack_tensor_shape)?.into(), + pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(), wf.into(), )?; - let pad_size = (pack_tensor_shape[2] * 32) % (index_bits * num_elements); + let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone(); + let pval = post_shift_pack_tensor_shape.pop().unwrap(); + post_shift_pack_tensor_shape.push(32 * pval); + out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); - let mut pack_tensor_shape = pack_tensor.shape().to_vec(); - pack_tensor_shape.pop(); - pack_tensor_shape.push(out.volume() / pack_tensor.volume()); - out = out.into_tensor().clone().into_shape(&pack_tensor_shape)?.into_tvalue(); + let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); if pad_size > 0 { let end = out.shape()[out.rank() - 1] - pad_size; - out = out.slice(out.rank(), 0, end)?.into(); + out = out.slice(out.rank() - 1, 0, end)?.into(); } - pack_tensor_shape.pop(); - let auto = out.volume() / pack_tensor_shape.iter().product::() / index_bits; - pack_tensor_shape.push(auto); - pack_tensor_shape.push(index_bits); - out = out.into_tensor().into_shape(&pack_tensor_shape)?.into(); + let mut post_pad_pack_tensor_shape = pack_tensor_shape.clone(); + post_pad_pack_tensor_shape.pop(); + let auto = out.shape().last().unwrap() / index_bits; + post_pad_pack_tensor_shape.push(auto); + post_pad_pack_tensor_shape.push(index_bits); + out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into(); let wf1 = Tensor::from( - Array1::from_iter(0..(index_bits as u16)).to_shape([1, 1, 1, index_bits])?.into_owned(), + Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(), ); out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); @@ -85,17 +88,17 @@ impl VPTQGemm { let axis = out.rank() - 1; out = out .into_tensor() - .into_array::()? + .into_array::()? .sum_axis(tract_ndarray::Axis(axis)) .into_tvalue(); - let unpack_indice = out.cast_to_dt(DatumType::U16)?; + let unpack_indice = out.cast_to_dt(DatumType::I32)?; let mut indices = - unsafe { Tensor::uninitialized_dt(DatumType::U16, unpack_indice.shape())? }; + unsafe { Tensor::uninitialized_dt(DatumType::I32, unpack_indice.shape())? }; - crate::ndarray::Zip::from(&mut indices.to_array_view_mut::()?) - .and_broadcast(unpack_indice.to_array_view::()?) + crate::ndarray::Zip::from(&mut indices.to_array_view_mut::()?) + .and_broadcast(unpack_indice.to_array_view::()?) .for_each(|indice, upack_indice| *indice = upack_indice & ((1 << index_bits) - 1)); indices = indices.slice(2, 0, num_elements)?; @@ -106,6 +109,7 @@ impl VPTQGemm { &self, centroids: Tensor, indices: Tensor, + group_size: usize ) -> TractResult { /// TODO: instead use ndarray for indices to use views transform (--) let mut indices = indices.clone(); @@ -113,16 +117,13 @@ impl VPTQGemm { unimplemented!("unexected centroid shape ?") }; - let [_, _, group_size] = *indices.shape() else { - unimplemented!("unexected indice shape ?") - }; - - let mut vsh = indices.shape().to_vec(); if self.is_indice_packed { // unimplemented!("unpacking indices not implemented yet !"); let index_bits = (num_centroids as f32).log2().ceil() as usize; - indices = self.eval_unpack_index_tensor(indices, index_bits, 1)?; + indices = self.eval_unpack_index_tensor(indices, index_bits, group_size)?; } + + let mut vsh = indices.shape().to_vec(); indices.insert_axis(3)?; vsh.push(vector_len); indices = indices.broadcast_to_shape(&vsh)?; // NOTE: costly in tract (applied in memory but not in ndarray) @@ -139,7 +140,7 @@ impl VPTQGemm { .into_tensor(); let remain = selected_centroids.volume() / (num_codebooks * group_size * vector_len); - // split_axes + let mut qweight = selected_centroids .into_shape(&[num_codebooks, remain, group_size, vector_len])? .permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory) @@ -196,23 +197,23 @@ impl EvalOp for VPTQGemm { assert!(input.datum_type().is_float()); assert_eq!(indices.rank(), 3); - assert_eq!(indices.datum_type(), DatumType::U16); + assert_eq!(indices.datum_type(), DatumType::I32); assert_eq!(centroids.rank(), 3); assert!(centroids.datum_type().is_float()); let enable_outlier = outlier_indices.len() > 0; if enable_outlier { assert_eq!(outlier_indices.rank(), 3); - assert_eq!(outlier_indices.datum_type(), DatumType::U16); + assert_eq!(outlier_indices.datum_type(), DatumType::I32); assert_eq!(outlier_centroids.rank(), 3); assert!(outlier_centroids.datum_type().is_float()); } - let mut qweight = self.eval_extract_from_vector_quant(centroids, indices)?; + let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; if enable_outlier { // same as centroids to qweights except for outlier let outlier_qweight = - self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices)?; + self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices, self.outlier_size)?; qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; } @@ -243,10 +244,36 @@ impl EvalOp for VPTQGemm { + weight_bias.into_array::()?) .into_tensor(); } - // call matmul now with qweight - let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); - // einsum -> matmul imperatif - einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) + // // call matmul now with qweight + // let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); + // // einsum -> matmul imperatif + // einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) + // + qweight = qweight.permute_axes(&[1, 0])?; + let op = ops(); + let &[m, k] = input.shape() else { bail!("unexpected rank {:?}", input.rank())}; + let &n = qweight.shape().last().unwrap(); + + let mmm = op.mmm(DatumType::F32, Some(m), Some(k), Some(n)).unwrap(); + + let (pack_a, pack_b) = &mmm.packings()[0]; + let cstore = unsafe { + mmm.c_view(0, 1) + }; + + + let a = pack_a.prepare_tensor(&input, 1, 0)?; + let b = pack_b.prepare_tensor(&qweight, 0, 1)?; + unsafe { + let mut out = + Tensor::uninitialized::(&[m, n])?; + let non_linear = &[FusedSpec::AddMatMul { + a: tract_linalg::mmm::AsInputValue::Owned(a), b: tract_linalg::mmm::AsInputValue::Owned(b), packing: 0 + }, FusedSpec::Store(cstore.wrap(&out.view()))]; + mmm.run(m, n, non_linear); + + Ok(tvec!(out.into())) + } } } diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs index 433fd2f0cd..8a26ab96bb 100644 --- a/nnef/src/ops/core/vptq.rs +++ b/nnef/src/ops/core/vptq.rs @@ -1,6 +1,5 @@ use crate::internal::*; use crate::ser::*; -use tract_core::ops::cast::cast; use tract_core::ops::vptq::VPTQGemm; pub fn register(registry: &mut Registry) { @@ -20,6 +19,8 @@ pub fn register(registry: &mut Registry) { TypeName::Integer.named("vector_len"), TypeName::Integer.tensor().named("in_features"), TypeName::Integer.tensor().named("out_features"), + TypeName::Integer.tensor().named("group_size"), + TypeName::Integer.tensor().named("outlier_size"), ], &[("output", TypeName::Scalar.tensor())], de_vptq_gemm, @@ -29,7 +30,7 @@ pub fn register(registry: &mut Registry) { fn ser_vptq_gemm( ast: &mut IntoAst, node: &TypedNode, - op: &VPTQGemm, + _op: &VPTQGemm, ) -> TractResult>> { let input = ast.mapping[&node.inputs[0]].clone(); let indices = ast.mapping[&node.inputs[1]].clone(); @@ -44,6 +45,8 @@ fn ser_vptq_gemm( let vector_len = ast.mapping[&node.inputs[9]].clone(); let in_features = ast.mapping[&node.inputs[10]].clone(); let out_features = ast.mapping[&node.inputs[11]].clone(); + let group_size = ast.mapping[&node.inputs[12]].clone(); + let outlier_size = ast.mapping[&node.inputs[13]].clone(); Ok(Some(invocation( "tract_core_vptq_gemm", &[ @@ -61,6 +64,8 @@ fn ser_vptq_gemm( ("vector_len", numeric(vector_len)), ("in_features", numeric(in_features)), ("out_features", numeric(out_features)), + ("group_size", numeric(group_size)), + ("outlier_size", numeric(outlier_size)), ], ))) } @@ -81,8 +86,11 @@ fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> let out_features = invocation.named_arg_as(builder, "out_features")?; let is_indice_packed = invocation.named_arg_as(builder, "is_indice_packed")?; + let group_size = invocation.named_arg_as(builder, "group_size")?; + let outlier_size = invocation.named_arg_as(builder, "outlier_size")?; + builder.wire( - VPTQGemm { vector_len, in_features, out_features, is_indice_packed }, + VPTQGemm { vector_len, in_features, out_features, is_indice_packed, group_size, outlier_size}, &[ input, indices, From e22bff96a5c84ae94b707cc5f5d91e3174281991 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 24 Jan 2025 11:52:27 +0100 Subject: [PATCH 12/21] fix: cleanup --- core/src/ops/vptq.rs | 50 +++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 86f639a6e3..395ecff8f1 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -5,11 +5,10 @@ use crate::{ internal::*, ops::{ array::{Gather, GatherElements, Topk}, - einsum::EinSum, math::shift_left, }, }; -use tract_linalg::{mmm::{FusedSpec, Packing}, ops}; +use tract_linalg::{mmm::FusedSpec, ops}; #[derive(Debug, Clone)] pub struct VPTQGemm { @@ -18,7 +17,7 @@ pub struct VPTQGemm { pub out_features: usize, pub is_indice_packed: bool, pub group_size: usize, - pub outlier_size: usize + pub outlier_size: usize, } impl Op for VPTQGemm { @@ -52,7 +51,6 @@ impl VPTQGemm { let pack_tensor_shape = pack_tensor.shape().to_vec(); - let mut pre_shift_pack_tensor_shape = pack_tensor_shape.clone(); pre_shift_pack_tensor_shape.push(1); @@ -109,7 +107,7 @@ impl VPTQGemm { &self, centroids: Tensor, indices: Tensor, - group_size: usize + group_size: usize, ) -> TractResult { /// TODO: instead use ndarray for indices to use views transform (--) let mut indices = indices.clone(); @@ -209,11 +207,15 @@ impl EvalOp for VPTQGemm { assert!(outlier_centroids.datum_type().is_float()); } - let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; + let mut qweight = + self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; if enable_outlier { // same as centroids to qweights except for outlier - let outlier_qweight = - self.eval_extract_from_vector_quant(outlier_centroids, outlier_indices, self.outlier_size)?; + let outlier_qweight = self.eval_extract_from_vector_quant( + outlier_centroids, + outlier_indices, + self.outlier_size, + )?; qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; } @@ -244,35 +246,35 @@ impl EvalOp for VPTQGemm { + weight_bias.into_array::()?) .into_tensor(); } - // // call matmul now with qweight + // NOTE: next step is fast matmul equivalent of { // let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); - // // einsum -> matmul imperatif // einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) - // + // } qweight = qweight.permute_axes(&[1, 0])?; let op = ops(); - let &[m, k] = input.shape() else { bail!("unexpected rank {:?}", input.rank())}; + let &[m, k] = input.shape() else { bail!("unexpected rank {:?}", input.rank()) }; let &n = qweight.shape().last().unwrap(); let mmm = op.mmm(DatumType::F32, Some(m), Some(k), Some(n)).unwrap(); let (pack_a, pack_b) = &mmm.packings()[0]; - let cstore = unsafe { - mmm.c_view(0, 1) - }; - + let cstore = unsafe { mmm.c_view(0, 1) }; let a = pack_a.prepare_tensor(&input, 1, 0)?; let b = pack_b.prepare_tensor(&qweight, 0, 1)?; unsafe { - let mut out = - Tensor::uninitialized::(&[m, n])?; - let non_linear = &[FusedSpec::AddMatMul { - a: tract_linalg::mmm::AsInputValue::Owned(a), b: tract_linalg::mmm::AsInputValue::Owned(b), packing: 0 - }, FusedSpec::Store(cstore.wrap(&out.view()))]; - mmm.run(m, n, non_linear); - - Ok(tvec!(out.into())) + let out = Tensor::uninitialized::(&[m, n])?; + let non_linear = &[ + FusedSpec::AddMatMul { + a: tract_linalg::mmm::AsInputValue::Owned(a), + b: tract_linalg::mmm::AsInputValue::Owned(b), + packing: 0, + }, + FusedSpec::Store(cstore.wrap(&out.view())), + ]; + mmm.run(m, n, non_linear)?; + + Ok(tvec!(out.into())) } } } From 5aee7e7b19efc5ff97864160b40903913cd38377 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 24 Jan 2025 15:26:58 +0100 Subject: [PATCH 13/21] fix: fdtypes check and align --- core/src/ops/vptq.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 395ecff8f1..593708fdfc 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,3 +1,5 @@ +use std::collections::HashSet; + use tract_data::itertools::Itertools; use tract_ndarray::Array1; @@ -206,6 +208,12 @@ impl EvalOp for VPTQGemm { assert_eq!(outlier_centroids.rank(), 3); assert!(outlier_centroids.datum_type().is_float()); } + let fdtypes = HashSet::from([ + input.datum_type(), + centroids.datum_type(), + outlier_centroids.datum_type(), + ]); + assert!(fdtypes.len() == 1); let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; @@ -246,7 +254,7 @@ impl EvalOp for VPTQGemm { + weight_bias.into_array::()?) .into_tensor(); } - // NOTE: next step is fast matmul equivalent of { + // NOTE: next steps is fast matmul equivalent of { // let einsum_op = EinSum::new("ik,kj->ij".parse()?, f32::datum_type()); // einsum_op.eval(tvec!(input, qweight.permute_axes(&[1, 0])?.into_tvalue())) // } @@ -255,7 +263,7 @@ impl EvalOp for VPTQGemm { let &[m, k] = input.shape() else { bail!("unexpected rank {:?}", input.rank()) }; let &n = qweight.shape().last().unwrap(); - let mmm = op.mmm(DatumType::F32, Some(m), Some(k), Some(n)).unwrap(); + let mmm = op.mmm(*fdtypes.iter().next().unwrap(), Some(m), Some(k), Some(n)).unwrap(); let (pack_a, pack_b) = &mmm.packings()[0]; let cstore = unsafe { mmm.c_view(0, 1) }; From a96c24260a60163abc9a67b9ff0c136671b27298 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Sun, 26 Jan 2025 18:46:02 +0100 Subject: [PATCH 14/21] fix: better vptq --- core/src/ops/vptq.rs | 24 ++++++++++++++---------- nnef/src/ops/core/vptq.rs | 26 ++++++++++++++------------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 593708fdfc..96cb630b8c 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -175,9 +175,9 @@ impl EvalOp for VPTQGemm { bias, ) = args_9!(inputs); let indices = indices.into_tensor(); - let centroids = centroids.into_tensor(); + let mut centroids = centroids.into_tensor(); let outlier_indices = outlier_indices.into_tensor(); - let outlier_centroids = outlier_centroids.into_tensor(); + let mut outlier_centroids = outlier_centroids.into_tensor(); let perm = perm.into_tensor(); let weight_scale = weight_scale.into_tensor(); let weight_bias = weight_bias.into_tensor(); @@ -193,7 +193,7 @@ impl EvalOp for VPTQGemm { if bias.len() > 1 { unimplemented!("'bias' for vptq not yet supported !"); } - assert_eq!(input.rank(), 2); + assert!([2, 3].contains(&input.rank())); assert!(input.datum_type().is_float()); assert_eq!(indices.rank(), 3); @@ -208,12 +208,16 @@ impl EvalOp for VPTQGemm { assert_eq!(outlier_centroids.rank(), 3); assert!(outlier_centroids.datum_type().is_float()); } - let fdtypes = HashSet::from([ - input.datum_type(), - centroids.datum_type(), - outlier_centroids.datum_type(), - ]); - assert!(fdtypes.len() == 1); + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let fdtypes = HashSet::from(_fdtypes); + if fdtypes.len() != 1 { + log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type()); + centroids = centroids.cast_to_dt(input.datum_type())?.into_owned(); + outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned(); + } + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let fdtypes = HashSet::from(_fdtypes); + assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}"); let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; @@ -290,7 +294,7 @@ impl EvalOp for VPTQGemm { impl TypedOp for VPTQGemm { fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult> { let mut tfact = inputs[0].without_value(); - tfact.shape.set(1, self.out_features.into()); + tfact.shape.set(tfact.rank() - 1, self.out_features.into()); Ok(tvec!(tfact)) } diff --git a/nnef/src/ops/core/vptq.rs b/nnef/src/ops/core/vptq.rs index 8a26ab96bb..118a0885b3 100644 --- a/nnef/src/ops/core/vptq.rs +++ b/nnef/src/ops/core/vptq.rs @@ -30,7 +30,7 @@ pub fn register(registry: &mut Registry) { fn ser_vptq_gemm( ast: &mut IntoAst, node: &TypedNode, - _op: &VPTQGemm, + op: &VPTQGemm, ) -> TractResult>> { let input = ast.mapping[&node.inputs[0]].clone(); let indices = ast.mapping[&node.inputs[1]].clone(); @@ -42,11 +42,6 @@ fn ser_vptq_gemm( let weight_bias = ast.mapping[&node.inputs[7]].clone(); let bias = ast.mapping[&node.inputs[8]].clone(); - let vector_len = ast.mapping[&node.inputs[9]].clone(); - let in_features = ast.mapping[&node.inputs[10]].clone(); - let out_features = ast.mapping[&node.inputs[11]].clone(); - let group_size = ast.mapping[&node.inputs[12]].clone(); - let outlier_size = ast.mapping[&node.inputs[13]].clone(); Ok(Some(invocation( "tract_core_vptq_gemm", &[ @@ -61,11 +56,11 @@ fn ser_vptq_gemm( bias, ], &[ - ("vector_len", numeric(vector_len)), - ("in_features", numeric(in_features)), - ("out_features", numeric(out_features)), - ("group_size", numeric(group_size)), - ("outlier_size", numeric(outlier_size)), + ("vector_len", numeric(op.vector_len)), + ("in_features", numeric(op.in_features)), + ("out_features", numeric(op.out_features)), + ("group_size", numeric(op.group_size)), + ("outlier_size", numeric(op.outlier_size)), ], ))) } @@ -90,7 +85,14 @@ fn de_vptq_gemm(builder: &mut ModelBuilder, invocation: &ResolvedInvocation) -> let outlier_size = invocation.named_arg_as(builder, "outlier_size")?; builder.wire( - VPTQGemm { vector_len, in_features, out_features, is_indice_packed, group_size, outlier_size}, + VPTQGemm { + vector_len, + in_features, + out_features, + is_indice_packed, + group_size, + outlier_size, + }, &[ input, indices, From 8658ec5c6a1b464dc8017a5b6290733d66187c65 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Tue, 28 Jan 2025 13:30:33 +0100 Subject: [PATCH 15/21] fix: rank 3 vptq --- core/src/ops/vptq.rs | 49 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 39 insertions(+), 10 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 96cb630b8c..b59e75df63 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -68,8 +68,9 @@ impl VPTQGemm { let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); if pad_size > 0 { - let end = out.shape()[out.rank() - 1] - pad_size; - out = out.slice(out.rank() - 1, 0, end)?.into(); + let axis = out.rank() - 1; + let end = out.shape()[axis] - pad_size; + out = out.slice(axis, 0, end)?.into(); } let mut post_pad_pack_tensor_shape = pack_tensor_shape.clone(); @@ -149,9 +150,10 @@ impl VPTQGemm { .into_shape(&[vector_len * remain, num_codebooks * group_size])?; let dim0 = qweight.shape()[0]; - let padding = (-(self.out_features as i16) % vector_len as i16) as usize; + let padding = (-(self.out_features as i16)).wrapping_rem_euclid(vector_len as i16); if padding > 0 { - qweight = qweight.slice(0, 0, dim0 - padding)?; + let end = dim0 as i16 - padding; + qweight = qweight.slice(0, 0, end as usize)?; } Ok(qweight) } @@ -228,7 +230,15 @@ impl EvalOp for VPTQGemm { outlier_indices, self.outlier_size, )?; - qweight = Tensor::stack_tensors(1, &[outlier_qweight, qweight])?; + + qweight = + Tensor::stack_tensors(1, &[&outlier_qweight, &qweight]).with_context(|| { + format!( + "outlier.shape:{:?}, main.shape:{:?}", + &outlier_qweight.shape(), + &qweight.shape() + ) + })?; } let enable_perm = perm.len() > 1; @@ -264,18 +274,37 @@ impl EvalOp for VPTQGemm { // } qweight = qweight.permute_axes(&[1, 0])?; let op = ops(); - let &[m, k] = input.shape() else { bail!("unexpected rank {:?}", input.rank()) }; + let ishape = input.shape(); + let &n = qweight.shape().last().unwrap(); + let (&[m, k], out_shape, offset) = match ishape.len() { + 2 => { + let &[m, k] = ishape else { + bail!("unexpected rank: {:?}", input.len()); + }; + (&[m, k], vec![m, n], 0usize) + } + 3 => { + let &[b, m, k] = ishape else { + bail!("unexpected rank: {:?}", input.len()); + }; + (&[m, k], vec![b, m, n], 1usize) + } + _ => { + bail!("unexpected rank {:?}", ishape.len()) + } + }; + let mmm = op.mmm(*fdtypes.iter().next().unwrap(), Some(m), Some(k), Some(n)).unwrap(); let (pack_a, pack_b) = &mmm.packings()[0]; - let cstore = unsafe { mmm.c_view(0, 1) }; + let cstore = unsafe { mmm.c_view(0 + offset, 1 + offset) }; - let a = pack_a.prepare_tensor(&input, 1, 0)?; - let b = pack_b.prepare_tensor(&qweight, 0, 1)?; + let a = pack_a.prepare_tensor(&input, 1 + offset, 0 + offset)?; + let b = pack_b.prepare_tensor(&qweight, 0 + offset, 1 + offset)?; unsafe { - let out = Tensor::uninitialized::(&[m, n])?; + let out = Tensor::uninitialized::(out_shape.iter().as_slice().try_into()?)?; let non_linear = &[ FusedSpec::AddMatMul { a: tract_linalg::mmm::AsInputValue::Owned(a), From 044f72b210d1d309fa6578f4d5110964073e00e5 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Tue, 28 Jan 2025 17:42:58 +0100 Subject: [PATCH 16/21] fix: vptq working on large model --- core/src/ops/vptq.rs | 34 +++++++++++++++++++++------------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index b59e75df63..d56ce91484 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -263,6 +263,8 @@ impl EvalOp for VPTQGemm { .into_tensor(); } + let data_type = *fdtypes.iter().next().unwrap(); + if enable_norm { qweight = (qweight.into_array::()? * weight_scale.into_array::()? + weight_bias.into_array::()?) @@ -278,33 +280,36 @@ impl EvalOp for VPTQGemm { let &n = qweight.shape().last().unwrap(); - let (&[m, k], out_shape, offset) = match ishape.len() { + let (&[m, k], out_shape) = match ishape.len() { 2 => { let &[m, k] = ishape else { - bail!("unexpected rank: {:?}", input.len()); + bail!("unexpected rank: {:?}", ishape.len()); }; - (&[m, k], vec![m, n], 0usize) + (&[m, k], vec![m, n]) } 3 => { let &[b, m, k] = ishape else { - bail!("unexpected rank: {:?}", input.len()); + bail!("unexpected rank: {:?}", ishape.len()); }; - (&[m, k], vec![b, m, n], 1usize) + (&[m, k], vec![b, m, n]) } _ => { bail!("unexpected rank {:?}", ishape.len()) } }; - let mmm = op.mmm(*fdtypes.iter().next().unwrap(), Some(m), Some(k), Some(n)).unwrap(); + let input_offset = input.rank() - 2; + let weight_offset = qweight.rank() - 2; + let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap(); let (pack_a, pack_b) = &mmm.packings()[0]; - let cstore = unsafe { mmm.c_view(0 + offset, 1 + offset) }; - let a = pack_a.prepare_tensor(&input, 1 + offset, 0 + offset)?; - let b = pack_b.prepare_tensor(&qweight, 0 + offset, 1 + offset)?; - unsafe { - let out = Tensor::uninitialized::(out_shape.iter().as_slice().try_into()?)?; + let cstore = unsafe { mmm.c_view(input_offset, 1 + input_offset) }; + + let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?; + let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?; + let last = unsafe { + let out = Tensor::uninitialized::(out_shape.iter().as_slice())?; let non_linear = &[ FusedSpec::AddMatMul { a: tract_linalg::mmm::AsInputValue::Owned(a), @@ -315,8 +320,11 @@ impl EvalOp for VPTQGemm { ]; mmm.run(m, n, non_linear)?; - Ok(tvec!(out.into())) - } + out + }; + // force down cast for now + let last_cdt = last.cast_to_dt(input.datum_type())?.into_owned().into_tvalue(); + Ok(tvec!(last_cdt)) } } From 40d2e636f87e9de6f8fc979dce066639337dce3d Mon Sep 17 00:00:00 2001 From: Mathieu Poumeyrol Date: Wed, 29 Jan 2025 11:51:01 +0100 Subject: [PATCH 17/21] make vtpq work on top of available kernels --- core/src/ops/vptq.rs | 96 ++++++++++++++++++++++++++++---------------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index d56ce91484..a119d60097 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -57,14 +57,21 @@ impl VPTQGemm { pre_shift_pack_tensor_shape.push(1); let mut out = shift_right_zero_and_1( - pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(), + pack_tensor + .clone() + .into_shape(&pre_shift_pack_tensor_shape)? + .into(), wf.into(), )?; let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone(); let pval = post_shift_pack_tensor_shape.pop().unwrap(); post_shift_pack_tensor_shape.push(32 * pval); - out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); + out = out + .into_tensor() + .clone() + .into_shape(&post_shift_pack_tensor_shape)? + .into_tvalue(); let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); if pad_size > 0 { @@ -78,10 +85,15 @@ impl VPTQGemm { let auto = out.shape().last().unwrap() / index_bits; post_pad_pack_tensor_shape.push(auto); post_pad_pack_tensor_shape.push(index_bits); - out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into(); + out = out + .into_tensor() + .into_shape(&post_pad_pack_tensor_shape)? + .into(); let wf1 = Tensor::from( - Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(), + Array1::from_iter(0..(index_bits as i32)) + .to_shape([1, 1, 1, index_bits])? + .into_owned(), ); out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); @@ -146,7 +158,7 @@ impl VPTQGemm { .into_shape(&[num_codebooks, remain, group_size, vector_len])? .permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory) .into_shape(&[num_codebooks, remain * vector_len, group_size])? - .permute_axes(&[1, 0, 2])?// NOTE: costly in tract (applied in memory) + .permute_axes(&[1, 0, 2])? // NOTE: costly in tract (applied in memory) .into_shape(&[vector_len * remain, num_codebooks * group_size])?; let dim0 = qweight.shape()[0]; @@ -210,14 +222,27 @@ impl EvalOp for VPTQGemm { assert_eq!(outlier_centroids.rank(), 3); assert!(outlier_centroids.datum_type().is_float()); } - let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let _fdtypes = [ + input.datum_type(), + centroids.datum_type(), + outlier_centroids.datum_type(), + ]; let fdtypes = HashSet::from(_fdtypes); if fdtypes.len() != 1 { - log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type()); + log::warn!( + "force cast centroids to be same type as input: {:?}", + input.datum_type() + ); centroids = centroids.cast_to_dt(input.datum_type())?.into_owned(); - outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned(); + outlier_centroids = outlier_centroids + .cast_to_dt(input.datum_type())? + .into_owned(); } - let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; + let _fdtypes = [ + input.datum_type(), + centroids.datum_type(), + outlier_centroids.datum_type(), + ]; let fdtypes = HashSet::from(_fdtypes); assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}"); @@ -245,16 +270,23 @@ impl EvalOp for VPTQGemm { if enable_perm { let axis = 0; let dim = perm.shape()[0]; - let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; - let invert_perm = - top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(0); + let top_k = Topk { + axis, + largest: false, + fallback_k: dim.into(), + }; + let invert_perm = top_k + .eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))? + .remove(0); // TODO: manage case with quant dim == 'in' ? // if self.vector_quant_dim == "in": // assert True, "Not implemented" // qweight = qweight[invert_perm, :] let perm_gather_axis = 1; - let gather_perm = Gather { axis: perm_gather_axis }; + let gather_perm = Gather { + axis: perm_gather_axis, + }; qweight = gather_perm .eval(tvec!(qweight.into(), invert_perm))? .pop() @@ -280,19 +312,9 @@ impl EvalOp for VPTQGemm { let &n = qweight.shape().last().unwrap(); - let (&[m, k], out_shape) = match ishape.len() { - 2 => { - let &[m, k] = ishape else { - bail!("unexpected rank: {:?}", ishape.len()); - }; - (&[m, k], vec![m, n]) - } - 3 => { - let &[b, m, k] = ishape else { - bail!("unexpected rank: {:?}", ishape.len()); - }; - (&[m, k], vec![b, m, n]) - } + let (m, k, out_shape) = match ishape { + &[m, k] => (m, k, vec![m, n]), + &[b, m, k] => (m, k, vec![b, m, n]), _ => { bail!("unexpected rank {:?}", ishape.len()) } @@ -301,6 +323,16 @@ impl EvalOp for VPTQGemm { let input_offset = input.rank() - 2; let weight_offset = qweight.rank() - 2; + /* this would be better for Intel where there is no f16 support, but the kernel selection + APIs are not up to the task (yet) + + let acc_type = if tract_linalg::has_fp16() { + f16::datum_type() + } else { + f32::datum_type() + }; + + */ let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap(); let (pack_a, pack_b) = &mmm.packings()[0]; @@ -308,8 +340,8 @@ impl EvalOp for VPTQGemm { let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?; let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?; - let last = unsafe { - let out = Tensor::uninitialized::(out_shape.iter().as_slice())?; + unsafe { + let out = Tensor::uninitialized_dt(data_type, &out_shape)?; let non_linear = &[ FusedSpec::AddMatMul { a: tract_linalg::mmm::AsInputValue::Owned(a), @@ -319,12 +351,8 @@ impl EvalOp for VPTQGemm { FusedSpec::Store(cstore.wrap(&out.view())), ]; mmm.run(m, n, non_linear)?; - - out - }; - // force down cast for now - let last_cdt = last.cast_to_dt(input.datum_type())?.into_owned().into_tvalue(); - Ok(tvec!(last_cdt)) + Ok(tvec!(out.into_tvalue())) + } } } From d4d331739904faa08d8cad4b7219149b00beca00 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 31 Jan 2025 15:36:33 +0100 Subject: [PATCH 18/21] fix: permutation applied in vptq correctly --- core/src/ops/vptq.rs | 59 +++++------------- harness/nnef-test-cases/vptq-basic/io.npz | Bin 0 -> 522 bytes .../nnef-test-cases/vptq-basic/model.nnef.tgz | Bin 0 -> 20518 bytes harness/nnef-test-cases/vptq-basic/runme.sh | 8 +++ harness/nnef-test-cases/vptq-with-perm/io.npz | Bin 0 -> 526 bytes .../vptq-with-perm/model.nnef.tgz | Bin 0 -> 20518 bytes .../nnef-test-cases/vptq-with-perm/runme.sh | 8 +++ 7 files changed, 30 insertions(+), 45 deletions(-) create mode 100644 harness/nnef-test-cases/vptq-basic/io.npz create mode 100644 harness/nnef-test-cases/vptq-basic/model.nnef.tgz create mode 100755 harness/nnef-test-cases/vptq-basic/runme.sh create mode 100644 harness/nnef-test-cases/vptq-with-perm/io.npz create mode 100644 harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz create mode 100755 harness/nnef-test-cases/vptq-with-perm/runme.sh diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index a119d60097..c3d6ee6d7c 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,5 +1,6 @@ -use std::collections::HashSet; +use std::{collections::HashSet, path::Path}; +use ndarray_npy::NpzWriter; use tract_data::itertools::Itertools; use tract_ndarray::Array1; @@ -57,21 +58,14 @@ impl VPTQGemm { pre_shift_pack_tensor_shape.push(1); let mut out = shift_right_zero_and_1( - pack_tensor - .clone() - .into_shape(&pre_shift_pack_tensor_shape)? - .into(), + pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(), wf.into(), )?; let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone(); let pval = post_shift_pack_tensor_shape.pop().unwrap(); post_shift_pack_tensor_shape.push(32 * pval); - out = out - .into_tensor() - .clone() - .into_shape(&post_shift_pack_tensor_shape)? - .into_tvalue(); + out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); if pad_size > 0 { @@ -85,15 +79,10 @@ impl VPTQGemm { let auto = out.shape().last().unwrap() / index_bits; post_pad_pack_tensor_shape.push(auto); post_pad_pack_tensor_shape.push(index_bits); - out = out - .into_tensor() - .into_shape(&post_pad_pack_tensor_shape)? - .into(); + out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into(); let wf1 = Tensor::from( - Array1::from_iter(0..(index_bits as i32)) - .to_shape([1, 1, 1, index_bits])? - .into_owned(), + Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(), ); out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap(); @@ -222,27 +211,14 @@ impl EvalOp for VPTQGemm { assert_eq!(outlier_centroids.rank(), 3); assert!(outlier_centroids.datum_type().is_float()); } - let _fdtypes = [ - input.datum_type(), - centroids.datum_type(), - outlier_centroids.datum_type(), - ]; + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; let fdtypes = HashSet::from(_fdtypes); if fdtypes.len() != 1 { - log::warn!( - "force cast centroids to be same type as input: {:?}", - input.datum_type() - ); + log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type()); centroids = centroids.cast_to_dt(input.datum_type())?.into_owned(); - outlier_centroids = outlier_centroids - .cast_to_dt(input.datum_type())? - .into_owned(); + outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned(); } - let _fdtypes = [ - input.datum_type(), - centroids.datum_type(), - outlier_centroids.datum_type(), - ]; + let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()]; let fdtypes = HashSet::from(_fdtypes); assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}"); @@ -270,23 +246,16 @@ impl EvalOp for VPTQGemm { if enable_perm { let axis = 0; let dim = perm.shape()[0]; - let top_k = Topk { - axis, - largest: false, - fallback_k: dim.into(), - }; - let invert_perm = top_k - .eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))? - .remove(0); + let top_k = Topk { axis, largest: false, fallback_k: dim.into() }; + let invert_perm = + top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(1); // TODO: manage case with quant dim == 'in' ? // if self.vector_quant_dim == "in": // assert True, "Not implemented" // qweight = qweight[invert_perm, :] let perm_gather_axis = 1; - let gather_perm = Gather { - axis: perm_gather_axis, - }; + let gather_perm = Gather { axis: perm_gather_axis }; qweight = gather_perm .eval(tvec!(qweight.into(), invert_perm))? .pop() diff --git a/harness/nnef-test-cases/vptq-basic/io.npz b/harness/nnef-test-cases/vptq-basic/io.npz new file mode 100644 index 0000000000000000000000000000000000000000..cfdaabb9053b2d037baeed5fd39b78138053d996 GIT binary patch literal 522 zcmWIWW@gc4fB;2?8=c;h{zCyfg9t-rUV&a-K_w%D07C~<5saSf7wQ`j$;eQ~P_3Sl zTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh=XCxM+0{I$-ItoUbItsN4 zWCN}zdGhv8eu~)zpm^Zu<-HNe9>^~(fq7s8)B`Y@QV*C=?1A|X{USRYn(nP|&}}*B z5a7+oq|1yN7oflfVRoq7K~w`Hh=j%rx+akSK^_EQR80qgBGAwX@MdKL@)($aFbha` HflUMed>?bv literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/vptq-basic/model.nnef.tgz b/harness/nnef-test-cases/vptq-basic/model.nnef.tgz new file mode 100644 index 0000000000000000000000000000000000000000..32afbde3496423c930e7598fc5a3ee47b8240456 GIT binary patch literal 20518 zcmeHN(N5bi6m{7iP@nh%{MI6M!A_HwiNV<0o+jQWWQEj}XlWXpv;rn1-uDUH!~S5u zuuqw^57;NL>y~tG(zTNyrHpHFF}`-LuaA8>vE$88>+4sazjVn>uiufO6-DyIO1qb& zLJn`fS=-k3!J&8lR(c(&NGXYp6ZuQdSkuYsbQ8ijv$;*^c`4rHj1H32Q_y`*{Zxwf z4(+8uPww${bEnB$t=3Bm=7aCJt?E<3;3e2Q?pKAZgL}_F zDj;jOIb?-ILG!$zMa5%A0Qn#JANl_=NATTzH~HU}QL6eu=f2+nG{?{7f4ji{O?X{E z{=eJQ{?pT%|2Lqe$^ZHtU{<2Dly}#?A!GxDf-n$F%R~VAANe2of0+@tT77B$FN^oH zkuKzaC(r+O%Wfk7XZV$wYyPLuLB{`jf4^k@haRtA!>zd!@7|C6r$)Gt{|ozn&K!>XKLuh*g>3&{0sCtU`Z$34vd;f? z|CpAK0P;Wbf06(3`~UJI@nL$@`~R}O`)4Cs$p3}!|F-Sm`F{pqsrf2=|3~gGb&BU2 zNB&3tFY-U`|K|qHe^l93UVJ*b6I|Y)7VtlJ;r8wH{XcHoxc@J2LJZ48sgj@{QEsu_ zYE_=6G8&ar<@u>lvKu6+x}f8t-&Yb^jn)i}WGWJQHh@bbffov!*{)z*GLL}vMbeoJ@k%c2sJgMVXZ-@>^CSsGR#fBe3=Fl-wVCGy-CZQMMpbr zScJaKyR@d=*^cHO$eFiu4ri?=?wB3SbiD9Hv4^Pv=;J`sXwmEOk)dhEGyu>3SuA20 zr)Q!odp*4<8Wfz#b{pfuWE|sS=Irf`WLlYPs~v4nW-;^EQ2J?KiBLx12Y?qu;zYuN zQZgAUvqoAIaUsO_*m=~X4l#&V(cT3DEP6dh=XCxM+0{I$-Its>`ItsN4 zWCJdUa~kY{v_k-jC%D(|s6_TeerXBJ6BD4GfYFqC!h~W^2skk?$N({f4)A7V(q%@C z4^V)EFgw)kAgX~8L_%W+T@%RvAP<5ts-~kr5ol-xc(bwrc??WIm<6Q!z$O9!t1oEl literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz b/harness/nnef-test-cases/vptq-with-perm/model.nnef.tgz new file mode 100644 index 0000000000000000000000000000000000000000..21fc4f14834b9768d28b8f6313c5242d5f4e8f51 GIT binary patch literal 20518 zcmeHNL2uJA6m|ijsiz&e;afW?WR8=jlg385aYDVLDP*nJMWkuwq|-J)+8v2sz=40j zU*O2UfIA0%!gy|*&P_{;w5D|IS$%c=>^whz@$2U}e*9r$!~XuML$12Lwg}BA5=Um* zxgaIdy!v8po7;`1fA&K7Z6Qe^i47C^L(Z7(WOb?uVHV>SC-kfkZ*oTaN$N}JKBvAb zMZQD3Y0wpWR^8sQZ3k|K=~y+m0rm1-!Qd{Bv)4~OkT4L^3!-+=5{cRN)1^Y9g4%Tl zF>cjs5Z{btxvGM>#)#$Aow`-4JFX4!?=Z_EbR{vF#lMJyxr1I*KlkGGt$g{bL^i;^ z`yeHdd!rn(sYF5ZU`mUM+l&D6Kk`5F|80)ooB6Kue@jHE>;>&*uK*Os&*gtU#s9YJ zBLCm)&i>uiivKsEsm}ju4KORwS;`yh8qW!7g1iQaK`^Zn0px$=f8_sFM%;S!rTD)n z-p@w5kpJC0|8tkuk^eLNip&-Nt9_P?{~yBd)$+d@-Q_d*YQP8N667<;SCAhd1l9zb zAQXh*`Tzh~W@I!Fe*%Ps{BP&^-(rq~{67I=K?TMCkHB^b zM6Chlr3O?kd;kC5s6*91rqv^W{Ez&P{J;9RTbCZy^M6tA{j)JHc$v&;Z;&Q9}woD;`j_zq@Jk^zq{IHrF`>}w!SC&ET1B1fp(|*@W&j-rN zkIcR|j4-r=nE*d4utnT}6gG2O$&0MvG%Zgkq~;g+Fp z#?}ykYp*AZcN(WBUPpAhYEl>!oXMUxhK2DshQ-X;(G$s}GS^l)YEWh}{p(b;(w_7} z5rH27UJ!Xl0w$Cc2`KELK~3X8lVs$Q7yGR@qCFTuUl_C_d;Qo;f(tb}VRB~9jWitQ z(R(~2fBjd2zYqaL01-e05CKF05kLeG0Ym^1Km-s0L;w*$1P}p401-e05CKF05kLeG V0Ym^1Km`6F0zbdK`$(FE{05@AKt=!n literal 0 HcmV?d00001 diff --git a/harness/nnef-test-cases/vptq-with-perm/runme.sh b/harness/nnef-test-cases/vptq-with-perm/runme.sh new file mode 100755 index 0000000000..828dcb90de --- /dev/null +++ b/harness/nnef-test-cases/vptq-with-perm/runme.sh @@ -0,0 +1,8 @@ +#!/bin/sh + +cd $(dirname $0) +set -ex + +: ${TRACT_RUN:=cargo run -p tract $CARGO_OPTS --} + +$TRACT_RUN ./model.nnef.tgz --nnef-tract-core run --input-from-bundle ./io.npz --assert-output-bundle ./io.npz -q From 06d3c9070fe300f5cff1c1da9388239f95b61f12 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Fri, 31 Jan 2025 15:42:31 +0100 Subject: [PATCH 19/21] fix: clean modules import --- core/src/ops/vptq.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index c3d6ee6d7c..386d823d31 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -1,6 +1,5 @@ -use std::{collections::HashSet, path::Path}; +use std::collections::HashSet; -use ndarray_npy::NpzWriter; use tract_data::itertools::Itertools; use tract_ndarray::Array1; From 506df6acbd44d6d0a24f29a6bc5f92c6e25a5692 Mon Sep 17 00:00:00 2001 From: epi Date: Wed, 5 Feb 2025 00:26:52 +0100 Subject: [PATCH 20/21] fix: avoid clone and useless array copy --- core/src/ops/vptq.rs | 54 ++++++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 27 deletions(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 386d823d31..8678cb5521 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -6,7 +6,7 @@ use tract_ndarray::Array1; use crate::{ internal::*, ops::{ - array::{Gather, GatherElements, Topk}, + array::{Gather, Topk}, math::shift_left, }, }; @@ -33,7 +33,7 @@ fn shift_right_zero_and_1(input: TValue, shift_value: TValue) -> TractResult()?; let shift_value = shift_value.to_array_view::()?; let out_shape = crate::broadcast::multi_broadcast(&[input.shape(), shift_value.shape()])?; - let mut out = Tensor::zero_dt(DatumType::I32, &out_shape)?; + let mut out = unsafe { Tensor::uninitialized_dt(DatumType::I32, &out_shape)? }; crate::ndarray::Zip::from(out.to_array_view_mut::()?) .and_broadcast(input) .and_broadcast(shift_value) @@ -41,6 +41,21 @@ fn shift_right_zero_and_1(input: TValue, shift_value: TValue) -> TractResult TractResult { + let &[_, _, vlen] = centroids.shape() else { bail!("wrong centroids shape") }; + let &[b, n_indices_x, n_indices_y] = indices.shape() else { + bail!("wrong indice shape {:?}", indices.shape()) + }; + let mut out = unsafe { + Tensor::uninitialized_dt(centroids.datum_type(), &[b, n_indices_x * n_indices_y, vlen])? + }; + indices.to_array_view::()?.iter().enumerate().for_each(|(idx, idx_val)| { + let idx_val = *idx_val as usize; + out.assign_slice(idx..idx + 1, centroids, idx_val..(idx_val + 1), 1).unwrap(); + }); + Ok(out) +} + impl VPTQGemm { /// decompression of indexes fn eval_unpack_index_tensor( @@ -57,14 +72,14 @@ impl VPTQGemm { pre_shift_pack_tensor_shape.push(1); let mut out = shift_right_zero_and_1( - pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(), + pack_tensor.into_shape(&pre_shift_pack_tensor_shape)?.into(), wf.into(), )?; let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone(); let pval = post_shift_pack_tensor_shape.pop().unwrap(); post_shift_pack_tensor_shape.push(32 * pval); - out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); + out = out.into_tensor().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue(); let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements); if pad_size > 0 { @@ -89,7 +104,7 @@ impl VPTQGemm { let axis = out.rank() - 1; out = out .into_tensor() - .into_array::()? + .to_array_view_mut::()? .sum_axis(tract_ndarray::Axis(axis)) .into_tvalue(); @@ -112,33 +127,17 @@ impl VPTQGemm { indices: Tensor, group_size: usize, ) -> TractResult { - /// TODO: instead use ndarray for indices to use views transform (--) let mut indices = indices.clone(); let [num_codebooks, num_centroids, vector_len] = *centroids.shape() else { unimplemented!("unexected centroid shape ?") }; if self.is_indice_packed { - // unimplemented!("unpacking indices not implemented yet !"); let index_bits = (num_centroids as f32).log2().ceil() as usize; indices = self.eval_unpack_index_tensor(indices, index_bits, group_size)?; } - let mut vsh = indices.shape().to_vec(); - indices.insert_axis(3)?; - vsh.push(vector_len); - indices = indices.broadcast_to_shape(&vsh)?; // NOTE: costly in tract (applied in memory but not in ndarray) - let intermediate_volume = indices.shape()[1..3].iter().product(); - indices = indices.into_shape(&[num_codebooks, intermediate_volume, vector_len])?; - - let gather1 = GatherElements { axis: 1 }; - // selected_centroids = torch.gather(centroids, 1, indices) - let selected_centroids = gather1 - .eval(tvec!(centroids.into(), indices.into()))? - .pop() - .context("apply gather to get selected main centroids") - .unwrap() - .into_tensor(); + let selected_centroids = gather_all_elements(¢roids, &indices)?; let remain = selected_centroids.volume() / (num_codebooks * group_size * vector_len); @@ -223,6 +222,7 @@ impl EvalOp for VPTQGemm { let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; + dbg!(&qweight); if enable_outlier { // same as centroids to qweights except for outlier let outlier_qweight = self.eval_extract_from_vector_quant( @@ -266,8 +266,8 @@ impl EvalOp for VPTQGemm { let data_type = *fdtypes.iter().next().unwrap(); if enable_norm { - qweight = (qweight.into_array::()? * weight_scale.into_array::()? - + weight_bias.into_array::()?) + qweight = (qweight.into_array::()? * weight_scale.to_array_view::()? + + weight_bias.to_array_view::()?) .into_tensor(); } // NOTE: next steps is fast matmul equivalent of { @@ -280,9 +280,9 @@ impl EvalOp for VPTQGemm { let &n = qweight.shape().last().unwrap(); - let (m, k, out_shape) = match ishape { - &[m, k] => (m, k, vec![m, n]), - &[b, m, k] => (m, k, vec![b, m, n]), + let (m, k, out_shape) = match *ishape { + [m, k] => (m, k, vec![m, n]), + [b, m, k] => (m, k, vec![b, m, n]), _ => { bail!("unexpected rank {:?}", ishape.len()) } From 4375139b552690bec7a43074babb51b6a54575b3 Mon Sep 17 00:00:00 2001 From: Julien Balian Date: Mon, 10 Feb 2025 11:56:01 +0100 Subject: [PATCH 21/21] fix: rm dbg --- core/src/ops/vptq.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/ops/vptq.rs b/core/src/ops/vptq.rs index 8678cb5521..98a9f6900a 100644 --- a/core/src/ops/vptq.rs +++ b/core/src/ops/vptq.rs @@ -222,7 +222,6 @@ impl EvalOp for VPTQGemm { let mut qweight = self.eval_extract_from_vector_quant(centroids, indices, self.group_size)?; - dbg!(&qweight); if enable_outlier { // same as centroids to qweights except for outlier let outlier_qweight = self.eval_extract_from_vector_quant(