Commit 8bb71cf
authored
portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum (#20090)
This PR follows up on #19117 (`op_grid_sampler_2d`)
### Motivation
softmax, log_softmax, mean, and sum all accumulate their reduction in
the input dtype. For BFloat16, that sum saturates around 256. Once it
gets there, adding 1.0 rounds away and the total gets stuck. A uniform
softmax over 512 elements in BFloat16 gives `~1/256` per output instead
of `1/512`.
### Why FP32 accumulation is needed
BFloat16 has the same exponent width as Float32, so it has a similar
range. However, it has far fewer fraction bits, which makes its
representable spacing much coarser as values grow.
| Type | Exponent bits | Fraction bits | Practical effect |
| --- | ---: | ---: | --- |
| `BFloat16` | 8 | 7 | Similar range to `Float32`, but coarse spacing |
| `Float32` | 8 | 23 | Similar range, much finer spacing |
For BFloat16, the gap between consecutive representable values (i.e, the
smallest step size) increases at each power-of-two range:
| Range | BFloat16 step size | Representable examples |
| --- | ---: | --- |
| `[128, 256)` | `1` | `128, 129, 130, ..., 255` |
| `[256, 512)` | `2` | `256, 258, 260, ..., 510` |
As a result, once a BFloat16 running sum reaches `256`, adding `1.0` no
longer changes the value:
| Operation | Exact result | BFloat16 result | Reason |
| --- | ---: | ---: | --- |
| `256 + 1` | `257` | `256` | `257` is not representable and rounds back
to `256` (according to IEEE 754; round-to-nearest-even) |
This directly affects all four ops for large inputs. For a softmax over
512 zeros, each `exp(0)` contributes `1.0`, so the denominator should be
`512`. If the BFloat16 accumulation gets stuck at `256`, the output
becomes approximately `1/256` instead of the correct `1/512`.
| Case | Expected denominator | BFloat16 accumulated denominator |
Output |
| --- | ---: | ---: | ---: |
| Correct accumulation | `512` | `512` | `1/512` |
| BFloat16 accumulation | `512` | `~256` | `~1/256` |
### Tests
```
$ cmake --build cmake-out --target portable_kernels_test -j$(nproc)
[100%] Built target portable_kernels_test
# Post-fix — new tests:
[ OK ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ OK ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat
# Pre-fix (reverted op files):
[ FAILED ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat
[ FAILED ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat
$ lintrunner op_softmax.cpp op_log_softmax.cpp op_mean.cpp op_sum.cpp \
op_softmax_test.cpp op_log_softmax_test.cpp op_mean_test.cpp op_sum_test.cpp
ok No lint issues.
```
cc @larryliu0820 @manuelcandales1 parent e93a285 commit 8bb71cf
9 files changed
Lines changed: 194 additions & 40 deletions
File tree
- kernels
- optimized/cpu
- portable/cpu
- test
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
98 | 98 | | |
99 | 99 | | |
100 | 100 | | |
101 | | - | |
102 | | - | |
103 | | - | |
104 | | - | |
105 | | - | |
| 101 | + | |
| 102 | + | |
106 | 103 | | |
107 | | - | |
108 | | - | |
109 | | - | |
110 | | - | |
111 | | - | |
112 | | - | |
113 | | - | |
114 | | - | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
115 | 122 | | |
116 | 123 | | |
117 | 124 | | |
| |||
148 | 155 | | |
149 | 156 | | |
150 | 157 | | |
| 158 | + | |
| 159 | + | |
| 160 | + | |
| 161 | + | |
| 162 | + | |
| 163 | + | |
| 164 | + | |
| 165 | + | |
| 166 | + | |
| 167 | + | |
| 168 | + | |
| 169 | + | |
151 | 170 | | |
152 | 171 | | |
153 | 172 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| |||
42 | 43 | | |
43 | 44 | | |
44 | 45 | | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
45 | 49 | | |
46 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
47 | 56 | | |
48 | 57 | | |
49 | 58 | | |
| |||
61 | 70 | | |
62 | 71 | | |
63 | 72 | | |
64 | | - | |
| 73 | + | |
65 | 74 | | |
66 | | - | |
| 75 | + | |
| 76 | + | |
67 | 77 | | |
68 | | - | |
| 78 | + | |
69 | 79 | | |
70 | 80 | | |
71 | 81 | | |
| |||
75 | 85 | | |
76 | 86 | | |
77 | 87 | | |
78 | | - | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
79 | 91 | | |
80 | 92 | | |
81 | 93 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
| 11 | + | |
10 | 12 | | |
11 | 13 | | |
12 | 14 | | |
| |||
58 | 60 | | |
59 | 61 | | |
60 | 62 | | |
| 63 | + | |
| 64 | + | |
61 | 65 | | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
62 | 71 | | |
63 | 72 | | |
64 | | - | |
| 73 | + | |
65 | 74 | | |
66 | 75 | | |
67 | | - | |
| 76 | + | |
68 | 77 | | |
69 | 78 | | |
70 | 79 | | |
71 | | - | |
| 80 | + | |
72 | 81 | | |
73 | 82 | | |
74 | 83 | | |
| |||
83 | 92 | | |
84 | 93 | | |
85 | 94 | | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
86 | 100 | | |
87 | 101 | | |
88 | 102 | | |
89 | 103 | | |
90 | 104 | | |
91 | | - | |
| 105 | + | |
92 | 106 | | |
93 | | - | |
94 | | - | |
95 | | - | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
96 | 110 | | |
97 | 111 | | |
98 | | - | |
| 112 | + | |
| 113 | + | |
99 | 114 | | |
100 | 115 | | |
101 | 116 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
10 | 11 | | |
11 | 12 | | |
12 | 13 | | |
| |||
42 | 43 | | |
43 | 44 | | |
44 | 45 | | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
45 | 49 | | |
46 | 50 | | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
47 | 56 | | |
48 | 57 | | |
49 | 58 | | |
| |||
61 | 70 | | |
62 | 71 | | |
63 | 72 | | |
64 | | - | |
| 73 | + | |
65 | 74 | | |
66 | | - | |
| 75 | + | |
| 76 | + | |
67 | 77 | | |
68 | | - | |
| 78 | + | |
69 | 79 | | |
70 | 80 | | |
71 | 81 | | |
| |||
74 | 84 | | |
75 | 85 | | |
76 | 86 | | |
77 | | - | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
78 | 92 | | |
79 | 93 | | |
80 | 94 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
7 | 7 | | |
8 | 8 | | |
9 | 9 | | |
| 10 | + | |
| 11 | + | |
10 | 12 | | |
11 | 13 | | |
12 | 14 | | |
| |||
60 | 62 | | |
61 | 63 | | |
62 | 64 | | |
| 65 | + | |
| 66 | + | |
63 | 67 | | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
64 | 73 | | |
65 | 74 | | |
66 | 75 | | |
67 | 76 | | |
68 | | - | |
| 77 | + | |
69 | 78 | | |
70 | 79 | | |
71 | 80 | | |
72 | | - | |
| 81 | + | |
73 | 82 | | |
74 | 83 | | |
75 | 84 | | |
| |||
108 | 117 | | |
109 | 118 | | |
110 | 119 | | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
111 | 125 | | |
112 | 126 | | |
113 | 127 | | |
114 | 128 | | |
115 | 129 | | |
116 | | - | |
| 130 | + | |
117 | 131 | | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
125 | 135 | | |
126 | 136 | | |
127 | | - | |
| 137 | + | |
128 | 138 | | |
129 | 139 | | |
130 | 140 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
369 | 369 | | |
370 | 370 | | |
371 | 371 | | |
| 372 | + | |
| 373 | + | |
| 374 | + | |
| 375 | + | |
| 376 | + | |
| 377 | + | |
| 378 | + | |
| 379 | + | |
| 380 | + | |
| 381 | + | |
| 382 | + | |
| 383 | + | |
| 384 | + | |
372 | 385 | | |
373 | 386 | | |
374 | 387 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
263 | 263 | | |
264 | 264 | | |
265 | 265 | | |
| 266 | + | |
| 267 | + | |
| 268 | + | |
| 269 | + | |
| 270 | + | |
| 271 | + | |
| 272 | + | |
| 273 | + | |
| 274 | + | |
| 275 | + | |
| 276 | + | |
| 277 | + | |
| 278 | + | |
| 279 | + | |
| 280 | + | |
| 281 | + | |
| 282 | + | |
| 283 | + | |
| 284 | + | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
266 | 295 | | |
267 | 296 | | |
268 | 297 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
254 | 267 | | |
255 | 268 | | |
256 | 269 | | |
| |||
0 commit comments