Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ harness = false
name = "rms_norm"
harness = false

[[bench]]
name = "activations_avx512_fp16"
harness = false

[[bench]]
bench = false
name = "arm64simd"
Expand Down
68 changes: 68 additions & 0 deletions linalg/benches/activations_avx512_fp16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Microbench: AVX-512_FP16 native f16 element-wise activations vs the
// f32-roundtrip versions in `act_f16.rs` (which were the AVX-512 f16 path
// before native f16 ISA was available). Both run on 64-byte-aligned, 1024-
// element buffers — same workload as the existing activations_avx512_f16
// bench, just adding the native-fp16 column.

use criterion::*;
use tract_data::prelude::*;
use tract_linalg::element_wise::ElementWiseKer;

const N: usize = 1024;

fn aligned_input() -> Tensor {
let mut t = unsafe { Tensor::uninitialized_aligned::<f16>(&[N], 64).unwrap() };
let s = unsafe { t.as_slice_mut_unchecked::<f16>() };
for (i, x) in s.iter_mut().enumerate() {
*x = f16::from_f32((i as f32 / 10.0).sin() * 5.0);
}
t
}

macro_rules! bench_triple {
($c:expr, $name:expr, $pred:ty, $roundtrip:ty, $native:ty $(, $param:expr)?) => {{
let mut group = $c.benchmark_group($name);
group.throughput(Throughput::Elements(N as u64));
let mut tg = aligned_input();
let sg = unsafe { tg.as_slice_mut_unchecked::<f16>() };
group.bench_function("generic", |b| {
b.iter(|| <$pred>::run(sg, ($($param)?)))
});
if std::is_x86_feature_detected!("avx512f") {
let mut tr = aligned_input();
let sr = unsafe { tr.as_slice_mut_unchecked::<f16>() };
group.bench_function("avx512_f32roundtrip", |b| {
b.iter(|| <$roundtrip>::run(sr, ($($param)?)))
});
}
if std::is_x86_feature_detected!("avx512fp16") {
let mut tn = aligned_input();
let sn = unsafe { tn.as_slice_mut_unchecked::<f16>() };
group.bench_function("avx512fp16_native", |b| {
b.iter(|| <$native>::run(sn, ($($param)?)))
});
}
group.finish();
}};
}

fn benches(c: &mut Criterion) {
bench_triple!(
c,
"hardswish_f16",
tract_linalg::generic::hardswish::HHardSwish8,
tract_linalg::x86_64_fma::act_f16::x86_64_avx512_hardswish_f16_64n,
tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n
);
bench_triple!(
c,
"leaky_relu_f16",
tract_linalg::generic::leaky_relu::HLeakyRelu8,
tract_linalg::x86_64_fma::act_f16::x86_64_avx512_leaky_relu_f16_64n,
tract_linalg::x86_64_fma::act_f16_fp16::x86_64_avx512fp16_leaky_relu_f16_128n,
f16::from_f32(0.1)
);
}

criterion_group!(g, benches);
criterion_main!(g);
22 changes: 22 additions & 0 deletions linalg/src/x86_64_fma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub mod mmm;

pub mod act;
pub mod act_f16;
pub mod act_f16_fp16;
pub mod by_scalar;
pub mod erf;
mod intel;
Expand Down Expand Up @@ -46,6 +47,24 @@ fn plug_fma(ops: &mut Ops) {
log::info!("sigmoid_f32, tanh_f32: x86_64/fma activated");
}

/// On hosts that also support AVX-512_FP16 (Sapphire Rapids / Granite Rapids /
/// later, and recent Xeon-D / consumer parts), upgrade the f16 element-wise
/// kernels from the f32-roundtrip implementations in `act_f16.rs` to the
/// native f16 implementations in `act_f16_fp16.rs` where the native path is
/// actually faster on this uarch. We benched each op against its f32-roundtrip
/// equivalent on Sapphire Rapids and only plug in the ones that win:
///
/// hardswish_f16: 8.71 → 31.6 Gelem/s (3.62× native) — plug in
/// leaky_relu_f16: 9.44 → 5.85 Gelem/s (0.62× native — regression) — keep
/// the f32-roundtrip version from act_f16.rs. The native
/// kernel exists in act_f16_fp16.rs for future revisits but
/// is not wired here.
fn plug_avx512fp16(ops: &mut Ops) {
ops.hardswish_f16 = Box::new(|| act_f16_fp16::x86_64_avx512fp16_hardswish_f16_128n::ew());

log::info!("hardswish_f16: x86_64/avx512fp16 native activated");
}

