linalg/arm64/sve: VLA SVE2 rms_norm_f32 kernel (stacked on #2314)#2315
Open
czoli1976 wants to merge 3 commits into
Open
linalg/arm64/sve: VLA SVE2 rms_norm_f32 kernel (stacked on #2314)#2315czoli1976 wants to merge 3 commits into
czoli1976 wants to merge 3 commits into
Conversation
Contributor
Author
|
@kali this should complete the RMS_NORM Perf work AFAIK, so we got it for AVX/NEON/SME/SVE_SVL |
Add a linalg-side fused row-wise RmsNorm primitive (`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce + rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail handles the remainder when row_len % 64 != 0; vmovups is used throughout since per-row slices from a tensor are not guaranteed 64-byte aligned. `core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs where the normalised axis is the last (contiguous) one — it iterates row by row and dispatches to the linalg primitive. Other shapes (non-trailing axis) keep the original composition. Generic scalar fallback ships alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar version, which is itself ~equivalent to the composed path because both are memory-bandwidth bound. CUDA and Metal already expose a fused `rms_norm` kernel (`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`); this closes the CPU side of the same gap. Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s): - row 1024: 0.77 (composed) -> 12.4 (AVX-512) 16.2x - row 2048: 0.77 -> 13.8 17.9x - row 4096: 0.77 -> 13.8 17.9x Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an aarch64 NEON implementation of `tract_linalg::ops().rms_norm_f32`,
mirroring the AVX-512 kernel from the parent RmsNorm PR. 16 f32 lanes per
inner loop iteration (4 v-registers of 4 lanes each):
Pass 1 — sum of squares via 4 fmla chains (v0..v3), 3-way fadd reduce,
then horizontal reduce to scalar via vaddvq_f32.
Pass 2 — broadcast inv_std into v0, multiply each 4-v-register chunk
in place.
Scalar tail handles (len % 16 != 0).
Plugs into `Ops::rms_norm_f32` in `arm64::plug()`. The core-side fast path
in `core::ops::nn::RmsNorm::eval` (added by the parent PR) is already
arch-neutral and picks this up automatically — every model with a trailing-
axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A /
Neoverse instead of the generic 4-call composition.
Tests use the same scalar-reference pattern as the AVX-512 kernel:
trivial, prop-style sin/cos input at n=16, n=1024+7 (exercising the
scalar tail), and a sub-chunk n=8 (all-tail) case. NEON is mandatory on
aarch64 so no runtime feature detection is needed; the kernel is gated by
`#[target_feature(enable = "neon")]` only for the inline-asm + intrinsic
context.
Cross-compile check: `cargo check --target aarch64-unknown-linux-gnu -p
tract-linalg` clean on the modified files. The x86_64 bench output is
unchanged (the kernel module is `#[cfg(target_arch = "aarch64")]`-only via
the `arm64` parent), and the rms_norm bench gains a "neon" column when
built for aarch64.
Dependencies: needs the parent RmsNorm PR (which adds the `Ops::rms_norm_f32`
slot and the `core::ops::nn::RmsNorm::eval` dispatcher). If the parent
lands first this rebases trivially.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
VLA SVE2 implementation of the row-wise RmsNorm primitive added by the parent stack (sonos#2311 linalg slot + core/nn fast path; sonos#2314 NEON kernel). Plugs into Ops::rms_norm_f32 in sve::plug() when FEAT_SVE2 is present on Linux aarch64, overriding the NEON 4-lane kernel with wider lanes (vl-dependent) and a predicated tail (no scalar epilogue). Structure mirrors the NEON + AVX-512 kernels: Pass 1 — sum of squares via 4 svfloat32_t accumulator chains, 4*svcntw() lanes per iteration. Tail handled by a predicated svwhilelt_b32 loop over the residue — no scalar epilogue. Pass 2 — broadcast inv_std into inv_v, fmul/st1 each 4-vec chunk; same predicated tail. Width-agnostic by construction — identical correct output at any FEAT_SVE streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer loop iterations, real perf scaling. Validation (QEMU-only — no SVE hardware locally): - 100 cases pass at SVL=128 (4 lanes), SVL=256 (8 lanes), SVL=512 (16 lanes) via qemu-aarch64 -cpu max,sve{128,256,512}=on. Coverage: every size 1..33, hidden ∈ {768..8192} × 9 tail residues, huge rows up to 32768, all-zero pathological. Bit-equivalent vs scalar within sqrt(n)-scaled tolerance. - Local M1 macOS build clean (tract_sve cfg gated out; new code is purely additive — Linux aarch64 + FEAT_SVE2 only). Expected gain over the NEON kernel scales with SVL: - 128-bit SVE (rare Neoverse-N1): ~0× (same width as NEON) - 256-bit SVE (Graviton G3/G4): ~1.3–1.8× - 512-bit SVE (Neoverse-V2 wide): ~2.5–4× (mirroring AVX-512 vs SSE) Perf number unmeasured pending SVE hardware (AWS Graviton free tier). Same validation shape as PR sonos#2268 (correctness via QEMU + bit-equivalent vs the NEON fallback). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
a4885fb to
16dbc6a
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Stacked on #2314 (which is stacked on #2311). Adds the SVE2 sibling of the new NEON
rms_norm_f32kernel. Single commit on top; review only the top commit. Trivially rebases to standalone once parents merge.What
VLA SVE2 f32 fused row-wise RmsNorm in
linalg/arm64/sve/sve_rms_norm.c. Plugs intoOps::rms_norm_f32insve::plug()whenhas_sve2()is true, overriding the NEON 4-lane kernel from #2314 with wider, vector-length-agnostic lanes:svfloat32_taccumulator chains;4 * svcntw()lanes per inner iteration; predicatedsvwhilelt_b32loop over the residue — no scalar tail.inv_std,fmul/st1each 4-vec chunk; same predicated tail.Width-agnostic by construction: same correct output at any FEAT_SVE streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer iterations, real perf scaling.
Validation (QEMU — no SVE hardware locally)
Built a standalone C harness that links
sve_rms_norm.cdirectly + a scalar reference, ran underqemu-aarch64 -cpu max,sve{128,256,512}=on:Test coverage: every size 1..33 (boundary residues), hidden ∈ {768, 1024, 2048, 3072, 4096, 5120, 8192} × 9 tail residues each, huge rows {8192, 16384, 32768}, all-zero pathological. Bit-equivalent vs scalar reference within
sqrt(n)-scaled tolerance.Local macOS M1 build clean (
tract_svecfg gated off; new code is purely additive — only fires on Linux aarch64 + FEAT_SVE2).Expected gain (perf number unmeasured pending SVE hardware)
Scales with the host's streaming vector length:
Same validation shape as the original SVE backend PR #2268: correctness via QEMU + bit-equivalent vs the NEON fallback, perf gain to be confirmed when a Graviton (or comparable) is available.
Risk
has_sve2()ANDtract_svecfg (which requires Linux + aarch64 + a SVE-capable C compiler). Non-Linux / non-aarch64 / no-SVE hosts hit zero new code; the NEON kernel from linalg/arm64: NEON rms_norm_f32 kernel (stacked on #2311) #2314 remains the active path.cargo fmt --check,cargo clippyclean.Follow-ups
Same pattern would close the rest of the gap on SVE: softmax, max, reduce, elementwise activations are still NEON-by-default. Out of scope here.
🤖 Generated with Claude Code