portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum#20090
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20090
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit cbbb3dc with merge base 71cbe9f ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
3fb0012 to
d56aa5a
Compare
GregoryComer
left a comment
There was a problem hiding this comment.
Overall, the changes look good. Thanks for fixing this.
There is a failure in CI that looks legitimate. Could you take a look? This is from the unit test job.
2026-06-15T20:57:20.7535620Z [ RUN ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
2026-06-15T20:57:20.7535990Z E 00:00:00.161654 executorch:op_log_softmax.cpp:152] Check failed (false):
2026-06-15T20:57:20.7536490Z /Users/ec2-user/runner/_work/executorch/executorch/pytorch/executorch/kernels/test/op_log_softmax_test.cpp:382: Failure
2026-06-15T20:57:20.7536960Z Value of: (out)
2026-06-15T20:57:20.7543010Z Expected: is close to with tol (ETensor(sizes={1, 512}, dtype=BFloat16, data={-6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25, -6.25}), 1e-05, 0.1)
2026-06-15T20:57:20.7551240Z Actual: ETensor(sizes={1, 512}, dtype=BFloat16, data={0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}) (of type executorch::runtime::etensor::Tensor)
2026-06-15T20:57:20.7553800Z /Users/ec2-user/runner/_work/executorch/executorch/pytorch/executorch/../executorch/kernels/test/TestUtil.h:108: Failure
2026-06-15T20:57:20.7554230Z Expected equality of these values:
2026-06-15T20:57:20.7554440Z context_.failure_state()
2026-06-15T20:57:20.7554620Z Which is: 4-byte object <12-00 00-00>
2026-06-15T20:57:20.7554830Z torch::executor::Error::Ok
2026-06-15T20:57:20.7555010Z Which is: 4-byte object <00-00 00-00>
2026-06-15T20:57:20.7555320Z [ FAILED ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat (0 ms)
…tmax Problem: Softmax and log_softmax accumulated exp(x - max) in the tensor dtype. For BFloat16, the running sum saturates around 256 — adding 1.0 stops changing the total — so a uniform softmax over N=512 elements outputs ~1/256 instead of 1/512. Changes: Accumulate the exp-sum in float for Half/BFloat16 by threading an ACC type through the map-reduce calls. Loads and stores remain in the tensor dtype. Continues the fp32-accumulation work in pytorch#19117.
Problem: The fast-path and generic reduction loops in mean.out and sum.IntList_out accumulated the running sum in the tensor dtype. For BFloat16, the sum saturates around 256, so a mean over N=512 all-ones elements gives 0.5 instead of 1.0, and summing 512 all-ones elements gives 256 instead of 512. Changes: Accumulate in float for Half/BFloat16 by promoting the loop accumulator to ACC in both the fast path and the generic path. The final result is cast back to the tensor dtype on store. Continues the fp32-accumulation work in pytorch#19117.
opt_log_softmax_out only handled Float; BFloat16 and Half fell through to ET_KERNEL_CHECK(false), leaving output unchanged. The underlying log_softmax_kernel<IN_T, OUT_T> is fully generic and the ATen vectorized functions it delegates to already support BFloat16 and Half. - Extend log_softmax_wrapper with an if constexpr branch for BFloat16/Half that calls log_softmax_kernel<T, T> - Add BFloat16 and Half dispatch cases in opt_log_softmax_out
d56aa5a to
cbbb3dc
Compare
|
Thanks for flagging this @GregoryComer. I did not encounter this test fail because my local build environment does not compile optimized kernels. Fixed by adding BFloat16 and Half support directly in opt_log_softmax_out and pushed in the latest commit. |
|
I think all three failures are infra issues:
|
|
Thanks for the contribution! |
|
heads up this PR breaks ExecuTorch's lint rule: https://hud.pytorch.org/hud/pytorch/executorch/main/1?per_page=50&name_filter=lint you may want to have another PR to solve it, or i have to revert the current PR 20090. |
|
@Gasoonjia I am looking at it and found the root cause. I'll reply shortly. |
|
I submitted a PR #20368 that fixes this issue. Please review it. |
This PR follows up on #19117 (
op_grid_sampler_2d)Motivation
softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives
~1/256per output instead of1/512.Why FP32 accumulation is needed
BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow.
BFloat16Float32, but coarse spacingFloat32For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range:
[128, 256)1128, 129, 130, ..., 255[256, 512)2256, 258, 260, ..., 510As a result, once a BFloat16 running sum reaches
256, adding1.0no longer changes the value:256 + 1257256257is not representable and rounds back to256(according to IEEE 754; round-to-nearest-even)This directly affects all four ops for large inputs. For a softmax over 512 zeros, each
exp(0)contributes1.0, so the denominator should be512. If the BFloat16 accumulation gets stuck at256, the output becomes approximately1/256instead of the correct1/512.5125121/512512~256~1/256Tests
cc @larryliu0820 @manuelcandales