linalg/x86_64 + core/nn: fused AVX-512 RmsNorm kernel#2311
Open
czoli1976 wants to merge 1 commit into
Open
Conversation
This was referenced May 28, 2026
ded79fd to
855f563
Compare
855f563 to
6227823
Compare
czoli1976
added a commit
to czoli1976/tract
that referenced
this pull request
May 29, 2026
VLA SVE2 implementation of the row-wise RmsNorm primitive added by the parent stack (sonos#2311 linalg slot + core/nn fast path; sonos#2314 NEON kernel). Plugs into Ops::rms_norm_f32 in sve::plug() when FEAT_SVE2 is present on Linux aarch64, overriding the NEON 4-lane kernel with wider lanes (vl-dependent) and a predicated tail (no scalar epilogue). Structure mirrors the NEON + AVX-512 kernels: Pass 1 — sum of squares via 4 svfloat32_t accumulator chains, 4*svcntw() lanes per iteration. Tail handled by a predicated svwhilelt_b32 loop over the residue — no scalar epilogue. Pass 2 — broadcast inv_std into inv_v, fmul/st1 each 4-vec chunk; same predicated tail. Width-agnostic by construction — identical correct output at any FEAT_SVE streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer loop iterations, real perf scaling. Validation (QEMU-only — no SVE hardware locally): - 100 cases pass at SVL=128 (4 lanes), SVL=256 (8 lanes), SVL=512 (16 lanes) via qemu-aarch64 -cpu max,sve{128,256,512}=on. Coverage: every size 1..33, hidden ∈ {768..8192} × 9 tail residues, huge rows up to 32768, all-zero pathological. Bit-equivalent vs scalar within sqrt(n)-scaled tolerance. - Local M1 macOS build clean (tract_sve cfg gated out; new code is purely additive — Linux aarch64 + FEAT_SVE2 only). Expected gain over the NEON kernel scales with SVL: - 128-bit SVE (rare Neoverse-N1): ~0× (same width as NEON) - 256-bit SVE (Graviton G3/G4): ~1.3–1.8× - 512-bit SVE (Neoverse-V2 wide): ~2.5–4× (mirroring AVX-512 vs SSE) Perf number unmeasured pending SVE hardware (AWS Graviton free tier). Same validation shape as PR sonos#2268 (correctness via QEMU + bit-equivalent vs the NEON fallback). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
kali
previously approved these changes
Jun 5, 2026
Add a linalg-side fused row-wise RmsNorm primitive (`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce + rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail handles the remainder when row_len % 64 != 0; vmovups is used throughout since per-row slices from a tensor are not guaranteed 64-byte aligned. `core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs where the normalised axis is the last (contiguous) one — it iterates row by row and dispatches to the linalg primitive. Other shapes (non-trailing axis) keep the original composition. Generic scalar fallback ships alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar version, which is itself ~equivalent to the composed path because both are memory-bandwidth bound. CUDA and Metal already expose a fused `rms_norm` kernel (`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`); this closes the CPU side of the same gap. Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s): - row 1024: 0.77 (composed) -> 12.4 (AVX-512) 16.2x - row 2048: 0.77 -> 13.8 17.9x - row 4096: 0.77 -> 13.8 17.9x Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
6227823 to
ed8dfb5
Compare
Collaborator
|
rebased! |
Contributor
Author
Collaborator
|
dafuk is this bird ? :) plus, i hope there is never a 0.24, i think it's time to start adulting. |
Contributor
Author
|
It's a peacock, beautiful and elegant, admittedly not considered a very smart bird |
Collaborator
|
I know, I'm a country guy. And yeah, they're pretty dumb. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.

Summary
Adds a fused AVX-512 RmsNorm kernel + scalar fallback at the linalg level, and routes
core::ops::nn::RmsNorm::eval's common case (contiguous trailing axis, F32 or F16) through it instead of the existing 4-call composition.Why:
RmsNorm::evalcurrently runsReducer::MeanOfSquares+Add+Rsqrt+Mulas 4 separate ops, each writing/reading the full input through L1/L2. CUDA and Metal already expose a fusedrms_normkernel (cuda/src/kernels/nn/rms_norm.rs,metal/src/kernels/nn/rms_norm.rs); this closes the CPU side of that gap.What:
Ops::rms_norm_f32: Box<dyn Fn(&mut [f32], f32) + Send + Sync>— a simple closure slot since the op operates on a whole row (doesn't fit the per-elementElementWisepattern).linalg/src/generic/rms_norm.rs— scalar default.linalg/src/x86_64_fma/rms_norm.rs— AVX-512 kernel. Two passes: (1) sum-of-squares via 4 zmm FMA accumulators + zmm→scalar reduce + rsqrt, (2) multiply-back via 4 zmm broadcast-multiplies. Scalar tail handlesrow_len % 64 ≠ 0. Usesvmovups(per-row slices aren't guaranteed 64-byte aligned).plug_avx512foverrides the closure on AVX-512 hosts.core::ops::nn::RmsNorm::evaladds a fast path: whendtype ∈ {F32, F16}andself.axis == input.rank() - 1, iterates over outer dims and dispatches to the linalg primitive. Other axes keep the original composition.Bench (single-thread, Cascade Lake, throughput Gelem/s):
The composed and generic-scalar numbers are nearly identical because both are memory-bandwidth bound on a 4-pass / 2-pass workload of the same total work; the AVX-512 win comes from doing both passes in 1/4 the loop iterations (zmm width) and from avoiding the inter-op allocation + dispatch overhead in the composed version.
Test plan
cargo test --release -p tract-linalg --lib rms_norm— 8 tests (4 generic, 4 AVX-512 against the scalar reference, including a length-with-tail case).cargo test --release -p tract-core --lib rms_norm—eval_with_f16_eps_and_f16_input(now exercises the new fast path: rank-1 F16 input + F16 eps) plus neweval_with_non_trailing_axis_f32(rank-2, axis=0) which asserts the slow-path 4-call composition still matches a hand-computed reference within 1e-5.cargo test --release -p tract-linalg— 2665 passed, 0 failed.cargo test --release -p tract-core --lib— 246 passed, 0 failed.cargo bench --bench rms_norm— numbers above.plug_avx512fgating).aarch64-unknown-linux-gnu,wasm32-unknown-unknown): the AVX-512 asm inlinalg/src/x86_64_fma/rms_norm.rsis walled off by#[cfg(target_arch = "x86_64")]on thex86_64_fmamodule inlinalg/src/lib.rs.Validation environment
x86_64 KVM guest, Ubuntu 24.04.4 LTS (kernel 6.18.5), rustc 1.94.1.
is_x86_feature_detected!): f, vnni, dq, bw, vl, cd.RAYON_NUM_THREADS=1,taskset -c 0, Criterion warm-up 5 s / measure 15 s, sample size 100.Co-Authored-By: Claude Opus 4.7 (1M context) noreply@anthropic.com
Generated by Claude Code