Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion core/src/ops/nn/rms_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,34 @@ impl EvalOp for RmsNorm {

fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
let input = args_1!(inputs);
let in_dt = input.datum_type();

// Fast path: F32 or F16 input where the normalised axis is the last
// (contiguous) one. Use the fused tract_linalg::rms_norm_f32 kernel
// (AVX-512 when available; scalar fallback otherwise) instead of the
// 4-call MeanOfSquares + Add + Rsqrt + Mul composition below. ~16-18x
// faster on Cascade Lake AVX-512, ~equivalent on the scalar fallback
// since the composition is also memory-bandwidth bound.
if matches!(in_dt, DatumType::F32 | DatumType::F16)
&& input.rank() > 0
&& self.axis == input.rank() - 1
{
let eps_f32: f32 = self.eps.cast_to_scalar::<f32>()?;
let mut buf = input.cast_to::<f32>()?.into_owned();
let row_len = buf.shape()[self.axis];
if row_len > 0 {
let n_rows: usize = buf.shape().iter().take(self.axis).product();
let data = unsafe { buf.as_slice_mut_unchecked::<f32>() };
let rms_norm = &tract_linalg::ops().rms_norm_f32;
for r in 0..n_rows {
let start = r * row_len;
rms_norm(&mut data[start..start + row_len], eps_f32);
}
}
return Ok(tvec![buf.cast_to_dt(in_dt)?.into_owned().into()]);
}

// Slow path: original 4-call composition (kept for non-contiguous axes).
let input_f32 = input.cast_to::<f32>()?.into_owned();
// eps inherits the input dtype from the declutter pattern (F16 when the
// surrounding LayerNorm chain is F16). The MeanOfSquares + Add + Rsqrt
Expand All @@ -41,7 +68,7 @@ impl EvalOp for RmsNorm {
let mut a2 = Add.eval(a1.into_tvalue(), eps.into_tvalue(), DatumType::F32)?;
Rsqrt {}.eval_in_place(&mut a2, None)?;
let a3 = Mul.eval(a2.into_tvalue(), input_f32.into_tvalue(), DatumType::F32)?;
Ok(tvec![a3.cast_to_dt(input.datum_type())?.into_owned().into()])
Ok(tvec![a3.cast_to_dt(in_dt)?.into_owned().into()])
}
}

Expand Down Expand Up @@ -205,4 +232,40 @@ mod tests {
assert!(diff < 0.01, "lane {i}: got {} expected {}", g.to_f32(), e);
}
}

