Skip to content

portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090

Merged
GregoryComer merged 3 commits into
pytorch:mainfrom
vacu9708:fp32-accumulation-bfloat16
Jun 17, 2026
Merged

portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090
GregoryComer merged 3 commits into
pytorch:mainfrom
vacu9708:fp32-accumulation-bfloat16

Conversation

@vacu9708

@vacu9708 vacu9708 commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

This PR follows up on #19117 (op_grid_sampler_2d)

Motivation

softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives ~1/256 per output instead of 1/512.

Why FP32 accumulation is needed

BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow.

Type Exponent bits Fraction bits Practical effect
BFloat16 8 7 Similar range to Float32, but coarse spacing
Float32 8 23 Similar range, much finer spacing

For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range:

Range BFloat16 step size Representable examples
[128, 256) 1 128, 129, 130, ..., 255
[256, 512) 2 256, 258, 260, ..., 510

As a result, once a BFloat16 running sum reaches 256, adding 1.0 no longer changes the value:

Operation Exact result BFloat16 result Reason
256 + 1 257 256 257 is not representable and rounds back to 256 (according to IEEE 754; round-to-nearest-even)

This directly affects all four ops for large inputs. For a softmax over 512 zeros, each exp(0) contributes 1.0, so the denominator should be 512. If the BFloat16 accumulation gets stuck at 256, the output becomes approximately 1/256 instead of the correct 1/512.

Case Expected denominator BFloat16 accumulated denominator Output
Correct accumulation 512 512 1/512
BFloat16 accumulation 512 ~256 ~1/256

Tests

$ cmake --build cmake-out --target portable_kernels_test -j$(nproc)
[100%] Built target portable_kernels_test

