Skip to content

linalg/x86_64 + core/nn: fused AVX-512 RmsNorm kernel#2311

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/avx512-rms-norm
Open

linalg/x86_64 + core/nn: fused AVX-512 RmsNorm kernel#2311
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/avx512-rms-norm

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Summary

Adds a fused AVX-512 RmsNorm kernel + scalar fallback at the linalg level, and routes core::ops::nn::RmsNorm::eval's common case (contiguous trailing axis, F32 or F16) through it instead of the existing 4-call composition.

Why: RmsNorm::eval currently runs Reducer::MeanOfSquares + Add + Rsqrt + Mul as 4 separate ops, each writing/reading the full input through L1/L2. CUDA and Metal already expose a fused rms_norm kernel (cuda/src/kernels/nn/rms_norm.rs, metal/src/kernels/nn/rms_norm.rs); this closes the CPU side of that gap.

What:

  • New Ops::rms_norm_f32: Box<dyn Fn(&mut [f32], f32) + Send + Sync> — a simple closure slot since the op operates on a whole row (doesn't fit the per-element ElementWise pattern).
  • linalg/src/generic/rms_norm.rs — scalar default.
  • linalg/src/x86_64_fma/rms_norm.rs — AVX-512 kernel. Two passes: (1) sum-of-squares via 4 zmm FMA accumulators + zmm→scalar reduce + rsqrt, (2) multiply-back via 4 zmm broadcast-multiplies. Scalar tail handles row_len % 64 ≠ 0. Uses vmovups (per-row slices aren't guaranteed 64-byte aligned).
  • plug_avx512f overrides the closure on AVX-512 hosts.
  • core::ops::nn::RmsNorm::eval adds a fast path: when dtype ∈ {F32, F16} and self.axis == input.rank() - 1, iterates over outer dims and dispatches to the linalg primitive. Other axes keep the original composition.

Bench (single-thread, Cascade Lake, throughput Gelem/s):

row composed generic scalar AVX-512 AVX-512 vs composed
1024 0.77 0.75 12.4 16.2×
2048 0.77 0.75 13.8 17.9×
4096 0.77 0.75 13.8 17.9×

The composed and generic-scalar numbers are nearly identical because both are memory-bandwidth bound on a 4-pass / 2-pass workload of the same total work; the AVX-512 win comes from doing both passes in 1/4 the loop iterations (zmm width) and from avoiding the inter-op allocation + dispatch overhead in the composed version.

Test plan

  • cargo test --release -p tract-linalg --lib rms_norm — 8 tests (4 generic, 4 AVX-512 against the scalar reference, including a length-with-tail case).
  • cargo test --release -p tract-core --lib rms_normeval_with_f16_eps_and_f16_input (now exercises the new fast path: rank-1 F16 input + F16 eps) plus new eval_with_non_trailing_axis_f32 (rank-2, axis=0) which asserts the slow-path 4-call composition still matches a hand-computed reference within 1e-5.
  • cargo test --release -p tract-linalg — 2665 passed, 0 failed.
  • cargo test --release -p tract-core --lib — 246 passed, 0 failed.
  • cargo bench --bench rms_norm — numbers above.
  • Non-AVX512 x86 hosts unchanged (scalar fallback exercised via plug_avx512f gating).
  • Non-trailing-axis case unchanged (slow path = original 4-call composition; defended by the new test above).
  • Cross-arch builds clean (aarch64-unknown-linux-gnu, wasm32-unknown-unknown): the AVX-512 asm in linalg/src/x86_64_fma/rms_norm.rs is walled off by #[cfg(target_arch = "x86_64")] on the x86_64_fma module in linalg/src/lib.rs.

Validation environment

x86_64 KVM guest, Ubuntu 24.04.4 LTS (kernel 6.18.5), rustc 1.94.1.

  • CPU: Intel Xeon @ 2.80 GHz, family 6 / model 85 / stepping 7 (Cascade Lake-SP); 4 vCPU, 1 thread/core, 1 socket.
  • AVX-512 features (all confirmed by is_x86_feature_detected!): f, vnni, dq, bw, vl, cd.
  • Cache: L1d 32 KiB/core, L2 1 MiB/core, L3 33 MiB shared. 15 GiB RAM.
  • Bench discipline: RAYON_NUM_THREADS=1, taskset -c 0, Criterion warm-up 5 s / measure 15 s, sample size 100.

Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com


Generated by Claude Code

@czoli1976 czoli1976 force-pushed the feat/avx512-rms-norm branch from ded79fd to 855f563 Compare May 28, 2026 20:13
@czoli1976 czoli1976 force-pushed the feat/avx512-rms-norm branch from 855f563 to 6227823 Compare May 29, 2026 08:13
czoli1976 added a commit to czoli1976/tract that referenced this pull request May 29, 2026
VLA SVE2 implementation of the row-wise RmsNorm primitive added by the
parent stack (sonos#2311 linalg slot + core/nn fast path; sonos#2314 NEON kernel).
Plugs into Ops::rms_norm_f32 in sve::plug() when FEAT_SVE2 is present on
Linux aarch64, overriding the NEON 4-lane kernel with wider lanes
(vl-dependent) and a predicated tail (no scalar epilogue).

Structure mirrors the NEON + AVX-512 kernels:

  Pass 1 — sum of squares via 4 svfloat32_t accumulator chains, 4*svcntw()
           lanes per iteration. Tail handled by a predicated svwhilelt_b32
           loop over the residue — no scalar epilogue.
  Pass 2 — broadcast inv_std into inv_v, fmul/st1 each 4-vec chunk;
           same predicated tail.

Width-agnostic by construction — identical correct output at any FEAT_SVE
streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer
loop iterations, real perf scaling.

Validation (QEMU-only — no SVE hardware locally):
- 100 cases pass at SVL=128 (4 lanes), SVL=256 (8 lanes), SVL=512 (16
  lanes) via qemu-aarch64 -cpu max,sve{128,256,512}=on. Coverage: every
  size 1..33, hidden ∈ {768..8192} × 9 tail residues, huge rows up to
  32768, all-zero pathological. Bit-equivalent vs scalar within
  sqrt(n)-scaled tolerance.
- Local M1 macOS build clean (tract_sve cfg gated out; new code is purely
  additive — Linux aarch64 + FEAT_SVE2 only).

Expected gain over the NEON kernel scales with SVL:
- 128-bit SVE (rare Neoverse-N1):  ~0× (same width as NEON)
- 256-bit SVE (Graviton G3/G4):    ~1.3–1.8×
- 512-bit SVE (Neoverse-V2 wide):  ~2.5–4× (mirroring AVX-512 vs SSE)

Perf number unmeasured pending SVE hardware (AWS Graviton free tier).
Same validation shape as PR sonos#2268 (correctness via QEMU + bit-equivalent
vs the NEON fallback).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kali
kali previously approved these changes Jun 5, 2026
Add a linalg-side fused row-wise RmsNorm primitive
(`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call
composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single
two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce
+ rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail
handles the remainder when row_len % 64 != 0; vmovups is used throughout
since per-row slices from a tensor are not guaranteed 64-byte aligned.

`core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs
where the normalised axis is the last (contiguous) one — it iterates row
by row and dispatches to the linalg primitive. Other shapes (non-trailing
axis) keep the original composition. Generic scalar fallback ships
alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar
version, which is itself ~equivalent to the composed path because both
are memory-bandwidth bound.

CUDA and Metal already expose a fused `rms_norm` kernel
(`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`);
this closes the CPU side of the same gap.

Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s):
  - row 1024:  0.77 (composed) -> 12.4 (AVX-512)   16.2x
  - row 2048:  0.77            -> 13.8             17.9x
  - row 4096:  0.77            -> 13.8             17.9x

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@kali kali force-pushed the feat/avx512-rms-norm branch from 6227823 to ed8dfb5 Compare June 5, 2026 11:58
@kali
Copy link
Copy Markdown
Collaborator

kali commented Jun 5, 2026

rebased!

@czoli1976
Copy link
Copy Markdown
Contributor Author

image

@kali
Copy link
Copy Markdown
Collaborator

kali commented Jun 5, 2026

dafuk is this bird ? :) plus, i hope there is never a 0.24, i think it's time to start adulting.

@czoli1976
Copy link
Copy Markdown
Contributor Author

czoli1976 commented Jun 5, 2026

It's a peacock, beautiful and elegant, admittedly not considered a very smart bird

@kali
Copy link
Copy Markdown
Collaborator

kali commented Jun 5, 2026

I know, I'm a country guy. And yeah, they're pretty dumb.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants