diff --git a/core/src/ops/einsum/kernel_selection.rs b/core/src/ops/einsum/kernel_selection.rs index 99732acd14..649b5486f7 100644 --- a/core/src/ops/einsum/kernel_selection.rs +++ b/core/src/ops/einsum/kernel_selection.rs @@ -110,7 +110,7 @@ pub fn wire_linear( kit.weight == weight && kit.accumulator == accumulator && kit.activation == activation }) .min_by_key(|kit| kit.generic_fallback as usize) - .with_context(|| format!("No kit found for matmul {a:?} • {b_fact:?}"))?; + .with_context(|| format!("No kit found for matmul {weight:?} {accumulator:?} {activation:?}"))?; let configs = [kit.item_for_mv(), kit.item_for_squarish()]; let packed: Box = if let Some(a_payload) = a_as_bqv { let packed = kit diff --git a/data/src/datum.rs b/data/src/datum.rs index ef8d781a6c..768deccb04 100644 --- a/data/src/datum.rs +++ b/data/src/datum.rs @@ -492,6 +492,7 @@ pub trait Datum: { fn name() -> &'static str; fn datum_type() -> DatumType; + fn is() -> bool; } macro_rules! datum { @@ -510,6 +511,10 @@ macro_rules! datum { fn datum_type() -> DatumType { DatumType::$v } + + fn is() -> bool { + Self::datum_type() == D::datum_type() + } } }; } diff --git a/linalg/build.rs b/linalg/build.rs index d46d2b610b..c7bbb9b9a8 100644 --- a/linalg/build.rs +++ b/linalg/build.rs @@ -269,6 +269,7 @@ fn preprocess_file( } .to_owned(); let long = if msvc { "dd" } else { ".long" }; + let quad = if msvc { "dq" } else { ".quad" }; let g = if os == "macos" || os == "ios" { "_" } else { "" }; // note: use .align with bytes instead of p2align since they both use direct bytes. let align = if msvc { "align" } else { ".align" }; @@ -281,6 +282,7 @@ fn preprocess_file( "G": g, "suffix": suffix, "long": long, + "quad": quad, "jump_table": jump_table(), "align": align, "offset": if msvc { "offset" } else { "rip + "}, diff --git a/linalg/src/arm64/arm64fp16.rs b/linalg/src/arm64/arm64fp16.rs index e69d64d9f5..286f67f3bb 100644 --- a/linalg/src/arm64/arm64fp16.rs +++ b/linalg/src/arm64/arm64fp16.rs @@ -50,3 +50,13 @@ pub fn plug(ops: &mut Ops) { tanh_impl!(f16, arm64fp16_tanh_f16_8n, 8, 8, crate::arm64::has_fp16()); sigmoid_impl!(f16, arm64fp16_sigmoid_f16_8n, 8, 8, crate::arm64::has_fp16()); + +#[cfg(test)] +mod test { + + #[test] + fn kits() { + let mut ops = crate::generic(); + super::plug(&mut ops); + } +} diff --git a/linalg/src/arm64/arm64fp16/panel_extract.rs b/linalg/src/arm64/arm64fp16/panel_extract.rs index 21145f3fbe..bea69a9968 100644 --- a/linalg/src/arm64/arm64fp16/panel_extract.rs +++ b/linalg/src/arm64/arm64fp16/panel_extract.rs @@ -1,8 +1,8 @@ +use super::FP16; use crate::frame::block_quant::{PackedBlockQuantFormat, Q4_0}; -use crate::frame::PackedFormat; +use crate::mmm::Packing; use crate::Ops; use tract_data::internal::*; -use super::FP16; pub fn plug(ops: &mut Ops) { ops.panel_extractors.push(packed_64_q40_to_f16.clone()); @@ -10,7 +10,7 @@ pub fn plug(ops: &mut Ops) { panel_extractor!(kernel_packed_64_q40_to_f16 as packed_64_q40_to_f16( Box::new(PackedBlockQuantFormat::new(&Q4_0, 64, 16, true)), - PackedFormat::new(f16::datum_type(), 64, 16) + f16::packing(64).align(16) ) where(FP16)); #[target_feature(enable = "fp16")] diff --git a/linalg/src/frame/mmm/kernel.rs b/linalg/src/frame/mmm/kernel.rs index 1018483887..eeb671c7db 100644 --- a/linalg/src/frame/mmm/kernel.rs +++ b/linalg/src/frame/mmm/kernel.rs @@ -34,7 +34,6 @@ type Kernel = unsafe fn(&[FusedKerSpec]) -> isize; pub struct DynKernel { pub name: String, pub kernel: Kernel, - pub default_packing_alignments: (usize, usize), pub packings: Vec<(Box, Box)>, pub stores: Vec, pub supported_predicate: fn() -> bool, @@ -45,7 +44,8 @@ impl DynKernel { pub fn new( name: &str, kernel: Kernel, - default_packing_alignments: (usize, usize), + packing_a: PackedFormat, + packing_b: PackedFormat, ) -> Self { let kernel = DynKernel { name: name.to_string(), @@ -53,12 +53,9 @@ impl DynKernel { packings: vec![], stores: vec![Acc::datum_type()], supported_predicate: || true, - default_packing_alignments, can_fuse: |_| true, }; - let a = kernel.regular_pack_a(); - let b = kernel.regular_pack_b(); - kernel.with_packing(a, b) + kernel.with_packing(packing_a, packing_b) } pub fn with_platform_condition(mut self, f: fn() -> bool) -> Self { @@ -77,11 +74,11 @@ impl DynKernel { } pub fn regular_pack_a(&self) -> PackedFormat { - PackedFormat::new(Acc::datum_type(), MR, self.default_packing_alignments.0) + *self.packings[0].0.clone().downcast::().unwrap() } pub fn regular_pack_b(&self) -> PackedFormat { - PackedFormat::new(Acc::datum_type(), NR, self.default_packing_alignments.1) + *self.packings[0].1.clone().downcast::().unwrap() } pub fn with_can_fuse(self, can_fuse: fn(&FusedSpec) -> bool) -> Self { diff --git a/linalg/src/frame/mmm/macros.rs b/linalg/src/frame/mmm/macros.rs index 3e47ab1135..8714bed583 100644 --- a/linalg/src/frame/mmm/macros.rs +++ b/linalg/src/frame/mmm/macros.rs @@ -1,6 +1,7 @@ macro_rules! MMMExternKernel { ( - $func:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $func:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -21,7 +22,8 @@ macro_rules! MMMExternKernel { } } - MMMKernel!([]::rusty as $func<$ti>($mr, $nr)@($align_a, $align_b) + MMMKernel!([]::rusty as $func<$ti>($mr, $nr) + $(@($align_a, $align_b))? $(where($where))? $(can_fuse($can_fuse))? $(packing[$pnum] = $pid => $packing;)* @@ -32,7 +34,8 @@ macro_rules! MMMExternKernel { } macro_rules! MMMRustKernel { ( $func: path => - $id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $id:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -49,7 +52,8 @@ macro_rules! MMMRustKernel { $func(op.as_ptr()) } } - MMMKernel!([]::rusty as $id<$ti>($mr, $nr)@($align_a, $align_b) + MMMKernel!([]::rusty as $id<$ti>($mr, $nr) + $(@($align_a, $align_b))? $(where($where))? $(can_fuse($can_fuse))? $(packing[$pnum] = $pid => $packing;)* @@ -62,7 +66,8 @@ macro_rules! MMMRustKernel { macro_rules! MMMKernel { ( $func: path as - $id:ident<$ti:ident>($mr: expr, $nr: expr)@($align_a:expr, $align_b:expr) + $id:ident<$ti:ident>($mr: expr, $nr: expr) + $(@($align_a:expr, $align_b:expr))? $(where($where:expr))? $(can_fuse($can_fuse:expr))? $(packing[$pnum:literal] = $pid:ident => $packing:expr;)* @@ -75,8 +80,15 @@ macro_rules! MMMKernel { use $crate::mmm::DynKernel; #[allow(unused_imports)] use tract_data::prelude::*; + use $crate::frame::mmm::Packing; #[allow(unused_mut)] - let mut k = DynKernel::<$mr, $nr, $ti>::new(stringify!($id), $func, ($align_a, $align_b)); + let (mut packing_a, mut packing_b) = ($ti::packing($mr), $ti::packing($nr)); + $( + packing_a = packing_a.align($align_a); + packing_b = packing_b.align($align_b); + )? + #[allow(unused_mut)] + let mut k = DynKernel::<$mr, $nr, $ti>::new(stringify!($id), $func, packing_a, packing_b); $(k = k.with_platform_condition($where);)? $( assert!(k.packings.len() == $pnum); @@ -102,4 +114,3 @@ macro_rules! MMMKernel { } }; } - diff --git a/linalg/src/frame/mmm/mod.rs b/linalg/src/frame/mmm/mod.rs index 9c63aa915a..797dbfda1e 100644 --- a/linalg/src/frame/mmm/mod.rs +++ b/linalg/src/frame/mmm/mod.rs @@ -30,6 +30,8 @@ pub use kit::*; pub use scratch::*; pub use storage::*; +pub use pack::Packing; + pub fn no_prefetch(_ptr: *const u8, _len: usize) {} pub trait MatMatMul: Debug + dyn_clone::DynClone + Send + Sync + std::any::Any { diff --git a/linalg/src/frame/mmm/pack.rs b/linalg/src/frame/mmm/pack.rs index 770bf38b58..8beba7bda6 100644 --- a/linalg/src/frame/mmm/pack.rs +++ b/linalg/src/frame/mmm/pack.rs @@ -12,7 +12,7 @@ use super::MMMInputFormat; pub struct PackedFormat { pub dt: DatumType, pub r: usize, - pub alignment: usize, + pub alignment_bytes: usize, pub end_padding_record: usize, } @@ -47,22 +47,27 @@ impl Display for PackedFormat { impl Debug for PackedFormat { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - ::fmt(self, f) + write!(f, "Packed{:?}[{}]@{}+{}", self.dt, self.r, self.alignment_bytes, self.end_padding_record) } } impl PackedFormat { - pub const fn new(dt: DatumType, nr: usize, alignment: usize) -> PackedFormat { - PackedFormat { dt, r: nr, alignment, end_padding_record: 1 } + pub const fn new(dt: DatumType, nr: usize, alignment_bytes: usize) -> PackedFormat { + PackedFormat { dt, r: nr, alignment_bytes, end_padding_record: 1 } } pub const fn with_end_padding_record(self, end_padding_record: usize) -> Self { PackedFormat { end_padding_record, ..self } } + #[inline] + pub fn align(self, alignment: usize) -> Self { + Self { alignment_bytes: alignment, ..self } + } + #[inline] pub fn alignment(&self) -> usize { - self.alignment + self.alignment_bytes } #[inline] @@ -102,7 +107,7 @@ impl PackedFormat { let strides = t.strides(); unsafe { let mut packed = - Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment_bytes); if cfg!(debug_assertions) { packed.as_bytes_mut().fill(0u8); } @@ -141,7 +146,7 @@ impl PackedFormat { let strides = t.strides(); unsafe { let mut packed = - Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment); + Blob::new_for_size_and_align(t.datum_type().size_of() * packed_len, self.alignment_bytes); if cfg!(debug_assertions) { packed.as_bytes_mut().fill(0u8); } @@ -509,6 +514,21 @@ unsafe fn pack_mn_major( } } +pub trait Packing { + fn packing(r: usize) -> PackedFormat; +} + +impl Packing for D { + fn packing(r: usize) -> PackedFormat { + PackedFormat { + dt: Self::datum_type(), + r, + alignment_bytes: Self::datum_type().alignment(), + end_padding_record: 0, + } + } +} + #[cfg(test)] mod test { use std::ops::Range; diff --git a/linalg/src/frame/mmm/tests/fuse.rs b/linalg/src/frame/mmm/tests/fuse.rs index da909a80b2..48a7ffc230 100644 --- a/linalg/src/frame/mmm/tests/fuse.rs +++ b/linalg/src/frame/mmm/tests/fuse.rs @@ -172,7 +172,7 @@ where let v = c.to_vec(); let c = mmm_stride_storage(&v, ker.nr()); let mut ops = ops.to_vec(); - ops.insert(0, FusedKerSpec::AddUnicast(c)); + ops.insert(0, FusedKerSpec::AddUnicast(c)); // FIXME ops.insert(0, FusedKerSpec::Clear); ops.push(FusedKerSpec::Store(c)); ops.push(FusedKerSpec::Done); diff --git a/linalg/src/frame/mmm/tests/packed_packed.rs b/linalg/src/frame/mmm/tests/packed_packed.rs index 99fc177a62..8089bffdeb 100644 --- a/linalg/src/frame/mmm/tests/packed_packed.rs +++ b/linalg/src/frame/mmm/tests/packed_packed.rs @@ -22,7 +22,7 @@ macro_rules! mmm_packed_packed_tests { #[allow(unused_imports)] use $crate::frame::mmm::tests::packed_packed::*; - mod fuse { + mod packed_packed { use super::*; proptest::proptest! { @@ -56,6 +56,28 @@ macro_rules! mmm_packed_packed_tests { t((1..=$ker.mr() as i64).map(|x| x as f32).collect_vec(), vec![1f32; $ker.nr()]) } + #[test] + fn packed_packed_a_scale_times_2_left() -> TractResult<()> { + t( + (1..=2 * $ker.mr() as i64).map(|x| x as f32).collect_vec(), + vec![1f32; $ker.nr()] + .into_iter() + .chain(vec![0f32; $ker.nr()].into_iter()) + .collect_vec(), + ) + } + + #[test] + fn packed_packed_a_scale_times_2_right() -> TractResult<()> { + t( + (1..=2 * $ker.mr() as i64).map(|x| x as f32).collect_vec(), + vec![0f32; $ker.nr()] + .into_iter() + .chain(vec![1f32; $ker.nr()].into_iter()) + .collect_vec(), + ) + } + #[test] fn packed_packed_a_scale_times_2() -> TractResult<()> { t( diff --git a/linalg/src/generic/mmm.rs b/linalg/src/generic/mmm.rs index cc984b2d68..d97215af86 100644 --- a/linalg/src/generic/mmm.rs +++ b/linalg/src/generic/mmm.rs @@ -1,8 +1,8 @@ #![allow(clippy::needless_range_loop)] use num_traits::AsPrimitive; -use pack::PackedFormat; use tract_data::prelude::f16; +use tract_data::prelude::DatumType::*; use tract_data::prelude::*; use super::*; @@ -280,14 +280,20 @@ where FusedKerSpec::AddMatMul { k, pa, pb, packing } => { use std::mem::transmute; if TI::datum_type().is_float() { - if packing == 0 { - add_mat_mul::(pa, pb, k, &mut ab); - } else if packing == 1 { - add_mat_mul_pq40::(pa, pb, k, &mut ab); - } else if packing == 2 { - add_mat_mul_pq40_scales_at_end::(pa, pb, k, &mut ab) - } else if packing == 3 { - add_mat_mul_pq40::(pa, pb, k, &mut ab); + match packing { + 0 => add_mat_mul::(pa, pb, k, &mut ab), + 1 if TI::is::() => { + add_mat_mul::(pa, pb, k, &mut ab) + } + 1 if TI::is::() => { + add_mat_mul::(pa, pb, k, &mut ab) + } + 2 => add_mat_mul_pq40::(pa, pb, k, &mut ab), + 3 => add_mat_mul_pq40_scales_at_end::( + pa, pb, k, &mut ab, + ), + 4 => add_mat_mul_pq40::(pa, pb, k, &mut ab), + _ => unreachable!(), } } else if TI::datum_type() == i32::datum_type() { // transmute to allow using explicitly i3 in add_mat_mul generic params @@ -331,58 +337,79 @@ where fn pq40_r4() -> PackedBlockQuantFormat { PackedBlockQuantFormat::new(&Q4_0, 4, 0, false) } -fn pq40_r4_se() -> PackedBlockQuantFormat { + +fn pq40_r4se() -> PackedBlockQuantFormat { PackedBlockQuantFormat::new(&Q4_0, 4, 0, true) } // f16 kernels -MMMRustKernel!(kernel:: => generic_f16_4x4(4,4)@(4,4) store(f32, f64)); -MMMRustKernel! {kernel:: => generic_f16_4x1(4,1)@(4,1) - packing[1] = q40f16 => |k| k.with_packing_a(pq40_r4()); - packing[2] = q40f16se => |k| k.with_packing_a(pq40_r4_se()); - packing[3] = q40f32 => |k| k.with_packing(pq40_r4(), PackedFormat::new(DatumType::F32, 1, 4)); +MMMRustKernel!(kernel:: => generic_f16_4x4(4,4) store(f32, f64)); +MMMRustKernel! {kernel:: => generic_f16_4x1(4,1) + packing[1] = f32f32 => |k| k.with_packing(f32::packing(4), f32::packing(1)); + packing[2] = q40f16 => |k| k.with_packing_a(pq40_r4()); + packing[3] = q40f16se => |k| k.with_packing_a(pq40_r4se()); + packing[4] = q40f32 => |k| k.with_packing(pq40_r4(), f32::packing(1)); store(f32, f64) } // f32 kernels -MMMRustKernel!(kernel:: => generic_f32_4x4(4,4)@(4,4) store(f16, f64)); -MMMRustKernel! {kernel:: => generic_f32_4x1(4,1)@(4,1) - packing[1] = q40f16 => |k| k.with_packing(pq40_r4(), PackedFormat::new(DatumType::F16, 1, 4)); - packing[2] = q40f16se => |k| k.with_packing(pq40_r4_se(), PackedFormat::new(DatumType::F16, 1, 4)); - packing[3] = q40f32 => |k| k.with_packing_a(pq40_r4()); +MMMRustKernel!(kernel:: => generic_f32_4x4(4,4) + packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(4)); + store(f16, f64) +); +MMMRustKernel! {kernel:: => generic_f32_4x1(4,1) + packing[1] = f16f16 => |k| k.with_packing(f16::packing(4), f16::packing(1)); + packing[2] = q40f16 => |k| k.with_packing(pq40_r4(), f16::packing(1)); + packing[3] = q40f16se => |k| k.with_packing(pq40_r4se(), f16::packing(1)); + packing[4] = q40f32 => |k| k.with_packing_a(pq40_r4()); store(f16, f64) } // f64 kernels -MMMRustKernel!(kernel:: => generic_f64_4x4(4,4)@(4,4) store(f16, f32)); -MMMRustKernel!(kernel:: => generic_f64_4x1(4,1)@(4,1) store(f16, f32)); +MMMRustKernel!(kernel:: => generic_f64_4x4(4,4) store(f16, f32)); +MMMRustKernel!(kernel:: => generic_f64_4x1(4,1) store(f16, f32)); // I32 kernels -MMMRustKernel! {kernel:: => generic_i32_4x4(4,4)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 4, 4), PackedFormat::new(DatumType::I8, 4, 4)); +MMMRustKernel! {kernel:: => generic_i32_4x4(4,4) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(4)); store(i8) } -MMMRustKernel! {kernel:: => generic_i32_4x1(4,1)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 4, 4), PackedFormat::new(DatumType::I8, 1, 4)); +MMMRustKernel! {kernel:: => generic_i32_4x1(4,1) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(4), i8::packing(1)); store(i8) } // extra tests kernels - #[cfg(test)] -MMMRustKernel!(kernel:: => generic_f32_3x2(3,2)@(4,4) store(f16, f64)); +MMMRustKernel!(kernel:: => generic_f32_3x2(3,2) store(f16, f64)); #[cfg(test)] -MMMRustKernel! {kernel:: => generic_i32_3x2(3,2)@(4,4) - packing[1] = i8i8 => |k| k.with_packing(PackedFormat::new(DatumType::I8, 3, 4), PackedFormat::new(DatumType::I8, 2, 4)); +MMMRustKernel! {kernel:: => generic_i32_3x2(3,2) + packing[1] = i8i8 => |k| k.with_packing(i8::packing(3), i8::packing(2)); store(i8) } pub fn plug(ops: &mut Ops) { ops.mmm_kits.push( - MMMKit::new(Q4_0, f32::datum_type(), f32::datum_type(), &pq40_r4()) - .with_native(generic_f32_4x1.mmm(), 3) + MMMKit::new(Q4_0, F32, F32, &pq40_r4()) + .with_native(generic_f32_4x1.mmm(), 4) .with_generic_fallback(true), ); + ops.mmm_kits.push( + MMMKit::new(F16, F32, F16, &f16::packing(4)) + .with_native(generic_f32_4x1.mmm(), 1) + .with_native(generic_f32_4x4.mmm(), 1) + .with_generic_fallback(true), + ); +} + +#[cfg(test)] +mod test { + + #[test] + fn kits() { + let mut ops = crate::generic(); + super::plug(&mut ops); + } } diff --git a/linalg/src/x86_64_fma/mmm.rs b/linalg/src/x86_64_fma/mmm.rs index 89fc56f41b..31fed4d518 100644 --- a/linalg/src/x86_64_fma/mmm.rs +++ b/linalg/src/x86_64_fma/mmm.rs @@ -3,8 +3,9 @@ use crate::frame::PackedFormat; use crate::mmm::MMMKit; use crate::mmm::MatMatMulKer; use crate::Ops; -use panel_extract::packed_32_f16_to_f32; -use panel_extract::packed_32_q40_to_f32; +use panel_extract::avx512_packed_128_q40_to_f32; +use panel_extract::fma_packed_32_f16_to_f32; +use panel_extract::fma_packed_32_q40_to_f32; use tract_data::internal::*; use DatumType::*; @@ -20,6 +21,11 @@ MMMExternKernel!(fma_mmm_f32_64x1(64,1)@(32,4) where(FMA)); pub fn pq40_r32() -> PackedBlockQuantFormat { PackedBlockQuantFormat::new(&Q4_0, 32, 16, false) } + +pub fn pq40_r128() -> PackedBlockQuantFormat { + PackedBlockQuantFormat::new(&Q4_0, 128, 16, false) +} + MMMExternKernel! {fma_mmm_f32_32x1(32,1)@(32,4) where(FMA) packing[1] = q40f32 => |k| k.with_packing_a(pq40_r32()); packing[2] = q40f16 => |k| k.with_packing(pq40_r32(), PackedFormat::new(F16, 1, 2)); @@ -27,11 +33,15 @@ MMMExternKernel! {fma_mmm_f32_32x1(32,1)@(32,4) where(FMA) store(f16) } MMMExternKernel!(fma_mmm_f32_32x3(32,3)@(32,4) where(FMA) - packing[1] = f32f16 => |k| k.with_packing(PackedFormat::new(F32, 32, 32), PackedFormat::new(F16, 3, 2)); + packing[1] = f32f16 => |k| k.with_packing(f32::packing(32).align(32), PackedFormat::new(F16, 3, 2)); store(f16) ); -MMMExternKernel!(avx512_mmm_f32_128x1(128, 1)@(64,4) where (AVX512F)); +MMMExternKernel!(avx512_mmm_f32_128x1(128, 1)@(64,4) where (AVX512F) + packing[1] = q40f32 => |k| k.with_packing_a(pq40_r128()); +); +MMMExternKernel!(avx512_mmm_f32_128x3(128, 3)@(64,4) where (AVX512F)); + MMMExternKernel!(avx512_mmm_f32_16x1 ( 16, 1)@(64,4) where (AVX512F)); MMMExternKernel!(avx512_mmm_f32_16x12( 16,12)@(64,4) where (AVX512F)); MMMExternKernel!(avx512_mmm_f32_16x8 ( 16, 8)@(64,4) where (AVX512F)); @@ -47,6 +57,17 @@ MMMExternKernel! { avx2_mmm_i32_8x8(8,8)@(32,4) where(AVX2) } pub fn plug(ops: &mut Ops) { + if avx512_mmm_f32_16x1.is_supported_here() { + ops.mmm_kits.push( + MMMKit::new(Q4_0, F32, F32, &pq40_r128()) + .with_native(avx512_mmm_f32_128x1.mmm(), 1) + .with_extracting( + avx512_mmm_f32_128x3.mmm(), + 0, + avx512_packed_128_q40_to_f32.clone(), + ), + ) + } if fma_mmm_f32_32x1.is_supported_here() { ops.mmm_kits.push( MMMKit::new_for_mmm(fma_mmm_f32_32x1.mmm(), 0).with_native(fma_mmm_f32_32x3.mmm(), 0), @@ -54,17 +75,26 @@ pub fn plug(ops: &mut Ops) { ops.mmm_kits.push(MMMKit::new_for_mmm(fma_mmm_f32_32x1.mmm(), 1).with_extracting( fma_mmm_f32_32x3.mmm(), 0, - packed_32_q40_to_f32.clone(), + fma_packed_32_q40_to_f32.clone(), )); ops.mmm_kits.push(MMMKit::new_for_mmm(fma_mmm_f32_32x1.mmm(), 2).with_extracting( fma_mmm_f32_32x3.mmm(), 1, - packed_32_q40_to_f32.clone(), + fma_packed_32_q40_to_f32.clone(), )); ops.mmm_kits.push( MMMKit::new(F16, F32, F16, &PackedFormat::new(F16, 32, 32)) .with_native(fma_mmm_f32_32x1.mmm(), 3) - .with_extracting(fma_mmm_f32_32x3.mmm(), 1, packed_32_f16_to_f32.clone()), + .with_extracting(fma_mmm_f32_32x3.mmm(), 1, fma_packed_32_f16_to_f32.clone()), ); } } + +#[cfg(test)] +mod test { + #[test] + fn kits() { + let mut ops = crate::generic(); + super::plug(&mut ops); + } +} diff --git a/linalg/src/x86_64_fma/panel_extract.rs b/linalg/src/x86_64_fma/panel_extract.rs index be976844c8..e62552649b 100644 --- a/linalg/src/x86_64_fma/panel_extract.rs +++ b/linalg/src/x86_64_fma/panel_extract.rs @@ -1,22 +1,39 @@ use super::*; use crate::frame::PackedFormat; +use crate::frame::mmm::Packing; use crate::Ops; use tract_data::internal::*; pub fn plug(ops: &mut Ops) { - ops.panel_extractors.extend([packed_32_q40_to_f32.clone(), packed_32_f16_to_f32.clone()]); + ops.panel_extractors.extend([ + fma_packed_32_q40_to_f32.clone(), + fma_packed_32_f16_to_f32.clone(), + avx512_packed_128_q40_to_f32.clone(), + ]); } -panel_extractor!(kernel_packed_32_q40_to_f32 as packed_32_q40_to_f32( +panel_extractor!(kernel_packed_32_q40_to_f32 as fma_packed_32_q40_to_f32( Box::new(super::mmm::pq40_r32()), - PackedFormat::new(f32::datum_type(), 32, 32) + f32::packing(32).align(32) ) where(AVX2)); -panel_extractor!(kernel_packed_32_f16_to_f32 as packed_32_f16_to_f32( +panel_extractor!(kernel_packed_32_f16_to_f32 as fma_packed_32_f16_to_f32( Box::new(PackedFormat::new(f16::datum_type(), 32, 32)), - PackedFormat::new(f32::datum_type(), 32, 32) + f32::packing(32).align(32) ) where(AVX2)); +panel_extractor!(kernel_packed_128_q40_to_f32::kernel as avx512_packed_128_q40_to_f32( + Box::new(super::mmm::pq40_r128()), + f32::packing(128).align(64) +) where(AVX512F)); + +mod kernel_packed_128_q40_to_f32 { + extern_kernel!(fn avx512_packed_128_q40_to_f32(i: *const u8, output: *mut u8, k: usize) -> ()); + pub unsafe fn kernel(input: *const u8, output: *mut u8, k: usize) { + avx512_packed_128_q40_to_f32(input, output, k) + } +} + #[target_feature(enable = "avx2")] unsafe fn kernel_packed_32_q40_to_f32(input: *const u8, output: *mut u8, k: usize) { debug_assert!(k % 32 == 0); diff --git a/linalg/x86_64/avx512/avx512_mmm_f32_128x1.tmpl b/linalg/x86_64/avx512/avx512_mmm_f32_128x1.tmpl index 382ae2ca68..fda2646aae 100644 --- a/linalg/x86_64/avx512/avx512_mmm_f32_128x1.tmpl +++ b/linalg/x86_64/avx512/avx512_mmm_f32_128x1.tmpl @@ -33,15 +33,113 @@ Windows ABI: mov rax, [rdi + 16] // A mov rbx, [rdi + 8] // k + mov r8, [rdi + 32] // packing test rbx, rbx jz {{L}}non_linear_loop + cmp r8, 1 + jz {{L}}q40f32 + {{align}} 16 -{{L}}main_loop_packed_packed: +{{L}}f32f32: {% include "8x1/packed_packed_loop1/avx-512.tmpli" %} sub rbx, 1 - jnz {{L}}main_loop_packed_packed + jnz {{L}}f32f32 + + jmp {{L}}non_linear_loop + +{{L}}q40f32_mask: +{% if msvc %} + {{long}} 0F0F0F0Fh +{% else %} + {{long}} 0x0F0F0F0F +{% endif %} + +{{L}}q40f32_eight: + {{long}} 8 + +{{L}}q40f32_perm: + {{quad}} 2 + {{quad}} 3 + {{quad}} 4 + {{quad}} 5 + {{quad}} 6 + {{quad}} 7 + {{quad}} 0 // we dont care what's rolling in from the right + {{quad}} 0 + +{{L}}q40f32: + // zmm0-7: acc + // zmm8-16: scales + // zmm30: 8 + // zmm29: mask + // zmm31: b value + vbroadcastss zmm29, dword ptr [{{offset}} {{L}}q40f32_mask] + vbroadcastss zmm30, dword ptr [{{offset}} {{L}}q40f32_eight] + vmovups zmm28, [{{offset}} {{L}}q40f32_perm] + +{{L}}q40f32_outerloop: + // scales + {% for i in (0..7) %} + vmovaps ymm{{i|plus:8}}, [rax + {{i|times:32}}] + {% endfor %} + {% for i in (0..7) %} + vcvtph2ps zmm{{i|plus:8}}, ymm{{i|plus:8}} + {% endfor %} + add rax, 256 + + mov rdx, 32 + +{{L}}q40f32_innerloop: + vbroadcastss zmm31, dword ptr [rcx] + vmovaps zmm27, [rax] // 128 nibbles + + vpandq zmm26, zmm27, zmm29 // 64 bytes + + vpmovzxbd zmm16, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm17, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm18, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm19, xmm26 // 16 u32 + + vpsrlw zmm27, zmm27, 4 + vpandq zmm26, zmm27, zmm29 // 64 bytes + + vpmovzxbd zmm20, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm21, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm22, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm23, xmm26 // 16 u32 + + + {% for i in (16..23) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm30 + {% endfor %} + + {% for i in (16..23) %} + vcvtdq2ps zmm{{i}}, zmm{{i}} + {% endfor %} + + {% for i in (0..7) %} + vmulps zmm{{i|plus:16}}, zmm{{i|plus:16}}, zmm{{i|plus:8}} + {% endfor %} + + {% for i in (0..7) %} + vfmadd231ps zmm{{i}}, zmm{{i|plus:16}}, zmm31 + {% endfor %} + + add rax, 64 + add rcx, 4 + sub rdx, 1 + jnz {{L}}q40f32_innerloop + + sub rbx, 32 + jnz {{L}}q40f32_outerloop jmp {{L}}non_linear_loop diff --git a/linalg/x86_64/avx512/avx512_mmm_f32_128x3.tmpl b/linalg/x86_64/avx512/avx512_mmm_f32_128x3.tmpl new file mode 100644 index 0000000000..077306f289 --- /dev/null +++ b/linalg/x86_64/avx512/avx512_mmm_f32_128x3.tmpl @@ -0,0 +1,155 @@ +{% comment %} +// vim: set syntax=asm : + +/* mmm 128 x 3: + + zmm0 zmm8 zmm816 + zmm1 zmm9 zmm17 + zmm2 zmm10 zmm18 + zmm3 zmm11 zmm19 + zmm4 zmm12 zmm20 + zmm5 zmm13 zmm21 + zmm6 zmm14 zmm22 + zmm7 zmm15 zmm23 + + +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of YMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +{% endcomment %} + +{% include "preamble.tmpliq" size:"128x3", suffix:suffix, G:G, arch:"avx512" %} + +{{L}}clear: + vzeroall + {% for i in (16..23) %} + vmovapd zmm{{i}}, zmm0 + {% endfor %} + jmp {{L}}non_linear_loop + +{{L}}add_mat_mul: + mov rbx, [rdi + 24] // B + mov rax, [rdi + 16] // A + + mov rcx, [rdi + 8] // k + test rcx, rcx + jz {{L}}non_linear_loop + +{{L}}main_loop_packed_packed: + vbroadcastss zmm29, dword ptr [rbx] + vbroadcastss zmm30, dword ptr [rbx+4] + vbroadcastss zmm31, dword ptr [rbx+8] + +{% for i in (0..7) %} + vmovaps zmm28, zmmword ptr [rax+{{i | times:64}}] + vfmadd231ps zmm{{i}}, zmm28, zmm29 + vfmadd231ps zmm{{i | plus: 8}}, zmm28, zmm30 + vfmadd231ps zmm{{i | plus: 16}}, zmm28, zmm31 +{% endfor %} + + add rbx, 12 + add rax, 512 + + dec rcx + jnz {{L}}main_loop_packed_packed + + jmp {{L}}non_linear_loop + +{% include "f32_scalars.tmpliq" from:0, to:23 %} +{% include "f32_per_rows.tmpliq" mr:128, from:0, to:23 %} +{% include "f32_per_cols.tmpliq" mr:128, from:0, to:23 %} +{% include "avx512_mmm_load_tile.tmpliq" from:0, to:23 %} + +{{L}}range_0_16: +{% for i in (0..15) %} + {{long}} {{i}} +{% endfor %} + +{{L}}add_unicast: + + mov r10, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + + vbroadcastss zmm29, dword ptr [rdi+16] // row stride (aka esi) + vmovups zmm26, [{{offset}} {{L}}range_0_16] + vpmulld zmm26, zmm26, zmm29 + +{% for i in (0..2) %} + kxnorw k1,k1,k1 + vgatherdps zmm24{k1}, [ r10 + zmm26 ] + add r10, rbx + vaddps zmm{{i | times: 8}}, zmm{{i | times: 8}}, zmm24 +{% endfor %} + + imul esi, 16 + vpbroadcastd zmm27, esi + +{% for j in (1..7) %} + mov r10, [rdi + 8] + vpaddd zmm26, zmm26, zmm27 + + {% for i in (0..2) %} + kxnorw k1,k1,k1 + vgatherdps zmm24{k1}, [ r10 + zmm26 ] + add r10, rbx + vaddps zmm{{i | times: 8 | plus: j}}, zmm{{i | times: 8 | plus: j}}, zmm24 + {% endfor %} +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}add_row_col_products: + mov rax, [ rdi + 8 ] + mov rbx, [ rdi + 16 ] + + vbroadcastss zmm29, dword ptr [rbx] + vbroadcastss zmm30, dword ptr [rbx+4] + vbroadcastss zmm31, dword ptr [rbx+8] + +{% for i in (0..7) %} + vmovups zmm28, zmmword ptr [rax+{{i | times:64}}] + vfmadd231ps zmm{{i}}, zmm28, zmm29 + vfmadd231ps zmm{{i | plus: 8}}, zmm28, zmm30 + vfmadd231ps zmm{{i | plus: 16}}, zmm28, zmm31 +{% endfor %} + + jmp {{L}}non_linear_loop + +{{L}}store: + mov r8, [rdi + 8] // c ptr + mov rsi, [rdi + 16] // row stride + mov rbx, [rdi + 24] // col stride + + // tops of cols + lea r9, [ r8 + rbx ] + lea r10, [ r8 + 2 * rbx ] + lea r11, [ r10 + rbx ] + + {% for word in (0..7) %} + {% for quarter in (0..3) %} + {% for r in (0..2) %} + vextractf32x4 xmm{{r | plus: 24}}, zmm{{r | times: 8 | plus: word}}, {{quarter}} + {% endfor %} + {% for row in (0..3) %} + {% for i in (0..2) %} + vextractps dword ptr [r{{i | plus: 8}}], xmm{{i | plus: 24}}, {{row}} + add r{{i | plus: 8}}, rsi + {% endfor %} + {% endfor %} + {% endfor %} + {% endfor %} + + jmp {{L}}non_linear_loop + +{% include "postamble.tmpliq" size:"128x3", suffix:suffix, G:G, L:L, arch:"avx512" %} + diff --git a/linalg/x86_64/avx512/avx512_packed_128_q40_to_f32.tmpl b/linalg/x86_64/avx512/avx512_packed_128_q40_to_f32.tmpl new file mode 100644 index 0000000000..99a8e78a30 --- /dev/null +++ b/linalg/x86_64/avx512/avx512_packed_128_q40_to_f32.tmpl @@ -0,0 +1,201 @@ +{% comment %} +// vim: set syntax=asm : + +/* +System V ABI: + args: rdi, rsi, rdx, rcx, r8, r9 + preserve: rbx, rsp, rbp, r12, r13, r14, r15 + scratch: rax, rdi, rsi, rdx, rcx, r8, r9, r10, r11 + return: rax (+rdx) + +Windows ABI: + args: RCX, RDX, R8, R9 + preserve: RBX, RBP, RDI, RSI, RSP, R12, R13, R14, R15, and XMM6-15 + scratch: RAX, RCX, RDX, R8, R9, R10, R11, XMM0-5, and the upper portions of ZMM0-15 and ZMM0-15 + return: rax (+rdx) +*/ +{% endcomment %} +{% if msvc %} + +_text segment +avx512_packed_128_q40_to_f32_{{suffix}} proc + +{% else %} + +.intel_syntax noprefix +.text +.p2align 5 +.globl {{G}}avx512_packed_128_q40_to_f32_{{suffix}} +{{G}}avx512_packed_128_q40_to_f32_{{suffix}}: +.cfi_startproc + +{% endif %} + + push rbp + mov rbp, rsp + +{% if family == "windows" %} +// https://www.agner.org/optimize/calling_conventions.pdf xmm6-15 are not scratch +// https://stackoverflow.com/questions/43358429/save-value-of-xmm-registers + and rsp,-16 + lea rsp,[rsp-160] + vmovaps [rsp], xmm6 + vmovaps [rsp+16*1],xmm7 + vmovaps [rsp+16*2],xmm8 + vmovaps [rsp+16*3],xmm9 + vmovaps [rsp+16*4],xmm10 + vmovaps [rsp+16*5],xmm11 + vmovaps [rsp+16*6],xmm12 + vmovaps [rsp+16*7],xmm13 + vmovaps [rsp+16*8],xmm14 + vmovaps [rsp+16*9],xmm15 + +// FIXME calling_conventions + push rdi + push rsi + +// win: rcx:input rdx: output, r8:k + mov rdi, rcx + mov rsi, rdx + mov rdx, r8 + +{% endif %} + + sub rsp, 8 +{% if family == "unix" %} +.cfi_def_cfa_offset 64 +{% endif %} + stmxcsr [rsp + 4] +{% if msvc %} + mov rax, 1FC0h +{% else %} + mov rax, 0x1FC0 +{% endif %} + mov [rsp], eax + ldmxcsr [rsp] + +// unix: rdi:input rsi: output, rdx:k + +{{L}}q40f32: + // zmm0-7: acc + // zmm8-16: scales + // zmm30: 8 + // zmm29: mask + // zmm31: b value + vbroadcastss zmm29, dword ptr [{{offset}} {{L}}q40f32_mask] + vbroadcastss zmm30, dword ptr [{{offset}} {{L}}q40f32_eight] + vmovups zmm28, [{{offset}} {{L}}q40f32_perm] + +{{L}}q40f32_outerloop: + // scales + {% for i in (0..7) %} + vmovaps ymm{{i|plus:8}}, [rdi + {{i|times:32}}] + {% endfor %} + {% for i in (0..7) %} + vcvtph2ps zmm{{i|plus:8}}, ymm{{i|plus:8}} + {% endfor %} + add rdi, 256 + mov rax, 32 + +{{L}}q40f32_innerloop: + vmovaps zmm27, [rdi] // 128 nibbles + + vpandq zmm26, zmm27, zmm29 // 64 bytes + + vpmovzxbd zmm16, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm17, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm18, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm19, xmm26 // 16 u32 + + vpsrlw zmm27, zmm27, 4 + vpandq zmm26, zmm27, zmm29 // 64 bytes + + vpmovzxbd zmm20, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm21, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm22, xmm26 // 16 u32 + vpermt2q zmm26, zmm28, zmm26 + vpmovzxbd zmm23, xmm26 // 16 u32 + + + {% for i in (16..23) %} + vpsubd zmm{{i}}, zmm{{i}}, zmm30 + {% endfor %} + + {% for i in (16..23) %} + vcvtdq2ps zmm{{i}}, zmm{{i}} + {% endfor %} + + {% for i in (0..7) %} + vmulps zmm{{i|plus:16}}, zmm{{i|plus:16}}, zmm{{i|plus:8}} + {% endfor %} + + {% for i in (0..7) %} + vmovaps [rsi + {{i|times:64}}], zmm{{i|plus:16}} + {% endfor %} + + add rdi, 64 + add rsi, 512 + sub rax, 1 + jnz {{L}}q40f32_innerloop + + sub rdx, 32 + jnz {{L}}q40f32_outerloop + +{{L}}return: + ldmxcsr [rsp + 4] + add rsp, 8 + +{% if family == "windows" %} + pop rsi + pop rdi + + vmovaps xmm15, [rsp+16*9] + vmovaps xmm14, [rsp+16*8] + vmovaps xmm13, [rsp+16*7] + vmovaps xmm12, [rsp+16*6] + vmovaps xmm11, [rsp+16*5] + vmovaps xmm10, [rsp+16*4] + vmovaps xmm9, [rsp+16*3] + vmovaps xmm8, [rsp+16*2] + vmovaps xmm7, [rsp+16*1] + vmovaps xmm6, [rsp] +{% endif %} + + mov rsp, rbp + pop rbp + ret + +{{L}}q40f32_mask: +{% if msvc %} + {{long}} 0F0F0F0Fh +{% else %} + {{long}} 0x0F0F0F0F +{% endif %} + +{{L}}q40f32_eight: + {{long}} 8 + +{{L}}q40f32_perm: + {{quad}} 2 + {{quad}} 3 + {{quad}} 4 + {{quad}} 5 + {{quad}} 6 + {{quad}} 7 + {{quad}} 0 // we dont care what's rolling in from the right + {{quad}} 0 + + +{% if msvc %} +avx512_packed_128_q40_to_f32_{{suffix}} endp +_text ends +end + +{% else %} +.cfi_endproc +{% endif %} diff --git a/linalg/x86_64/avx512/f32_scalars.tmpliq b/linalg/x86_64/avx512/f32_scalars.tmpliq index 7876d6cbaa..65dec989df 100644 --- a/linalg/x86_64/avx512/f32_scalars.tmpliq +++ b/linalg/x86_64/avx512/f32_scalars.tmpliq @@ -10,14 +10,14 @@ {{L}}leaky_relu: // can only use zmm12 to zmm15 // ymm15 <- alpha - vbroadcastss zmm15, dword ptr [rdi + 8] + vbroadcastss zmm31, dword ptr [rdi + 8] // ymm14 <- all zero - vpxorq zmm14, zmm14, zmm14 + vpxorq zmm30, zmm30, zmm30 {% for reg in (from..to) %} - vcmpps k1, zmm{{reg}}, zmm14, 1 // 1 means LT + vcmpps k1, zmm{{reg}}, zmm30, 1 // 1 means LT // ymm12 <- alpha * x if < 0 - vmulps zmm{{reg}} {k1}, zmm{{reg}}, zmm15 + vmulps zmm{{reg}} {k1}, zmm{{reg}}, zmm31 {% endfor %} // select muled of orginal diff --git a/linalg/x86_64/avx512/zmm_scalar.tmpliq b/linalg/x86_64/avx512/zmm_scalar.tmpliq index 43373c9d82..c38a5965cf 100644 --- a/linalg/x86_64/avx512/zmm_scalar.tmpliq +++ b/linalg/x86_64/avx512/zmm_scalar.tmpliq @@ -1,14 +1,14 @@ // vim: set syntax=asm : {{L}}{{label}}: - vbroadcastss zmm12, dword ptr [rdi + 8] + vbroadcastss zmm31, dword ptr [rdi + 8] {% if flipped %} {% for reg in (from..to) %} - {{op}} zmm{{reg}}, zmm{{reg}}, zmm12 + {{op}} zmm{{reg}}, zmm{{reg}}, zmm31 {% endfor %} {% else %} {% for reg in (from..to) %} - {{op}} zmm{{reg}}, zmm12, zmm{{reg}} + {{op}} zmm{{reg}}, zmm31, zmm{{reg}} {% endfor %} {% endif %}