fn plug_avx512f(ops: &mut Ops) {
ops.sigmoid_f32 = Box::new(|| avx512_sigmoid_f32::ew());
ops.tanh_f32 = Box::new(|| avx512_tanh_f32::ew());
Expand Down Expand Up @@ -88,6 +107,9 @@ pub fn plug(ops: &mut Ops) {
plug_fma(ops);
if is_x86_feature_detected!("avx512f") {
plug_avx512f(ops);
if is_x86_feature_detected!("avx512fp16") {
plug_avx512fp16(ops);
}
}
}
}
Expand Down
203 changes: 203 additions & 0 deletions linalg/src/x86_64_fma/act_f16_fp16.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
// AVX-512_FP16 native f16 element-wise activations.
//
// Sapphire Rapids (and later Intel) added the AVX-512 FP16 ISA: zmm-wide
// arithmetic on f16 directly (`vmulph`, `vfmadd*ph`, `vmaxph`, `vminph`,
// `vaddph`, `vsubph`, etc.). 32 f16 lanes per zmm — double the parallelism of
// the f32-roundtrip kernels in `act_f16.rs`, and zero conversion at the IO
// boundary.
//
// The kernels here mirror the algorithm of the f32 versions in `act.rs` and
// the f32-roundtrip f16 versions in `act_f16.rs`. Polynomials are evaluated
// directly in f16, accepting the lower mantissa precision (11 bits vs f32's
// 24) — the resulting tolerance fits inside the f16 activation tests'
// SuperApproximate band.
//
// Gated on `is_x86_feature_detected!("avx512fp16")` (the actual gating happens
// in `plug_avx512fp16` over in `x86_64_fma.rs`). Pre-FP16 AVX-512 hosts
// (Skylake-X, Cascade Lake, Ice Lake server prior to fp16 extension) keep
// using `act_f16.rs`'s f32-roundtrip versions.

use tract_data::internal::f16;

const FP16_TARGETS: &str = "avx512f,avx512fp16,avx512bw";

// hardswish(x) = x * clamp(x + 3, 0, 6) * (1/6).
// 128 f16 per iter (4 zmm × 32 lanes), 256 bytes / iter — same memory throughput
// as the f32 kernel's 64 f32 / iter.
ew_impl_wrap!(
f16,
x86_64_avx512fp16_hardswish_f16_128n,
128,
32,
(),
#[inline(never)]
fn run(buf: &mut [f16], _: ()) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
unsafe { hardswish_f16_run(buf) }
}
);

#[target_feature(enable = "avx512f,avx512fp16,avx512bw")]
unsafe fn hardswish_f16_run(buf: &mut [f16]) {
let len = buf.len();
let ptr = buf.as_ptr() as *mut u8;
let three = f16::from_f32(3.0).to_bits();
let six = f16::from_f32(6.0).to_bits();
let recip6 = f16::from_f32(1.0 / 6.0).to_bits();
unsafe {
std::arch::asm!("
vpbroadcastw zmm0, eax // 3.0
vpbroadcastw zmm1, ecx // 6.0
vpbroadcastw zmm2, edx // 1/6
vpxord zmm3, zmm3, zmm3 // 0.0
2:
vmovdqa64 zmm4, [{ptr}]
vmovdqa64 zmm5, [{ptr} + 64]
vmovdqa64 zmm6, [{ptr} + 128]
vmovdqa64 zmm7, [{ptr} + 192]
vaddph zmm8, zmm4, zmm0
vaddph zmm9, zmm5, zmm0
vaddph zmm10, zmm6, zmm0
vaddph zmm11, zmm7, zmm0
vminph zmm8, zmm8, zmm1
vminph zmm9, zmm9, zmm1
vminph zmm10, zmm10, zmm1
vminph zmm11, zmm11, zmm1
vmaxph zmm8, zmm8, zmm3
vmaxph zmm9, zmm9, zmm3
vmaxph zmm10, zmm10, zmm3
vmaxph zmm11, zmm11, zmm3
vmulph zmm8, zmm8, zmm4
vmulph zmm9, zmm9, zmm5
vmulph zmm10, zmm10, zmm6
vmulph zmm11, zmm11, zmm7
vmulph zmm8, zmm8, zmm2
vmulph zmm9, zmm9, zmm2
vmulph zmm10, zmm10, zmm2
vmulph zmm11, zmm11, zmm2
vmovdqa64 [{ptr}], zmm8
vmovdqa64 [{ptr} + 64], zmm9
vmovdqa64 [{ptr} + 128], zmm10
vmovdqa64 [{ptr} + 192], zmm11
add {ptr}, 256
sub {len}, 128
jnz 2b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("eax") three as u32,
in("ecx") six as u32,
in("edx") recip6 as u32,
out("zmm0") _, out("zmm1") _, out("zmm2") _, out("zmm3") _,
out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
);
}
}