/// Slow path: when the normalised axis is NOT the trailing one, the fast
/// path in `eval` (which dispatches to `tract_linalg::ops().rms_norm_f32`)
/// is skipped and the original 4-call `MeanOfSquares` + `Add` + `Rsqrt` +
/// `Mul` composition runs. Asserts the result is identical to a hand-
/// computed reference, so the slow path stays correct after the fast-path
/// addition.
#[test]
fn eval_with_non_trailing_axis_f32() {
// 2x3 input, axis=0 means we normalise across the 2 rows for each
// column independently:
// col 0: [1, 4] → mean_sq = (1 + 16) / 2 = 8.5 → 1/√8.5
// col 1: [2, 5] → mean_sq = (4 + 25) / 2 = 14.5 → 1/√14.5
// col 2: [3, 6] → mean_sq = (9 + 36) / 2 = 22.5 → 1/√22.5
let input = tensor2(&[[1.0_f32, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let eps = tensor0(0.0_f32).into_arc_tensor();
let op = RmsNorm { axis: 0, eps };
let out = op.eval(tvec!(input.into())).expect("eval should not panic");
let out = out.into_iter().next().unwrap().into_tensor();
assert_eq!(out.datum_type(), DatumType::F32);
assert_eq!(out.shape(), &[2, 3]);
let got = unsafe { out.as_slice_unchecked::<f32>() };
let inv = |ms: f32| ms.sqrt().recip();
let expected: [f32; 6] = [
1.0 * inv(8.5),
2.0 * inv(14.5),
3.0 * inv(22.5),
4.0 * inv(8.5),
5.0 * inv(14.5),
6.0 * inv(22.5),
];
for (i, (g, e)) in got.iter().zip(expected.iter()).enumerate() {
let diff = (g - e).abs();
assert!(diff < 1e-5, "lane {i}: got {g}, want {e}, diff {diff}");
}
}
}
4 changes: 4 additions & 0 deletions linalg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ harness = false
name = "softmax"
harness = false

[[bench]]
name = "rms_norm"
harness = false

[[bench]]
bench = false
name = "arm64simd"
Expand Down
86 changes: 86 additions & 0 deletions linalg/arm64/sve/sve_rms_norm.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// VLA SVE2 f32 fused row-wise RmsNorm. Mirrors the NEON kernel in
// arm64/arm64simd/rms_norm.rs and the AVX-512 kernel in
// x86_64_fma/rms_norm.rs.
//
// Pass 1 (sum of squares):
// 4 svfloat32_t accumulators (s0..s3), 4*svcntw() lanes per inner
// iteration. Tail handled by a predicated loop over the residue
// (svwhilelt_b32) — no scalar tail.
// Pass 2 (multiply-back):
// broadcast inv_std into inv_v, fmla/store each 4-vec chunk in place;
// same predicated tail.
//
// Width-agnostic by construction: identical correct output at any FEAT_SVE
// streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer
// loop iterations, real perf scaling.
//
// ABI: void sve_rms_norm_f32_kernel(float *buf, int64_t n, float eps).
// Called from sve.rs::sve_rms_norm_f32 when FEAT_SVE2 is present on Linux
// aarch64. Plugs into Ops::rms_norm_f32; the core/nn::RmsNorm::eval fast
// path dispatches here automatically for trailing-axis F32 RmsNorm.

#include <arm_sve.h>
#include <math.h>
#include <stdint.h>

void sve_rms_norm_f32_kernel(float *buf, int64_t n, float eps) {
if (n <= 0) return;

const int64_t vl = (int64_t)svcntw();
const int64_t step = 4 * vl;
const svbool_t ptrue = svptrue_b32();

// --- Pass 1: sum of squares ---
svfloat32_t s0 = svdup_n_f32(0.0f);
svfloat32_t s1 = svdup_n_f32(0.0f);
svfloat32_t s2 = svdup_n_f32(0.0f);
svfloat32_t s3 = svdup_n_f32(0.0f);

int64_t i = 0;
for (; i + step <= n; i += step) {
svfloat32_t x0 = svld1_f32(ptrue, buf + i + 0 * vl);
svfloat32_t x1 = svld1_f32(ptrue, buf + i + 1 * vl);
svfloat32_t x2 = svld1_f32(ptrue, buf + i + 2 * vl);
svfloat32_t x3 = svld1_f32(ptrue, buf + i + 3 * vl);
s0 = svmla_f32_x(ptrue, s0, x0, x0);
s1 = svmla_f32_x(ptrue, s1, x1, x1);
s2 = svmla_f32_x(ptrue, s2, x2, x2);
s3 = svmla_f32_x(ptrue, s3, x3, x3);
}
// Predicated tail: handles the (n % step) remainder, possibly distributed
// across up to 4 partial vl-chunks. No scalar epilogue.
for (; i < n; i += vl) {
svbool_t pg = svwhilelt_b32((uint64_t)i, (uint64_t)n);
svfloat32_t x = svld1_f32(pg, buf + i);
s0 = svmla_f32_x(pg, s0, x, x);
}

// Reduce 4 accumulators → scalar via tree-add + horizontal reduce.
s0 = svadd_f32_x(ptrue, s0, s1);
s2 = svadd_f32_x(ptrue, s2, s3);
s0 = svadd_f32_x(ptrue, s0, s2);
float sum_sq = svaddv_f32(ptrue, s0);

float mean_sq = sum_sq / (float)n;
float inv_std = 1.0f / sqrtf(mean_sq + eps);

// --- Pass 2: multiply by inv_std ---
svfloat32_t inv_v = svdup_n_f32(inv_std);

i = 0;
for (; i + step <= n; i += step) {
svfloat32_t x0 = svld1_f32(ptrue, buf + i + 0 * vl);
svfloat32_t x1 = svld1_f32(ptrue, buf + i + 1 * vl);
svfloat32_t x2 = svld1_f32(ptrue, buf + i + 2 * vl);
svfloat32_t x3 = svld1_f32(ptrue, buf + i + 3 * vl);
svst1_f32(ptrue, buf + i + 0 * vl, svmul_f32_x(ptrue, x0, inv_v));
svst1_f32(ptrue, buf + i + 1 * vl, svmul_f32_x(ptrue, x1, inv_v));
svst1_f32(ptrue, buf + i + 2 * vl, svmul_f32_x(ptrue, x2, inv_v));
svst1_f32(ptrue, buf + i + 3 * vl, svmul_f32_x(ptrue, x3, inv_v));
}
for (; i < n; i += vl) {
svbool_t pg = svwhilelt_b32((uint64_t)i, (uint64_t)n);
svfloat32_t x = svld1_f32(pg, buf + i);
svst1_f32(pg, buf + i, svmul_f32_x(pg, x, inv_v));
}
}
62 changes: 62 additions & 0 deletions linalg/benches/rms_norm.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// Microbench: fused RmsNorm vs the 4-call composition that tract-core currently
// uses (MeanOfSquares + Add + Rsqrt + Mul). The composition is reconstructed
// inline here in the same shape as `core::ops::nn::rms_norm::RmsNorm::eval`
// drives it. Both versions run on a 64-byte-aligned f32 row.

use criterion::*;
use tract_data::prelude::*;

fn aligned_row(n: usize) -> Tensor {
let mut t = unsafe { Tensor::uninitialized_aligned::<f32>(&[n], 64).unwrap() };
let s = unsafe { t.as_slice_mut_unchecked::<f32>() };
for (i, x) in s.iter_mut().enumerate() {
*x = (i as f32 / 10.0).sin() * 5.0;
}
t
}

#[inline(never)]
fn composed_rms_norm(buf: &mut [f32], eps: f32) {
// Same shape as tract-core's RmsNorm::eval: separate passes for sum-of-squares,
// mean, +eps, rsqrt, multiply — each writing/reading the row once.
let mut sum_sq = 0.0_f32;
for &x in buf.iter() {
sum_sq += x * x;
}
let mean_sq = sum_sq / buf.len() as f32;
let added = mean_sq + eps;
let inv_std = added.sqrt().recip();
for x in buf.iter_mut() {
*x *= inv_std;
}
}

fn rms_norm(c: &mut Criterion) {
for &n in &[1024usize, 2048, 4096] {
let id = format!("{n}");
let mut g = c.benchmark_group(format!("rms_norm_f32/{id}"));
g.throughput(Throughput::Elements(n as u64));
let mut t = aligned_row(n);
let s = unsafe { t.as_slice_mut_unchecked::<f32>() };
g.bench_function("composed", |b| b.iter(|| composed_rms_norm(s, 1e-5)));
g.bench_function("generic", |b| {
b.iter(|| tract_linalg::generic::rms_norm::rms_norm_f32(s, 1e-5))
});
#[cfg(target_arch = "x86_64")]
if std::is_x86_feature_detected!("avx512f") {
g.bench_function("avx512", |b| {
b.iter(|| tract_linalg::x86_64_fma::rms_norm::rms_norm_f32(s, 1e-5))
});
}
#[cfg(target_arch = "aarch64")]
{
g.bench_function("neon", |b| {
b.iter(|| tract_linalg::arm64::arm64simd_rms_norm_f32(s, 1e-5))
});
}
g.finish();
}
}

criterion_group!(g, rms_norm);
criterion_main!(g);
1 change: 1 addition & 0 deletions linalg/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ fn main() {
.file("arm64/sve/sve_mmv_f32_64x1.c")
.file("arm64/sve/sve_mmm_i32.c")
.file("arm64/sve/sve_mmm_i32_64x1.c")
.file("arm64/sve/sve_rms_norm.c")
.flag("-march=armv8.2-a+sve")
.compile("tract_sve_kernels");
// f16 kernels need native FP16 arithmetic (+fp16); compiled
Expand Down
1 change: 1 addition & 0 deletions linalg/src/arm64.rs
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,7 @@ pub fn plug(ops: &mut Ops) {
ops.sum_f32 = Box::new(|| arm64simd_sum_f32_16n::red());
ops.mul_by_scalar_f32 = Box::new(|| arm64simd_mul_by_scalar_f32_16n::ew());
ops.softmax2_fastcompact_f32 = Box::new(|| arm64simd_softmax2_fastcompact_f32_16n::red());
ops.rms_norm_f32 = Box::new(arm64simd_rms_norm_f32);
#[cfg(not(feature = "no_fp16"))]
if has_fp16() {
log::info!("ARMv8.2 tanh_f16 and sigmoid_f16 activated");
Expand Down
2 changes: 2 additions & 0 deletions linalg/src/arm64/arm64simd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod hardswish;
mod leaky_relu;
mod max;
mod panel_extract;
mod rms_norm;
mod silu;
mod silu_fused;
mod softmax;
Expand All @@ -17,6 +18,7 @@ pub use gelu_fused::arm64simd_gelu_f32_4n_fused;
pub use hardswish::arm64simd_hardswish_f32_8n;
pub use leaky_relu::arm64simd_leaky_relu_f32_8n;
pub use max::arm64simd_max_f32_16n;
pub use rms_norm::rms_norm_f32 as arm64simd_rms_norm_f32;
pub use silu::arm64simd_silu_f32_4n;
pub use silu_fused::arm64simd_silu_f32_4n_fused;
pub use softmax::arm64simd_softmax2_fastcompact_f32_16n;
Expand Down
Loading