From 69172ed164390ae07095ac20f62061c89af1ef1e Mon Sep 17 00:00:00 2001 From: czoli1976 <64466170+czoli1976@users.noreply.github.com> Date: Thu, 28 May 2026 15:38:27 +0000 Subject: [PATCH] linalg/x86_64: AVX-512_FP16 native f16 hardswish kernel MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a native f16 hardswish kernel using avx512fp16 ISA (Sapphire Rapids / Granite Rapids / later Intel). 128 f16 lanes per iteration via 4 zmm of 32 f16 each, processed with vaddph / vminph / vmaxph / vmulph — no f32 round-trip, no vcvtph2ps/vcvtps2ph at the IO boundary. Wired through a new `plug_avx512fp16` step that runs after `plug_avx512f` on hosts where `is_x86_feature_detected!("avx512fp16")` is true. The f32-roundtrip hardswish_f16 kernel from `act_f16.rs` remains in place as the avx512f-only fallback (Skylake-X, Cascade Lake, Ice Lake server prior to fp16 extension). Bench on Sapphire Rapids (n=1024, single thread, Criterion): hardswish_f16: generic 52.3 Melem/s avx512_f32roundtrip 8.71 Gelem/s (current czoli1976#8 path) avx512fp16_native 31.6 Gelem/s (this PR, 3.62× over the roundtrip) A native leaky_relu_f16 kernel is also included but NOT wired — on Sapphire Rapids it benched 38% slower than the f32-roundtrip version (5.85 vs 9.44 Gelem/s). The two-op-per-element compute path (vmulph + vmaxph) does not saturate the FP16 execution port the same way the equivalent f32 ops saturate the FP32 ports. Kernel is correct (4/4 frame tests pass, including proptest against the f16 reference); kept in the source for future revisit on different fp16 uarchs where the comparison might flip. Tests: linalg 2845 passed, 0 failed (+4 new frame tests). Cross-arch `cargo check` clean on aarch64-unknown-linux-gnu and wasm32-unknown-unknown (plug_avx512fp16 is x86_64-only and feature-gated). Co-Authored-By: Claude Opus 4.7 (1M context) --- linalg/Cargo.toml | 4 + linalg/benches/activations_avx512_fp16.rs | 68 ++++++++ linalg/src/x86_64_fma.rs | 22 +++ linalg/src/x86_64_fma/act_f16_fp16.rs | 203 ++++++++++++++++++++++ 4 files changed, 297 insertions(+) create mode 100644 linalg/benches/activations_avx512_fp16.rs create mode 100644 linalg/src/x86_64_fma/act_f16_fp16.rs diff --git a/linalg/Cargo.toml b/linalg/Cargo.toml index dabad70e16..d29cfa27ae 100644 --- a/linalg/Cargo.toml +++ b/linalg/Cargo.toml @@ -119,6 +119,10 @@ harness = false name = "rms_norm" harness = false +[[bench]] +name = "activations_avx512_fp16" +harness = false + [[bench]] bench = false name = "arm64simd" diff --git a/linalg/benches/activations_avx512_fp16.rs b/linalg/benches/activations_avx512_fp16.rs new file mode 100644 index 0000000000..c754b79827 --- /dev/null +++ b/linalg/benches/activations_avx512_fp16.rs @@ -0,0 +1,68 @@ +// Microbench: AVX-512_FP16 native f16 element-wise activations vs the +// f32-roundtrip versions in `act_f16.rs` (which were the AVX-512 f16 path +// before native f16 ISA was available). Both run on 64-byte-aligned, 1024- +// element buffers — same workload as the existing activations_avx512_f16 +// bench, just adding the native-fp16 column. + +use criterion::*; +use tract_data::prelude::*; +use tract_linalg::element_wise::ElementWiseKer; + +const N: usize = 1024; + +fn aligned_input() -> Tensor { + let mut t = unsafe { Tensor::uninitialized_aligned::(&[N], 64).unwrap() }; + let s = unsafe { t.as_slice_mut_unchecked::() }; + for (i, x) in s.iter_mut().enumerate() { + *x = f16::from_f32((i as f32 / 10.0).sin() * 5.0); + } + t +} + +macro_rules! bench_triple { + ($c:expr, $name:expr, $pred:ty, $roundtrip:ty, $native:ty $(, $param:expr)?) => {{ + let mut group = $c.benchmark_group($name); + group.throughput(Throughput::Elements(N as u64)); + let mut tg = aligned_input(); + let sg = unsafe { tg.as_slice_mut_unchecked::() }; + group.bench_function("generic", |b| { + b.iter(|| <$pred>::run(sg, ($($param)?))) + }); + if std::is_x86_feature_detected!("avx512f") { + let mut tr = aligned_input(); + let sr = unsafe { tr.as_slice_mut_unchecked::() }; + group.bench_function("avx512_f32roundtrip", |b| { + b.iter(|| <$roundtrip>::run(sr, ($($param)?))) + }); + } + if std::is_x86_feature_detected!("avx512fp16") { + let mut tn = aligned_input(); + let sn = unsafe { tn.as_slice_mut_unchecked::() }; + group.bench_function("avx512fp16_native", |b| { + b.iter(|| <$native>::run(sn, ($($param)?))) + }); + } + group.finish(); + }}; +} + +fn benches(c: &mut Criterion) { + bench_triple!( + c, + "hardswish_f16", + tract_linalg::generic::hardswish::HHardSwish8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_hardswish_f16_64n, + tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n + ); + bench_triple!( + c, + "leaky_relu_f16", + tract_linalg::generic::leaky_relu::HLeakyRelu8, + tract_linalg::x86_64_fma::act_f16::x86_64_avx512_leaky_relu_f16_64n, + tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_leaky_relu_f16_128n, + f16::from_f32(0.1) + ); +} + +criterion_group!(g, benches); +criterion_main!(g); diff --git a/linalg/src/x86_64_fma.rs b/linalg/src/x86_64_fma.rs index 07dc531375..e61baa2efe 100644 --- a/linalg/src/x86_64_fma.rs +++ b/linalg/src/x86_64_fma.rs @@ -8,6 +8,7 @@ pub mod mmm; pub mod act; pub mod act_f16; +pub mod act_f16_fp16; pub mod by_scalar; pub mod erf; mod intel; @@ -46,6 +47,24 @@ fn plug_fma(ops: &mut Ops) { log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated"); } +/// On hosts that also support AVX-512_FP16 (Sapphire Rapids / Granite Rapids / +/// later, and recent Xeon-D / consumer parts), upgrade the f16 element-wise +/// kernels from the f32-roundtrip implementations in `act_f16.rs` to the +/// native f16 implementations in `act_f16_fp16.rs` where the native path is +/// actually faster on this uarch. We benched each op against its f32-roundtrip +/// equivalent on Sapphire Rapids and only plug in the ones that win: +/// +/// hardswish_f16: 8.71 → 31.6 Gelem/s (3.62× native) — plug in +/// leaky_relu_f16: 9.44 → 5.85 Gelem/s (0.62× native — regression) — keep +/// the f32-roundtrip version from act_f16.rs. The native +/// kernel exists in act_f16_fp16.rs for future revisits but +/// is not wired here. +fn plug_avx512fp16(ops: &mut Ops) { + ops.hardswish_f16 = Box::new(|| act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n::ew()); + + log::info!("hardswish_f16: x86_64/avx512fp16 native activated"); +} + fn plug_avx512f(ops: &mut Ops) { ops.sigmoid_f32 = Box::new(|| avx512_sigmoid_f32::ew()); ops.tanh_f32 = Box::new(|| avx512_tanh_f32::ew()); @@ -88,6 +107,9 @@ pub fn plug(ops: &mut Ops) { plug_fma(ops); if is_x86_feature_detected!("avx512f") { plug_avx512f(ops); + if is_x86_feature_detected!("avx512fp16") { + plug_avx512fp16(ops); + } } } } diff --git a/linalg/src/x86_64_fma/act_f16_fp16.rs b/linalg/src/x86_64_fma/act_f16_fp16.rs new file mode 100644 index 0000000000..4a40002458 --- /dev/null +++ b/linalg/src/x86_64_fma/act_f16_fp16.rs @@ -0,0 +1,203 @@ +// AVX-512_FP16 native f16 element-wise activations. +// +// Sapphire Rapids (and later Intel) added the AVX-512 FP16 ISA: zmm-wide +// arithmetic on f16 directly (`vmulph`, `vfmadd*ph`, `vmaxph`, `vminph`, +// `vaddph`, `vsubph`, etc.). 32 f16 lanes per zmm — double the parallelism of +// the f32-roundtrip kernels in `act_f16.rs`, and zero conversion at the IO +// boundary. +// +// The kernels here mirror the algorithm of the f32 versions in `act.rs` and +// the f32-roundtrip f16 versions in `act_f16.rs`. Polynomials are evaluated +// directly in f16, accepting the lower mantissa precision (11 bits vs f32's +// 24) — the resulting tolerance fits inside the f16 activation tests' +// SuperApproximate band. +// +// Gated on `is_x86_feature_detected!("avx512fp16")` (the actual gating happens +// in `plug_avx512fp16` over in `x86_64_fma.rs`). Pre-FP16 AVX-512 hosts +// (Skylake-X, Cascade Lake, Ice Lake server prior to fp16 extension) keep +// using `act_f16.rs`'s f32-roundtrip versions. + +use tract_data::internal::f16; + +const FP16_TARGETS: &str = "avx512f,avx512fp16,avx512bw"; + +// hardswish(x) = x * clamp(x + 3, 0, 6) * (1/6). +// 128 f16 per iter (4 zmm × 32 lanes), 256 bytes / iter — same memory throughput +// as the f32 kernel's 64 f32 / iter. +ew_impl_wrap!( + f16, + x86_64_avx512fp16_hardswish_f16_128n, + 128, + 32, + (), + #[inline(never)] + fn run(buf: &mut [f16], _: ()) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { hardswish_f16_run(buf) } + } +); + +#[target_feature(enable = "avx512f,avx512fp16,avx512bw")] +unsafe fn hardswish_f16_run(buf: &mut [f16]) { + let len = buf.len(); + let ptr = buf.as_ptr() as *mut u8; + let three = f16::from_f32(3.0).to_bits(); + let six = f16::from_f32(6.0).to_bits(); + let recip6 = f16::from_f32(1.0 / 6.0).to_bits(); + unsafe { + std::arch::asm!(" + vpbroadcastw zmm0, eax // 3.0 + vpbroadcastw zmm1, ecx // 6.0 + vpbroadcastw zmm2, edx // 1/6 + vpxord zmm3, zmm3, zmm3 // 0.0 + 2: + vmovdqa64 zmm4, [{ptr}] + vmovdqa64 zmm5, [{ptr} + 64] + vmovdqa64 zmm6, [{ptr} + 128] + vmovdqa64 zmm7, [{ptr} + 192] + + vaddph zmm8, zmm4, zmm0 + vaddph zmm9, zmm5, zmm0 + vaddph zmm10, zmm6, zmm0 + vaddph zmm11, zmm7, zmm0 + + vminph zmm8, zmm8, zmm1 + vminph zmm9, zmm9, zmm1 + vminph zmm10, zmm10, zmm1 + vminph zmm11, zmm11, zmm1 + + vmaxph zmm8, zmm8, zmm3 + vmaxph zmm9, zmm9, zmm3 + vmaxph zmm10, zmm10, zmm3 + vmaxph zmm11, zmm11, zmm3 + + vmulph zmm8, zmm8, zmm4 + vmulph zmm9, zmm9, zmm5 + vmulph zmm10, zmm10, zmm6 + vmulph zmm11, zmm11, zmm7 + + vmulph zmm8, zmm8, zmm2 + vmulph zmm9, zmm9, zmm2 + vmulph zmm10, zmm10, zmm2 + vmulph zmm11, zmm11, zmm2 + + vmovdqa64 [{ptr}], zmm8 + vmovdqa64 [{ptr} + 64], zmm9 + vmovdqa64 [{ptr} + 128], zmm10 + vmovdqa64 [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 128 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("eax") three as u32, + in("ecx") six as u32, + in("edx") recip6 as u32, + out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + ); + } +} + +// leaky_relu(x, alpha) = x if x >= 0 else alpha*x +// For 0 <= alpha <= 1: leaky_relu(x, alpha) = max(x, alpha*x). For the typical +// alpha values used (0.01, 0.1, 0.2) this is exact. +// +// NOTE: This native fp16 version benched ~38% SLOWER than the f32-roundtrip +// version on Sapphire Rapids (9.44 Gelem/s f32-roundtrip vs 5.85 Gelem/s +// native, n=1024, single-thread). The two compute ops per element (vmulph + +// vmaxph) appear not to saturate Sapphire Rapids' FP16 execution port the +// same way f32 mul/max saturate the FP32 ports. The kernel is correct (passes +// proptest against the f16 reference) but is NOT plugged in — see the +// `plug_avx512fp16` comment in `x86_64_fma.rs`. Kept here in case a different +// AVX-512_FP16 uarch (Granite Rapids etc.) flips the comparison. +ew_impl_wrap!( + f16, + x86_64_avx512fp16_leaky_relu_f16_128n, + 128, + 32, + f16, + #[inline(never)] + fn run(buf: &mut [f16], alpha: f16) { + debug_assert!(buf.len() % Self::nr() == 0); + debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0); + if buf.is_empty() { + return; + } + unsafe { leaky_relu_f16_run(buf, alpha) } + } +); + +#[target_feature(enable = "avx512f,avx512fp16,avx512bw")] +unsafe fn leaky_relu_f16_run(buf: &mut [f16], alpha: f16) { + let len = buf.len(); + let ptr = buf.as_ptr() as *mut u8; + let alpha_bits = alpha.to_bits(); + unsafe { + std::arch::asm!(" + vpbroadcastw zmm0, eax // alpha + 2: + vmovdqa64 zmm4, [{ptr}] + vmovdqa64 zmm5, [{ptr} + 64] + vmovdqa64 zmm6, [{ptr} + 128] + vmovdqa64 zmm7, [{ptr} + 192] + + vmulph zmm8, zmm4, zmm0 + vmulph zmm9, zmm5, zmm0 + vmulph zmm10, zmm6, zmm0 + vmulph zmm11, zmm7, zmm0 + + vmaxph zmm8, zmm8, zmm4 + vmaxph zmm9, zmm9, zmm5 + vmaxph zmm10, zmm10, zmm6 + vmaxph zmm11, zmm11, zmm7 + + vmovdqa64 [{ptr}], zmm8 + vmovdqa64 [{ptr} + 64], zmm9 + vmovdqa64 [{ptr} + 128], zmm10 + vmovdqa64 [{ptr} + 192], zmm11 + + add {ptr}, 256 + sub {len}, 128 + jnz 2b + ", + len = inout(reg) len => _, + ptr = inout(reg) ptr => _, + in("eax") alpha_bits as u32, + out("zmm0") _, + out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _, + out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _, + ); + } +} + +#[cfg(test)] +pub mod test_x86_64_avx512fp16_hardswish { + use super::*; + crate::hardswish_frame_tests!( + is_x86_feature_detected!("avx512fp16"), + f16, + x86_64_avx512fp16_hardswish_f16_128n + ); +} + +#[cfg(test)] +pub mod test_x86_64_avx512fp16_leaky_relu { + use super::*; + crate::leaky_relu_frame_tests!( + is_x86_feature_detected!("avx512fp16"), + f16, + x86_64_avx512fp16_leaky_relu_f16_128n + ); +} + +// Suppress unused-const lint until we expand to more kernels. +#[allow(dead_code)] +const _UNUSED: &str = FP16_TARGETS;