// leaky_relu(x, alpha) = x if x >= 0 else alpha*x
// For 0 <= alpha <= 1: leaky_relu(x, alpha) = max(x, alpha*x). For the typical
// alpha values used (0.01, 0.1, 0.2) this is exact.
//
// NOTE: This native fp16 version benched ~38% SLOWER than the f32-roundtrip
// version on Sapphire Rapids (9.44 Gelem/s f32-roundtrip vs 5.85 Gelem/s
// native, n=1024, single-thread). The two compute ops per element (vmulph +
// vmaxph) appear not to saturate Sapphire Rapids' FP16 execution port the
// same way f32 mul/max saturate the FP32 ports. The kernel is correct (passes
// proptest against the f16 reference) but is NOT plugged in — see the
// `plug_avx512fp16` comment in `x86_64_fma.rs`. Kept here in case a different
// AVX-512_FP16 uarch (Granite Rapids etc.) flips the comparison.
ew_impl_wrap!(
f16,
x86_64_avx512fp16_leaky_relu_f16_128n,
128,
32,
f16,
#[inline(never)]
fn run(buf: &mut [f16], alpha: f16) {
debug_assert!(buf.len() % Self::nr() == 0);
debug_assert!(buf.as_ptr() as usize % Self::alignment_bytes() == 0);
if buf.is_empty() {
return;
}
unsafe { leaky_relu_f16_run(buf, alpha) }
}
);

#[target_feature(enable = "avx512f,avx512fp16,avx512bw")]
unsafe fn leaky_relu_f16_run(buf: &mut [f16], alpha: f16) {
let len = buf.len();
let ptr = buf.as_ptr() as *mut u8;
let alpha_bits = alpha.to_bits();
unsafe {
std::arch::asm!("
vpbroadcastw zmm0, eax // alpha
2:
vmovdqa64 zmm4, [{ptr}]
vmovdqa64 zmm5, [{ptr} + 64]
vmovdqa64 zmm6, [{ptr} + 128]
vmovdqa64 zmm7, [{ptr} + 192]
vmulph zmm8, zmm4, zmm0
vmulph zmm9, zmm5, zmm0
vmulph zmm10, zmm6, zmm0
vmulph zmm11, zmm7, zmm0
vmaxph zmm8, zmm8, zmm4
vmaxph zmm9, zmm9, zmm5
vmaxph zmm10, zmm10, zmm6
vmaxph zmm11, zmm11, zmm7
vmovdqa64 [{ptr}], zmm8
vmovdqa64 [{ptr} + 64], zmm9
vmovdqa64 [{ptr} + 128], zmm10
vmovdqa64 [{ptr} + 192], zmm11
add {ptr}, 256
sub {len}, 128
jnz 2b
",
len = inout(reg) len => _,
ptr = inout(reg) ptr => _,
in("eax") alpha_bits as u32,
out("zmm0") _,
out("zmm4") _, out("zmm5") _, out("zmm6") _, out("zmm7") _,
out("zmm8") _, out("zmm9") _, out("zmm10") _, out("zmm11") _,
);
}
}

#[cfg(test)]
pub mod test_x86_64_avx512fp16_hardswish {
use super::*;
crate::hardswish_frame_tests!(
is_x86_feature_detected!("avx512fp16"),
f16,
x86_64_avx512fp16_hardswish_f16_128n
);
}

#[cfg(test)]
pub mod test_x86_64_avx512fp16_leaky_relu {
use super::*;
crate::leaky_relu_frame_tests!(
is_x86_feature_detected!("avx512fp16"),
f16,
x86_64_avx512fp16_leaky_relu_f16_128n
);
}

// Suppress unused-const lint until we expand to more kernels.
#[allow(dead_code)]
const _UNUSED: &str = FP16_TARGETS;
Loading