Skip to content

Commit 8bb71cf

Browse files
authored
portable: accumulate in fp32 for Half/BFloat16 in softmax, log_softmax, mean, and sum (#20090)
This PR follows up on #19117 (`op_grid_sampler_2d`) ### Motivation softmax, log_softmax, mean, and sum all accumulate their reduction in the input dtype. For BFloat16, that sum saturates around 256. Once it gets there, adding 1.0 rounds away and the total gets stuck. A uniform softmax over 512 elements in BFloat16 gives `~1/256` per output instead of `1/512`. ### Why FP32 accumulation is needed BFloat16 has the same exponent width as Float32, so it has a similar range. However, it has far fewer fraction bits, which makes its representable spacing much coarser as values grow. | Type | Exponent bits | Fraction bits | Practical effect | | --- | ---: | ---: | --- | | `BFloat16` | 8 | 7 | Similar range to `Float32`, but coarse spacing | | `Float32` | 8 | 23 | Similar range, much finer spacing | For BFloat16, the gap between consecutive representable values (i.e, the smallest step size) increases at each power-of-two range: | Range | BFloat16 step size | Representable examples | | --- | ---: | --- | | `[128, 256)` | `1` | `128, 129, 130, ..., 255` | | `[256, 512)` | `2` | `256, 258, 260, ..., 510` | As a result, once a BFloat16 running sum reaches `256`, adding `1.0` no longer changes the value: | Operation | Exact result | BFloat16 result | Reason | | --- | ---: | ---: | --- | | `256 + 1` | `257` | `256` | `257` is not representable and rounds back to `256` (according to IEEE 754; round-to-nearest-even) | This directly affects all four ops for large inputs. For a softmax over 512 zeros, each `exp(0)` contributes `1.0`, so the denominator should be `512`. If the BFloat16 accumulation gets stuck at `256`, the output becomes approximately `1/256` instead of the correct `1/512`. | Case | Expected denominator | BFloat16 accumulated denominator | Output | | --- | ---: | ---: | ---: | | Correct accumulation | `512` | `512` | `1/512` | | BFloat16 accumulation | `512` | `~256` | `~1/256` | ### Tests ``` $ cmake --build cmake-out --target portable_kernels_test -j$(nproc) [100%] Built target portable_kernels_test # Post-fix — new tests: [ OK ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat [ OK ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat [ OK ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat [ OK ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat # Pre-fix (reverted op files): [ FAILED ] OpSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat [ FAILED ] OpLogSoftmaxOutTest.BFloat16LargeDimAccumulatesInFloat [ FAILED ] OpMeanOutTest.BFloat16LargeDimAccumulatesInFloat [ FAILED ] OpSumOutTest.BFloat16LargeDimAccumulatesInFloat $ lintrunner op_softmax.cpp op_log_softmax.cpp op_mean.cpp op_sum.cpp \ op_softmax_test.cpp op_log_softmax_test.cpp op_mean_test.cpp op_sum_test.cpp ok No lint issues. ``` cc @larryliu0820 @manuelcandales
1 parent e93a285 commit 8bb71cf

9 files changed

Lines changed: 194 additions & 40 deletions

File tree

kernels/optimized/cpu/op_log_softmax.cpp

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,27 @@ void log_softmax_kernel(const Tensor& input, int64_t dim, Tensor& out) {
9898
return;
9999
}
100100

101-
// OUT_T is the corresponding C++ type for out.scalar_type(). Only takes float
102-
// or double.
103-
template <
104-
typename OUT_T,
105-
std::enable_if_t<std::is_floating_point<OUT_T>::value, bool> = true>
101+
// OUT_T is the corresponding C++ type for out.scalar_type().
102+
template <typename OUT_T>
106103
bool log_softmax_wrapper(const Tensor& X, int64_t dim, Tensor& out) {
107-
auto input_scalar_type = X.scalar_type();
108-
switch (input_scalar_type) {
109-
// TODO: support Double as well
110-
case ScalarType::Float:
111-
log_softmax_kernel<float, OUT_T>(X, dim, out);
112-
return true;
113-
default:
114-
return false; // Unsupported input dtype
104+
if constexpr (
105+
std::is_same_v<OUT_T, executorch::aten::BFloat16> ||
106+
std::is_same_v<OUT_T, executorch::aten::Half>) {
107+
// Input dtype equals output dtype (enforced by check_log_softmax_args).
108+
// Use if constexpr to avoid instantiating cross-type combinations that
109+
// the ATen vectorized functions do not support.
110+
log_softmax_kernel<OUT_T, OUT_T>(X, dim, out);
111+
return true;
112+
} else {
113+
auto input_scalar_type = X.scalar_type();
114+
switch (input_scalar_type) {
115+
// TODO: support Double as well
116+
case ScalarType::Float:
117+
log_softmax_kernel<float, OUT_T>(X, dim, out);
118+
return true;
119+
default:
120+
return false; // Unsupported input dtype
121+
}
115122
}
116123
}
117124
} // namespace
@@ -148,6 +155,18 @@ Tensor& opt_log_softmax_out(
148155
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
149156
break;
150157
}
158+
case ScalarType::BFloat16: {
159+
bool success =
160+
log_softmax_wrapper<executorch::aten::BFloat16>(self, dim, out);
161+
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
162+
break;
163+
}
164+
case ScalarType::Half: {
165+
bool success =
166+
log_softmax_wrapper<executorch::aten::Half>(self, dim, out);
167+
ET_KERNEL_CHECK(context, success, InvalidArgument, out);
168+
break;
169+
}
151170
default:
152171
ET_KERNEL_CHECK(context, false, InvalidArgument, out);
153172
}

kernels/portable/cpu/op_log_softmax.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <cmath>
10+
#include <type_traits>
1011

1112
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1213
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -42,8 +43,16 @@ Tensor& log_softmax_out(
4243
// Adjust for negative dim
4344
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
4445

46+
// For half-precision inputs, the exp-sum is accumulated in float to avoid
47+
// saturation (BFloat16 saturates near 256, Half near 2048). Matches ATen's
48+
// acc_type behavior. See also op_grid_sampler_2d.cpp.
4549
ET_SWITCH_FLOATHBF16_TYPES(
4650
in.scalar_type(), ctx, "_log_softmax.out", CTYPE, [&]() {
51+
using ACC = std::conditional_t<
52+
std::is_same_v<CTYPE, executorch::aten::Half> ||
53+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
54+
float,
55+
CTYPE>;
4756
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
4857
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
4958

@@ -61,11 +70,12 @@ Tensor& log_softmax_out(
6170
size,
6271
stride);
6372

64-
CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
73+
ACC temp_sum = apply_unary_map_reduce_fn<CTYPE, ACC>(
6574
[max_in](const CTYPE val_in) {
66-
return std::exp(val_in - max_in);
75+
return std::exp(
76+
static_cast<ACC>(val_in) - static_cast<ACC>(max_in));
6777
},
68-
[](const CTYPE mapped_in, CTYPE val_accum) {
78+
[](const ACC mapped_in, ACC val_accum) {
6979
return val_accum + mapped_in;
7080
},
7181
in_data + base,
@@ -75,7 +85,9 @@ Tensor& log_softmax_out(
7585

7686
apply_unary_map_fn(
7787
[max_in, temp_sum](const CTYPE val_in) {
78-
return val_in - max_in - temp_sum;
88+
return static_cast<CTYPE>(
89+
static_cast<ACC>(val_in) - static_cast<ACC>(max_in) -
90+
temp_sum);
7991
},
8092
in_data + base,
8193
out_data + base,

kernels/portable/cpu/op_mean.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88
#include <c10/util/irange.h>
99

10+
#include <type_traits>
11+
1012
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1113
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1214
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -58,17 +60,24 @@ Tensor& mean_dim_out(
5860

5961
// @lint-ignore CLANGTIDY facebook-hte-CArray
6062
static constexpr const char op_name[] = "mean.out";
63+
// For half-precision inputs, accumulate in float to avoid saturation.
64+
// Matches ATen's acc_type behavior.
6165
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
66+
using ACC = std::conditional_t<
67+
std::is_same_v<CTYPE, executorch::aten::Half> ||
68+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
69+
float,
70+
CTYPE>;
6271
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
6372
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
64-
const CTYPE denom = static_cast<CTYPE>(reduce_size);
73+
const ACC denom = static_cast<ACC>(reduce_size);
6574
for (int64_t i = 0; i < outer_size; i++) {
6675
const CTYPE* row = in_data + i * reduce_size;
67-
CTYPE acc = 0;
76+
ACC acc = 0;
6877
for (int64_t j = 0; j < reduce_size; j++) {
6978
acc += row[j];
7079
}
71-
out_data[i] = acc / denom;
80+
out_data[i] = static_cast<CTYPE>(acc / denom);
7281
}
7382
});
7483
return out;
@@ -83,19 +92,25 @@ Tensor& mean_dim_out(
8392
static constexpr const char op_name[] = "mean.out";
8493
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
8594
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
95+
using ACC = std::conditional_t<
96+
std::is_same_v<CTYPE_OUT, executorch::aten::Half> ||
97+
std::is_same_v<CTYPE_OUT, executorch::aten::BFloat16>,
98+
float,
99+
CTYPE_OUT>;
86100
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
87101
const size_t num = get_reduced_dim_product(in, dim_list);
88102
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
89103
in, dim_list, out, [&](const auto begin, const auto end) {
90104
for (const auto out_ix : c10::irange(begin, end)) {
91-
CTYPE_OUT sum = 0;
105+
ACC sum = 0;
92106
if (plan.has_value()) {
93-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
94-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
95-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
107+
sum = plan->execute<CTYPE_IN, ACC>(
108+
[](CTYPE_IN v) { return static_cast<ACC>(v); },
109+
[](ACC outv, ACC acc) { return acc + outv; },
96110
out_ix);
97111
}
98-
out_data[out_ix] = sum / static_cast<float>(num);
112+
out_data[out_ix] =
113+
static_cast<CTYPE_OUT>(sum / static_cast<float>(num));
99114
}
100115
});
101116
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");

kernels/portable/cpu/op_softmax.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
#include <cmath>
10+
#include <type_traits>
1011

1112
#include <executorch/kernels/portable/cpu/util/activation_ops_util.h>
1213
#include <executorch/kernels/portable/cpu/util/functional_util.h>
@@ -42,8 +43,16 @@ Tensor& softmax_out(
4243
// Adjust for negative dim
4344
dim = dim < 0 ? dim + nonzero_dim(in) : dim;
4445

46+
// For half-precision inputs, the exp-sum is accumulated in float to avoid
47+
// saturation (BFloat16 saturates near 256, Half near 2048). Matches ATen's
48+
// acc_type behavior. See also op_grid_sampler_2d.cpp.
4549
ET_SWITCH_FLOATHBF16_TYPES(
4650
in.scalar_type(), ctx, "_softmax.out", CTYPE, [&]() {
51+
using ACC = std::conditional_t<
52+
std::is_same_v<CTYPE, executorch::aten::Half> ||
53+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
54+
float,
55+
CTYPE>;
4756
const CTYPE* const in_data = in.const_data_ptr<CTYPE>();
4857
CTYPE* const out_data = out.mutable_data_ptr<CTYPE>();
4958

@@ -61,11 +70,12 @@ Tensor& softmax_out(
6170
size,
6271
stride);
6372

64-
const CTYPE temp_sum = apply_unary_map_reduce_fn<CTYPE, CTYPE>(
73+
const ACC temp_sum = apply_unary_map_reduce_fn<CTYPE, ACC>(
6574
[max_in](const CTYPE val_in) {
66-
return std::exp(val_in - max_in);
75+
return std::exp(
76+
static_cast<ACC>(val_in) - static_cast<ACC>(max_in));
6777
},
68-
[](const CTYPE mapped_in, CTYPE val_accum) {
78+
[](const ACC mapped_in, ACC val_accum) {
6979
return val_accum + mapped_in;
7080
},
7181
in_data + base,
@@ -74,7 +84,11 @@ Tensor& softmax_out(
7484

7585
apply_unary_map_fn(
7686
[max_in, temp_sum](const CTYPE val_in) {
77-
return std::exp(val_in - max_in) / temp_sum;
87+
return static_cast<CTYPE>(
88+
std::exp(
89+
static_cast<ACC>(val_in) -
90+
static_cast<ACC>(max_in)) /
91+
temp_sum);
7892
},
7993
in_data + base,
8094
out_data + base,

kernels/portable/cpu/op_sum.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88
#include <c10/util/irange.h>
99

10+
#include <type_traits>
11+
1012
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1113
#include <executorch/runtime/kernel/kernel_includes.h>
1214
#include <executorch/runtime/platform/assert.h>
@@ -60,16 +62,23 @@ Tensor& sum_dim_out(
6062

6163
// @lint-ignore CLANGTIDY facebook-hte-CArray
6264
static constexpr const char op_name[] = "sum.IntList_out";
65+
// For half-precision inputs, accumulate in float to avoid saturation.
66+
// Matches ATen's acc_type behavior. See also op_grid_sampler_2d.cpp.
6367
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
68+
using ACC = std::conditional_t<
69+
std::is_same_v<CTYPE, executorch::aten::Half> ||
70+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
71+
float,
72+
CTYPE>;
6473
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
6574
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
6675
for (int64_t i = 0; i < outer_size; i++) {
6776
const CTYPE* row = in_data + i * reduce_size;
68-
CTYPE acc = 0;
77+
ACC acc = 0;
6978
for (int64_t j = 0; j < reduce_size; j++) {
7079
acc += row[j];
7180
}
72-
out_data[i] = acc;
81+
out_data[i] = static_cast<CTYPE>(acc);
7382
}
7483
});
7584
return out;
@@ -108,23 +117,24 @@ Tensor& sum_dim_out(
108117
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
109118
ET_SWITCH_REALHBBF16_TYPES(
110119
out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
120+
using ACC = std::conditional_t<
121+
std::is_same_v<CTYPE_OUT, executorch::aten::Half> ||
122+
std::is_same_v<CTYPE_OUT, executorch::aten::BFloat16>,
123+
float,
124+
CTYPE_OUT>;
111125
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
112126
const bool success =
113127
parallel_for_each_reduce_over_dim_list_output_index(
114128
in, dim_list, out, [&](const auto begin, const auto end) {
115129
for (const auto out_ix : c10::irange(begin, end)) {
116-
CTYPE_OUT sum = 0;
130+
ACC sum = 0;
117131
if (plan.has_value()) {
118-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
119-
[](CTYPE_IN v) {
120-
return static_cast<CTYPE_OUT>(v);
121-
},
122-
[](CTYPE_OUT outv, CTYPE_OUT acc) {
123-
return acc + outv;
124-
},
132+
sum = plan->execute<CTYPE_IN, ACC>(
133+
[](CTYPE_IN v) { return static_cast<ACC>(v); },
134+
[](ACC outv, ACC acc) { return acc + outv; },
125135
out_ix);
126136
}
127-
out_data[out_ix] = sum;
137+
out_data[out_ix] = static_cast<CTYPE_OUT>(sum);
128138
}
129139
});
130140
ET_KERNEL_CHECK_MSG(

kernels/test/op_log_softmax_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,19 @@ TEST_F(OpLogSoftmaxOutTest, SimpleGeneratedCase) {
369369
EXPECT_TENSOR_CLOSE(out, expected_result);
370370
}
371371

372+
TEST_F(OpLogSoftmaxOutTest, BFloat16LargeDimAccumulatesInFloat) {
373+
TensorFactory<ScalarType::BFloat16> tf;
374+
// N=512: without fp32 accumulation, the exp-sum saturates at BFloat16's
375+
// precision limit (~256), so the output is ~-log(256) instead of -log(512).
376+
// atol=1e-1 can catch pre-fix error: |log(512) - log(256)| = log(2)
377+
constexpr int N = 512;
378+
Tensor x = tf.zeros({1, N});
379+
Tensor out = tf.zeros({1, N});
380+
op_log_softmax_out(x, /*dim=*/1, /*half_to_float=*/false, out);
381+
Tensor expected = tf.full({1, N}, -std::log(static_cast<float>(N)));
382+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, /*rtol=*/1e-5, /*atol=*/1e-1);
383+
}
384+
372385
TEST_F(OpLogSoftmaxOutTest, DynamicShapeUpperBoundSameAsExpected) {
373386
TensorFactory<ScalarType::Float> tf;
374387

kernels/test/op_mean_test.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,35 @@ void OpMeanOutTest::
263263
test_mean_dim_out_bool<ScalarType::Double>();
264264
}
265265

266+
TEST_F(OpMeanOutTest, BFloat16GenericPathAccumulatesInFloat) {
267+
TensorFactory<ScalarType::BFloat16> tf;
268+
// Reducing dim=0 of {512, 1} is not the last dim, so the generic path is
269+
// taken. Without fp32 accumulation the sum saturates at ~256, giving
270+
// 256/512 = 0.5 instead of 1.0.
271+
constexpr int N = 512;
272+
Tensor x = tf.ones({N, 1});
273+
Tensor out = tf.zeros({1});
274+
int64_t dim = 0;
275+
op_mean_out(
276+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
277+
Tensor expected = tf.full({1}, 1.0f);
278+
EXPECT_TENSOR_CLOSE(out, expected);
279+
}
280+
281+
TEST_F(OpMeanOutTest, BFloat16LargeDimAccumulatesInFloat) {
282+
TensorFactory<ScalarType::BFloat16> tf;
283+
// N=512, all-ones input: without fp32 accumulation the sum saturates at
284+
// ~256 in BFloat16, giving 256/512 = 0.5 instead of 1.0.
285+
constexpr int N = 512;
286+
Tensor x = tf.ones({1, N});
287+
Tensor out = tf.zeros({1});
288+
int64_t dim = 1;
289+
op_mean_out(
290+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
291+
Tensor expected = tf.full({1}, 1.0f);
292+
EXPECT_TENSOR_CLOSE(out, expected);
293+
}
294+
266295
TEST_F(OpMeanOutTest, InvalidDimensionListDies) {
267296
ET_SKIP_IF(
268297
torch::executor::testing::SupportedFeatures::get()->is_aten,

kernels/test/op_softmax_test.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ TEST_F(OpSoftmaxOutTest, SimpleGeneratedCase) {
251251
EXPECT_TENSOR_CLOSE(out, expected_result);
252252
}
253253

254+
TEST_F(OpSoftmaxOutTest, BFloat16LargeDimAccumulatesInFloat) {
255+
TensorFactory<ScalarType::BFloat16> tf;
256+
// N=512: without fp32 accumulation the exp-sum saturates at BFloat16's
257+
// precision limit (~256), so the output is ~1/256 instead of 1/512.
258+
// 1e-3 is tight enough to catch pre-fix error: |1/256 - 1/512| ≈ 0.00195
259+
constexpr int N = 512;
260+
Tensor x = tf.zeros({1, N});
261+
Tensor out = tf.zeros({1, N});
262+
op_softmax_out(x, /*dim=*/1, /*half_to_float=*/false, out);
263+
Tensor expected = tf.full({1, N}, 1.0f / N);
264+
EXPECT_TENSOR_CLOSE_WITH_TOL(out, expected, /*rtol=*/1e-5, /*atol=*/1e-3);
265+
}
266+
254267
TEST_F(OpSoftmaxOutTest, DynamicShapeUpperBoundSameAsExpected) {
255268
TensorFactory<ScalarType::Float> tf;
256269

0 commit comments

Comments
 (0)