Skip to content

transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318

Open
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/arm64-neon-fp16-activations
Open

transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318
czoli1976 wants to merge 1 commit into
sonos:mainfrom
czoli1976:feat/arm64-neon-fp16-activations

Conversation

@czoli1976
Copy link
Copy Markdown
Contributor

@czoli1976 czoli1976 commented May 29, 2026

What

ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused attention softmax always ran scalar libm expf and never reached the linalg SIMD softmax kernels — a real perf gap. This PR closes that gap without sacrificing accuracy, by adding an accurate vectorizable exp and routing the fused softmax through it.

History: this PR originally just flipped Libc → FastCompact. CI (and a local full-suite run) showed that was wrong on two counts, so the approach was reworked — see below.

Why not FastCompact

Switching the fused softmax to the existing SoftmaxExp::FastCompact (Schraudolph approximation) fails the scaled_masked_softmax / sdpa proptests two independent ways:

  1. Precision. FastCompact's exp is ~0.5% off true softmax — outside the suite's Approximate tolerance (f32 rtol 5e-4), producing 30%+ outliers. (The existing softmax_l2 frame test only compares FastCompact against itself, so it never caught this.)
  2. Fully-masked rows → NaN mismatch. On an all--inf row the FastCompact kernel pads the SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1, so the row sums to a nonzero value and yields a finite 0 where the scalar libc path and the numpy reference both yield NaN (0 * 1/0).

This mirrors what ggml/llama.cpp concluded (ggml-org/llama.cpp#7154): keep an accurate vectorized expf for softmax, reserve fast-approx exp for error-tolerant ops.

What changed

  • linalg: accurate_exp_f32 — a Cephes-style range-reduced exp (Cody-Waite ln2 split + degree-6 polynomial + 2^n by exponent construction). Measured max rel error ~1.9e-6 vs libc over the softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep underflow flushes to 0; NaN propagates.
  • linalg: SSoftMaxL2Accurate / HSoftMaxL2Accurate map-reduce kernels, exposed as softmax2_accurate_{f32,f16}. They pad the SIMD tail with -inf (not f32::MIN), so masked/padding lanes contribute exactly 0 and a fully-masked row sums to 0NaN, matching libc and the reference.
  • core: new SoftmaxExp::Accurate variant + dispatch (f32/f16).
  • nnef: exp = "accurate" de/serialization round-trip.
  • transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.

Libc remains the default everywhere; FastCompact is untouched. This adds a third, accurate-but-vectorized option and points fused attention at it.

Tests

  • New linalg tests validate accurate_exp_f32 against libc (not against itself) and cover the fully-masked degenerate row (sum == 0).
  • scaled_masked_softmax + sdpa proptests (f16/f32 × raw/decluttered/optimized) pass on native and wasm32-wasip1 (the job that originally failed).
  • Full tract-linalg suite green; tract-core / nnef / transformers green; cargo fmt clean; no new clippy warnings on touched files.

🤖 Generated with Claude Code

@czoli1976 czoli1976 marked this pull request as draft May 29, 2026 17:14
@czoli1976
Copy link
Copy Markdown
Contributor Author

taking it back, correctness failure to investigate

@czoli1976 czoli1976 force-pushed the feat/arm64-neon-fp16-activations branch from eb71f47 to 457c825 Compare May 29, 2026 21:41
@czoli1976 czoli1976 changed the title transformers: ScaledMaskedSoftmax eval uses FastCompact exp to unlock SIMD softmax transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp May 29, 2026
ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused
attention softmax always ran scalar libm expf and never reached the
linalg SIMD softmax kernels — a real perf gap. The naive fix (switch to
SoftmaxExp::FastCompact) trades correctness for speed and fails the
proptests two ways:

  1. FastCompact's Schraudolph exp is ~0.5% off true softmax — outside
     the suite's Approximate tolerance (f32 rtol 5e-4), 30%+ outliers.
  2. On a fully-masked row (all -inf) the FastCompact kernel pads the
     SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1,
     so the row sums to a nonzero value and yields a finite 0 where the
     scalar libc path (and the numpy reference) yield NaN (0 * 1/0).

Instead, add an accurate vectorizable exp and route the fused softmax
through it (mirrors ggml/llama.cpp, which kept an accurate vectorized
expf for softmax rather than a coarse approximation):

  * linalg: `accurate_exp_f32`, a Cephes-style range-reduced exp
    (Cody-Waite ln2 split + degree-6 poly + 2^n by exponent
    construction). Measured max rel error ~1.9e-6 vs libc over the
    softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep
    underflow flushes to 0; NaN propagates.
  * linalg: `SSoftMaxL2Accurate` / `HSoftMaxL2Accurate` map-reduce
    kernels, exposed as `softmax2_accurate_{f32,f16}`. They pad the
    SIMD tail with -inf (not f32::MIN), so masked/padding lanes
    contribute exactly 0 and a fully-masked row sums to 0 -> NaN,
    matching libc and the reference.
  * core: new `SoftmaxExp::Accurate` variant + dispatch.
  * nnef: `exp = "accurate"` de/serialization round-trip.
  * transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.

New linalg tests validate the accurate exp against libc (not against
itself, unlike the existing FastCompact frame test) and cover the
fully-masked degenerate row. scaled_masked_softmax + sdpa proptests
(f16/f32, raw/decluttered/optimized) pass on native and wasm32-wasip1.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@czoli1976 czoli1976 force-pushed the feat/arm64-neon-fp16-activations branch from 457c825 to 9a5fa12 Compare May 30, 2026 06:36
@czoli1976 czoli1976 marked this pull request as ready for review May 30, 2026 07:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant