Skip to content

linalg/arm64: NEON rms_norm_f32 kernel (stacked on #2311)#2314

Open
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feat/arm64-neon-rms-norm
Open

linalg/arm64: NEON rms_norm_f32 kernel (stacked on #2311)#2314
czoli1976 wants to merge 2 commits into
sonos:mainfrom
czoli1976:feat/arm64-neon-rms-norm

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

Stacked on #2311 — needs the Ops::rms_norm_f32 slot + RmsNorm::eval fast path that PR adds. Single commit on top; review only the top commit. Rebases trivially to a standalone PR if #2311 merges first.

What

NEON (aarch64, 128-bit, 4-lane) implementation of tract_linalg::ops().rms_norm_f32, mirroring the AVX-512 kernel from the parent PR. 16 f32 lanes per inner iteration (4 v-registers × 4 lanes):

  • Pass 1 — sum of squares via 4 parallel fmla chains (v0..v3) → 3-way fadd tree → vaddvq_f32 horizontal reduce.
  • Pass 2 — broadcast inv_std into v0, fmul/st1 each 4-v-register chunk in place.
  • Scalar tail for (len % 16 != 0).

Plugs into arm64::plug(). NEON is mandatory on aarch64 so no runtime feature detection — #[target_feature(enable = "neon")] is just for the inline-asm + intrinsic context. Generic scalar fallback in tract_linalg::generic covers non-aarch64.

Validation (M1, real hardware)

Surface Result
4 new unit tests (trivial / 1024+7 tail / n=8 all-tail / empty) 4/4 pass
Full cargo test -p tract-linalg --lib 3718/0 pass, 0 regressions
cargo fmt --check, cargo clippy clean (no new warnings; the 6 pre-existing tract-linalg warnings on main are unchanged)

Synthetic stress test (local, not in this commit)

Local-only stress test file covering every size 1..32, hidden ∈ {768..8192} × 5 tail residues, pathological distributions (all-zero / all-equal / mixed-sign large-magnitude / subnormal-mostly), 8 epsilons (0 → 100), 500 random sizes in [1, 8192], and huge rows up to 32768. 10/10 pass vs scalar reference. Happy to include the file in this PR if reviewers want it as test/.

E2E correctness — 64×RmsNorm@4096 ONNX chain (this PR vs main)

max_abs_diff   = 0.000e+00
mean_abs_diff  = 0.000e+00
bit-exact equal? True

Bit-exact across 64 chained layers. No drift, no FMA-reorder loss.

Performance (M1 P-core)

Kernel-level microbench (linalg/benches/rms_norm.rs, 30 samples × 3s)

Row N composed (inline 2-loop, autovec) NEON gap
1024 1.04 µs / 0.98 Gelem/s 0.140 µs / 7.31 Gelem/s 7.43×
2048 2.08 µs / 0.98 Gelem/s 0.279 µs / 7.34 Gelem/s 7.47×
4096 4.17 µs / 0.98 Gelem/s 0.549 µs / 7.45 Gelem/s 7.59×

E2E on a 64×RmsNorm@4096 ONNX (real tract pipeline, bench mode)

Form main (composed 4-call) this PR speedup
Opset-23 native RMSNormalization × 64 0.299 ms/i 0.097 ms/i 3.08×
Opset-17 composed pattern × 64 (declutter hoists) 0.265 ms/i 0.077 ms/i 3.44×

E2E gap is smaller than kernel-level because per-call dispatch overhead is ~constant; with the fast NEON kernel it becomes a larger fraction. For a real LLM where RmsNorm is ~5% of inference time, this is ~1–4% E2E.

Caveat / context

The fast path only fires on graphs where RmsNorm is hoisted as a single op — i.e., native opset-23 RMSNormalization, SimplifiedLayerNormalization contrib (#2288), or the "clean" composed pattern that detect_rms_norm matches. Many HuggingFace optimum-onnxruntime LLM exports use a Cast → Pow → ReduceMean → Add → Sqrt → Div → Mul → Cast → Mul variant that the current declutter doesn't recognize, so on those models this kernel doesn't fire and there's no gain. Worth a follow-up to broaden the declutter pattern; not in scope here.

Risk

  • NEON is always present on aarch64 — no gating required, no host detection issue.
  • Pure addition: no existing kernel modified, no Ops field semantics changed.
  • Non-AVX-512 x86 / non-aarch64 keeps the generic scalar fallback from linalg/x86_64 + core/nn: fused AVX-512 RmsNorm kernel #2311 (unchanged).

🤖 Generated with Claude Code

czoli1976 and others added 2 commits May 29, 2026 09:13
Add a linalg-side fused row-wise RmsNorm primitive
(`tract_linalg::ops().rms_norm_f32`) that replaces tract-core's 4-call
composition (`MeanOfSquares` + `Add` + `Rsqrt` + `Mul`) with a single
two-pass kernel: sum-of-squares via 4 zmm FMA accumulators, scalar reduce
+ rsqrt, then multiply-back via 4 zmm broadcast-multiplies. Scalar tail
handles the remainder when row_len % 64 != 0; vmovups is used throughout
since per-row slices from a tensor are not guaranteed 64-byte aligned.

`core::ops::nn::RmsNorm::eval` gains a fast path for F32 / F16 inputs
where the normalised axis is the last (contiguous) one — it iterates row
by row and dispatches to the linalg primitive. Other shapes (non-trailing
axis) keep the original composition. Generic scalar fallback ships
alongside the AVX-512 kernel; non-x86 and non-AVX-512 x86 keep the scalar
version, which is itself ~equivalent to the composed path because both
are memory-bandwidth bound.

CUDA and Metal already expose a fused `rms_norm` kernel
(`cuda/src/kernels/nn/rms_norm.rs`, `metal/src/kernels/nn/rms_norm.rs`);
this closes the CPU side of the same gap.

Measured on Cascade Lake (single-thread, kernel-level, throughput Gelem/s):
  - row 1024:  0.77 (composed) -> 12.4 (AVX-512)   16.2x
  - row 2048:  0.77            -> 13.8             17.9x
  - row 4096:  0.77            -> 13.8             17.9x

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds an aarch64 NEON implementation of `tract_linalg::ops().rms_norm_f32`,
mirroring the AVX-512 kernel from the parent RmsNorm PR. 16 f32 lanes per
inner loop iteration (4 v-registers of 4 lanes each):

  Pass 1 — sum of squares via 4 fmla chains (v0..v3), 3-way fadd reduce,
           then horizontal reduce to scalar via vaddvq_f32.
  Pass 2 — broadcast inv_std into v0, multiply each 4-v-register chunk
           in place.
  Scalar tail handles (len % 16 != 0).

Plugs into `Ops::rms_norm_f32` in `arm64::plug()`. The core-side fast path
in `core::ops::nn::RmsNorm::eval` (added by the parent PR) is already
arch-neutral and picks this up automatically — every model with a trailing-
axis F32/F16 RmsNorm now hits this kernel on Apple Silicon / Cortex-A /
Neoverse instead of the generic 4-call composition.

Tests use the same scalar-reference pattern as the AVX-512 kernel:
trivial, prop-style sin/cos input at n=16, n=1024+7 (exercising the
scalar tail), and a sub-chunk n=8 (all-tail) case. NEON is mandatory on
aarch64 so no runtime feature detection is needed; the kernel is gated by
`#[target_feature(enable = "neon")]` only for the inline-asm + intrinsic
context.

Cross-compile check: `cargo check --target aarch64-unknown-linux-gnu -p
tract-linalg` clean on the modified files. The x86_64 bench output is
unchanged (the kernel module is `#[cfg(target_arch = "aarch64")]`-only via
the `arm64` parent), and the rms_norm bench gains a "neon" column when
built for aarch64.

Dependencies: needs the parent RmsNorm PR (which adds the `Ops::rms_norm_f32`
slot and the `core::ops::nn::RmsNorm::eval` dispatcher). If the parent
lands first this rebases trivially.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@czoli1976 czoli1976 force-pushed the feat/arm64-neon-rms-norm branch from 459d38b to 94ffd7b Compare May 29, 2026 08:13
czoli1976 added a commit to czoli1976/tract that referenced this pull request May 29, 2026
VLA SVE2 implementation of the row-wise RmsNorm primitive added by the
parent stack (sonos#2311 linalg slot + core/nn fast path; sonos#2314 NEON kernel).
Plugs into Ops::rms_norm_f32 in sve::plug() when FEAT_SVE2 is present on
Linux aarch64, overriding the NEON 4-lane kernel with wider lanes
(vl-dependent) and a predicated tail (no scalar epilogue).

Structure mirrors the NEON + AVX-512 kernels:

  Pass 1 — sum of squares via 4 svfloat32_t accumulator chains, 4*svcntw()
           lanes per iteration. Tail handled by a predicated svwhilelt_b32
           loop over the residue — no scalar epilogue.
  Pass 2 — broadcast inv_std into inv_v, fmul/st1 each 4-vec chunk;
           same predicated tail.

Width-agnostic by construction — identical correct output at any FEAT_SVE
streaming vector length (128 → 2048 bits). Wider VL = wider lanes, fewer
loop iterations, real perf scaling.

Validation (QEMU-only — no SVE hardware locally):
- 100 cases pass at SVL=128 (4 lanes), SVL=256 (8 lanes), SVL=512 (16
  lanes) via qemu-aarch64 -cpu max,sve{128,256,512}=on. Coverage: every
  size 1..33, hidden ∈ {768..8192} × 9 tail residues, huge rows up to
  32768, all-zero pathological. Bit-equivalent vs scalar within
  sqrt(n)-scaled tolerance.
- Local M1 macOS build clean (tract_sve cfg gated out; new code is purely
  additive — Linux aarch64 + FEAT_SVE2 only).

Expected gain over the NEON kernel scales with SVL:
- 128-bit SVE (rare Neoverse-N1):  ~0× (same width as NEON)
- 256-bit SVE (Graviton G3/G4):    ~1.3–1.8×
- 512-bit SVE (Neoverse-V2 wide):  ~2.5–4× (mirroring AVX-512 vs SSE)

Perf number unmeasured pending SVE hardware (AWS Graviton free tier).
Same validation shape as PR sonos#2268 (correctness via QEMU + bit-equivalent
vs the NEON fallback).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@czoli1976 czoli1976 marked this pull request as ready for review June 1, 2026 10:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant