transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318
Open
czoli1976 wants to merge 1 commit into
Open
transformers: route fused ScaledMaskedSoftmax through an accurate vectorized exp#2318czoli1976 wants to merge 1 commit into
czoli1976 wants to merge 1 commit into
Conversation
Contributor
Author
|
taking it back, correctness failure to investigate |
eb71f47 to
457c825
Compare
ScaledMaskedSoftmax::eval hard-coded SoftmaxExp::Libc, so the fused
attention softmax always ran scalar libm expf and never reached the
linalg SIMD softmax kernels — a real perf gap. The naive fix (switch to
SoftmaxExp::FastCompact) trades correctness for speed and fails the
proptests two ways:
1. FastCompact's Schraudolph exp is ~0.5% off true softmax — outside
the suite's Approximate tolerance (f32 rtol 5e-4), 30%+ outliers.
2. On a fully-masked row (all -inf) the FastCompact kernel pads the
SIMD tail with f32::MIN and computes exp(f32::MIN - f32::MIN) ≈ 1,
so the row sums to a nonzero value and yields a finite 0 where the
scalar libc path (and the numpy reference) yield NaN (0 * 1/0).
Instead, add an accurate vectorizable exp and route the fused softmax
through it (mirrors ggml/llama.cpp, which kept an accurate vectorized
expf for softmax rather than a coarse approximation):
* linalg: `accurate_exp_f32`, a Cephes-style range-reduced exp
(Cody-Waite ln2 split + degree-6 poly + 2^n by exponent
construction). Measured max rel error ~1.9e-6 vs libc over the
softmax domain [0, -60]. exp(0)==1 and exp(-inf)==0 exactly; deep
underflow flushes to 0; NaN propagates.
* linalg: `SSoftMaxL2Accurate` / `HSoftMaxL2Accurate` map-reduce
kernels, exposed as `softmax2_accurate_{f32,f16}`. They pad the
SIMD tail with -inf (not f32::MIN), so masked/padding lanes
contribute exactly 0 and a fully-masked row sums to 0 -> NaN,
matching libc and the reference.
* core: new `SoftmaxExp::Accurate` variant + dispatch.
* nnef: `exp = "accurate"` de/serialization round-trip.
* transformers: ScaledMaskedSoftmax::eval uses SoftmaxExp::Accurate.
New linalg tests validate the accurate exp against libc (not against
itself, unlike the existing FastCompact frame test) and cover the
fully-masked degenerate row. scaled_masked_softmax + sdpa proptests
(f16/f32, raw/decluttered/optimized) pass on native and wasm32-wasip1.
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
457c825 to
9a5fa12
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
ScaledMaskedSoftmax::evalhard-codedSoftmaxExp::Libc, so the fused attention softmax always ran scalar libmexpfand never reached the linalg SIMD softmax kernels — a real perf gap. This PR closes that gap without sacrificing accuracy, by adding an accurate vectorizableexpand routing the fused softmax through it.Why not FastCompact
Switching the fused softmax to the existing
SoftmaxExp::FastCompact(Schraudolph approximation) fails thescaled_masked_softmax/sdpaproptests two independent ways:expis ~0.5% off true softmax — outside the suite'sApproximatetolerance (f32 rtol5e-4), producing 30%+ outliers. (The existingsoftmax_l2frame test only compares FastCompact against itself, so it never caught this.)-infrow the FastCompact kernel pads the SIMD tail withf32::MINand computesexp(f32::MIN - f32::MIN) ≈ 1, so the row sums to a nonzero value and yields a finite0where the scalar libc path and the numpy reference both yieldNaN(0 * 1/0).This mirrors what ggml/llama.cpp concluded (ggml-org/llama.cpp#7154): keep an accurate vectorized
expffor softmax, reserve fast-approxexpfor error-tolerant ops.What changed
accurate_exp_f32— a Cephes-style range-reducedexp(Cody-Waiteln2split + degree-6 polynomial +2^nby exponent construction). Measured max rel error ~1.9e-6 vs libc over the softmax domain[0, -60].exp(0)==1andexp(-inf)==0exactly; deep underflow flushes to0;NaNpropagates.SSoftMaxL2Accurate/HSoftMaxL2Accuratemap-reduce kernels, exposed assoftmax2_accurate_{f32,f16}. They pad the SIMD tail with-inf(notf32::MIN), so masked/padding lanes contribute exactly0and a fully-masked row sums to0→NaN, matching libc and the reference.SoftmaxExp::Accuratevariant + dispatch (f32/f16).exp = "accurate"de/serialization round-trip.ScaledMaskedSoftmax::evalusesSoftmaxExp::Accurate.Libcremains the default everywhere;FastCompactis untouched. This adds a third, accurate-but-vectorized option and points fused attention at it.Tests
accurate_exp_f32against libc (not against itself) and cover the fully-masked degenerate row (sum == 0).scaled_masked_softmax+sdpaproptests (f16/f32 × raw/decluttered/optimized) pass on native and wasm32-wasip1 (the job that originally failed).tract-linalgsuite green;tract-core/nnef/transformersgreen;cargo fmtclean; no new clippy warnings on touched files.🤖 Generated with Claude Code