# Post-fix — new tests:
[ OK ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat

# Pre-fix (reverted op files):
[ FAILED ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat

$ lintrunner op_softmax.cpp op_log_softmax.cpp op_mean.cpp op_sum.cpp \
             op_softmax_test.cpp op_log_softmax_test.cpp op_mean_test.cpp op_sum_test.cpp
ok  No lint issues.

cc @larryliu0820 @manuelcandales

@vacu9708 vacu9708 requested a review from manuelcandales as a code owner June 8, 2026 04:09
@pytorch-bot

pytorch-bot Bot commented Jun 8, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20090

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit cbbb3dc with merge base 71cbe9f (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 8, 2026
@github-actions

github-actions Bot commented Jun 8, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@vacu9708 vacu9708 force-pushed the fp32-accumulation-bfloat16 branch 2 times, most recently from 3fb0012 to d56aa5a Compare June 8, 2026 06:02
@nil-is-all nil-is-all added the module: kernels Issues related to kernel libraries and utilities, and code under kernels/ label Jun 15, 2026

@GregoryComer GregoryComer left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, the changes look good. Thanks for fixing this.

There is a failure in CI that looks legitimate. Could you take a look? This is from the unit test job.

2026-06-15T20:57:20.7535620Z [ RUN      ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
2026-06-15T20:57:20.7535990Z E 00:00:00.161654 executorch:op_log_softmax.cpp:152] Check failed (false): 
2026-06-15T20:57:20.7536490Z /Users/ec2-user/runner/_work/executorch/executorch/pytorch/executorch/kernels/test/op_log_softmax_test.cpp:382: Failure
2026-06-15T20:57:20.7536960Z Value of: (out)
2026-06-15T20:57:20.7543010Z Expected: is close to with tol (ETensor(sizes={1, 512}, dtype=BFloat16, data={-6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25}), 1e-05, 0.1)
2026-06-15T20:57:20.7551240Z   Actual: ETensor(sizes={1, 512}, dtype=BFloat16, data={0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) (of type executorch::runtime::etensor::Tensor)
2026-06-15T20:57:20.7553800Z /Users/ec2-user/runner/_work/executorch/executorch/pytorch/executorch/../executorch/kernels/test/TestUtil.h:108: Failure
2026-06-15T20:57:20.7554230Z Expected equality of these values:
2026-06-15T20:57:20.7554440Z   context_.failure_state()
2026-06-15T20:57:20.7554620Z     Which is: 4-byte object <12-00 00-00>
2026-06-15T20:57:20.7554830Z   torch::executor::Error::Ok
2026-06-15T20:57:20.7555010Z     Which is: 4-byte object <00-00 00-00>
2026-06-15T20:57:20.7555320Z [  FAILED  ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat (0 ms)

vacu9708 added 3 commits June 16, 2026 12:27
…tmax

Problem:
Softmax and log_softmax accumulated exp(x - max) in the tensor dtype.
For BFloat16, the running sum saturates around 256 — adding 1.0 stops
changing the total — so a uniform softmax over N=512 elements outputs
~1/256 instead of 1/512.

Changes:
Accumulate the exp-sum in float for Half/BFloat16 by threading an ACC
type through the map-reduce calls. Loads and stores remain in the tensor
dtype.

Continues the fp32-accumulation work in pytorch#19117.
Problem:
The fast-path and generic reduction loops in mean.out and sum.IntList_out
accumulated the running sum in the tensor dtype. For BFloat16, the sum
saturates around 256, so a mean over N=512 all-ones elements gives 0.5
instead of 1.0, and summing 512 all-ones elements gives 256 instead of
512.

Changes:
Accumulate in float for Half/BFloat16 by promoting the loop accumulator
to ACC in both the fast path and the generic path. The final result is
cast back to the tensor dtype on store.

Continues the fp32-accumulation work in pytorch#19117.
opt_log_softmax_out only handled Float; BFloat16 and Half fell through to
ET_KERNEL_CHECK(false), leaving output unchanged. The underlying
log_softmax_kernel<IN_T, OUT_T> is fully generic and the ATen vectorized
functions it delegates to already support BFloat16 and Half.

- Extend log_softmax_wrapper with an if constexpr branch for BFloat16/Half
  that calls log_softmax_kernel<T, T>
- Add BFloat16 and Half dispatch cases in opt_log_softmax_out
@vacu9708 vacu9708 force-pushed the fp32-accumulation-bfloat16 branch from d56aa5a to cbbb3dc Compare June 16, 2026 03:44
@vacu9708

vacu9708 commented Jun 16, 2026

Copy link
Copy Markdown
Contributor Author

Thanks for flagging this @GregoryComer.
The failure was in optimized_kernels_test, not portable_kernels_test. opt_log_softmax_out only handled Float; BFloat16 hit ET_KERNEL_CHECK(context, false, InvalidArgument, out), leaving the output as zeros.

I did not encounter this test fail because my local build environment does not compile optimized kernels.

Fixed by adding BFloat16 and Half support directly in opt_log_softmax_out and pushed in the latest commit.

@vacu9708

vacu9708 commented Jun 17, 2026

Copy link
Copy Markdown
Contributor Author

I think all three failures are infra issues:

  • Cadence hifi4 / vision (Input required and not supplied: aws-region):
    • fails in the AWS-credentials step before any build.
  • test-models-linux-basic (mv3, portable):
    • died during pip install . setup with No matching distribution found for scikit-learn==1.7.1 (from versions: none)

@vacu9708 vacu9708 requested a review from GregoryComer June 17, 2026 08:54
@GregoryComer GregoryComer merged commit 8bb71cf into pytorch:main Jun 17, 2026
178 of 181 checks passed
@GregoryComer

Copy link
Copy Markdown
Member

Thanks for the contribution!

@Gasoonjia

Gasoonjia commented Jun 17, 2026

Copy link
Copy Markdown
Contributor

heads up this PR breaks ExecuTorch's lint rule: https://hud.pytorch.org/hud/pytorch/executorch/main/1?per_page=50&name_filter=lint

you may want to have another PR to solve it, or i have to revert the current PR 20090.

@vacu9708

@vacu9708

Copy link
Copy Markdown
Contributor Author

@Gasoonjia I am looking at it and found the root cause. I'll reply shortly.

@vacu9708

vacu9708 commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

e93a285ebd "Extend CPPCHECK scope to portable kernels" (2026-06-04)
My work branch was cut before this new lint coverage.

I submitted a PR #20368 that fixes this issue. Please review it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: kernels Issues related to kernel libraries and utilities, and code under kernels/

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants