Skip to content

Commit 16ba018

Browse files
committed
portable: accumulate in fp32 for Half/BFloat16 in mean and sum
Problem: The fast-path and generic reduction loops in mean.out and sum.IntList_out accumulated the running sum in the tensor dtype. For BFloat16, the sum saturates around 256, so a mean over N=512 all-ones elements gives 0.5 instead of 1.0, and summing 512 all-ones elements gives 256 instead of 512. Changes: Accumulate in float for Half/BFloat16 by promoting the loop accumulator to ACC in both the fast path and the generic path. The final result is cast back to the tensor dtype on store. Continues the fp32-accumulation work in #19117.
1 parent 9a39008 commit 16ba018

4 files changed

Lines changed: 102 additions & 19 deletions

File tree

kernels/portable/cpu/op_mean.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88
#include <c10/util/irange.h>
99

10+
#include <type_traits>
11+
1012
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1113
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1214
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -58,17 +60,24 @@ Tensor& mean_dim_out(
5860

5961
// @lint-ignore CLANGTIDY facebook-hte-CArray
6062
static constexpr const char op_name[] = "mean.out";
63+
// For half-precision inputs, accumulate in float to avoid saturation.
64+
// Matches ATen's acc_type behavior.
6165
ET_SWITCH_FLOATHBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
66+
using ACC = std::conditional_t<
67+
std::is_same_v<CTYPE, executorch::aten::Half> ||
68+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
69+
float,
70+
CTYPE>;
6271
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
6372
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
64-
const CTYPE denom = static_cast<CTYPE>(reduce_size);
73+
const ACC denom = static_cast<ACC>(reduce_size);
6574
for (int64_t i = 0; i < outer_size; i++) {
6675
const CTYPE* row = in_data + i * reduce_size;
67-
CTYPE acc = 0;
76+
ACC acc = 0;
6877
for (int64_t j = 0; j < reduce_size; j++) {
6978
acc += row[j];
7079
}
71-
out_data[i] = acc / denom;
80+
out_data[i] = static_cast<CTYPE>(acc / denom);
7281
}
7382
});
7483
return out;
@@ -83,19 +92,25 @@ Tensor& mean_dim_out(
8392
static constexpr const char op_name[] = "mean.out";
8493
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
8594
ET_SWITCH_FLOATHBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
95+
using ACC = std::conditional_t<
96+
std::is_same_v<CTYPE_OUT, executorch::aten::Half> ||
97+
std::is_same_v<CTYPE_OUT, executorch::aten::BFloat16>,
98+
float,
99+
CTYPE_OUT>;
86100
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
87101
const size_t num = get_reduced_dim_product(in, dim_list);
88102
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
89103
in, dim_list, out, [&](const auto begin, const auto end) {
90104
for (const auto out_ix : c10::irange(begin, end)) {
91-
CTYPE_OUT sum = 0;
105+
ACC sum = 0;
92106
if (plan.has_value()) {
93-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
94-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
95-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
107+
sum = plan->execute<CTYPE_IN, ACC>(
108+
[](CTYPE_IN v) { return static_cast<ACC>(v); },
109+
[](ACC outv, ACC acc) { return acc + outv; },
96110
out_ix);
97111
}
98-
out_data[out_ix] = sum / static_cast<float>(num);
112+
out_data[out_ix] =
113+
static_cast<CTYPE_OUT>(sum / static_cast<float>(num));
99114
}
100115
});
101116
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");

kernels/portable/cpu/op_sum.cpp

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
*/
88
#include <c10/util/irange.h>
99

10+
#include <type_traits>
11+
1012
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
1113
#include <executorch/runtime/kernel/kernel_includes.h>
1214
#include <executorch/runtime/platform/assert.h>
@@ -60,16 +62,23 @@ Tensor& sum_dim_out(
6062

6163
// @lint-ignore CLANGTIDY facebook-hte-CArray
6264
static constexpr const char op_name[] = "sum.IntList_out";
65+
// For half-precision inputs, accumulate in float to avoid saturation.
66+
// Matches ATen's acc_type behavior. See also op_grid_sampler_2d.cpp.
6367
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
68+
using ACC = std::conditional_t<
69+
std::is_same_v<CTYPE, executorch::aten::Half> ||
70+
std::is_same_v<CTYPE, executorch::aten::BFloat16>,
71+
float,
72+
CTYPE>;
6473
const CTYPE* in_data = in.const_data_ptr<CTYPE>();
6574
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
6675
for (int64_t i = 0; i < outer_size; i++) {
6776
const CTYPE* row = in_data + i * reduce_size;
68-
CTYPE acc = 0;
77+
ACC acc = 0;
6978
for (int64_t j = 0; j < reduce_size; j++) {
7079
acc += row[j];
7180
}
72-
out_data[i] = acc;
81+
out_data[i] = static_cast<CTYPE>(acc);
7382
}
7483
});
7584
return out;
@@ -108,23 +117,24 @@ Tensor& sum_dim_out(
108117
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
109118
ET_SWITCH_REALHBBF16_TYPES(
110119
out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
120+
using ACC = std::conditional_t<
121+
std::is_same_v<CTYPE_OUT, executorch::aten::Half> ||
122+
std::is_same_v<CTYPE_OUT, executorch::aten::BFloat16>,
123+
float,
124+
CTYPE_OUT>;
111125
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
112126
const bool success =
113127
parallel_for_each_reduce_over_dim_list_output_index(
114128
in, dim_list, out, [&](const auto begin, const auto end) {
115129
for (const auto out_ix : c10::irange(begin, end)) {
116-
CTYPE_OUT sum = 0;
130+
ACC sum = 0;
117131
if (plan.has_value()) {
118-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
119-
[](CTYPE_IN v) {
120-
return static_cast<CTYPE_OUT>(v);
121-
},
122-
[](CTYPE_OUT outv, CTYPE_OUT acc) {
123-
return acc + outv;
124-
},
132+
sum = plan->execute<CTYPE_IN, ACC>(
133+
[](CTYPE_IN v) { return static_cast<ACC>(v); },
134+
[](ACC outv, ACC acc) { return acc + outv; },
125135
out_ix);
126136
}
127-
out_data[out_ix] = sum;
137+
out_data[out_ix] = static_cast<CTYPE_OUT>(sum);
128138
}
129139
});
130140
ET_KERNEL_CHECK_MSG(

kernels/test/op_mean_test.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,35 @@ void OpMeanOutTest::
263263
test_mean_dim_out_bool<ScalarType::Double>();
264264
}
265265

266+
TEST_F(OpMeanOutTest, BFloat16GenericPathAccumulatesInFloat) {
267+
TensorFactory<ScalarType::BFloat16> tf;
268+
// Reducing dim=0 of {512, 1} is not the last dim, so the generic path is
269+
// taken. Without fp32 accumulation the sum saturates at ~256, giving
270+
// 256/512 = 0.5 instead of 1.0.
271+
constexpr int N = 512;
272+
Tensor x = tf.ones({N, 1});
273+
Tensor out = tf.zeros({1});
274+
int64_t dim = 0;
275+
op_mean_out(
276+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
277+
Tensor expected = tf.full({1}, 1.0f);
278+
EXPECT_TENSOR_CLOSE(out, expected);
279+
}
280+
281+
TEST_F(OpMeanOutTest, BFloat16LargeDimAccumulatesInFloat) {
282+
TensorFactory<ScalarType::BFloat16> tf;
283+
// N=512, all-ones input: without fp32 accumulation the sum saturates at
284+
// ~256 in BFloat16, giving 256/512 = 0.5 instead of 1.0.
285+
constexpr int N = 512;
286+
Tensor x = tf.ones({1, N});
287+
Tensor out = tf.zeros({1});
288+
int64_t dim = 1;
289+
op_mean_out(
290+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
291+
Tensor expected = tf.full({1}, 1.0f);
292+
EXPECT_TENSOR_CLOSE(out, expected);
293+
}
294+
266295
TEST_F(OpMeanOutTest, InvalidDimensionListDies) {
267296
ET_SKIP_IF(
268297
torch::executor::testing::SupportedFeatures::get()->is_aten,

kernels/test/op_sum_test.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,35 @@ class OpSumOutTest : public OperatorTest {
307307
}
308308
};
309309

310+
TEST_F(OpSumOutTest, BFloat16GenericPathAccumulatesInFloat) {
311+
TensorFactory<ScalarType::BFloat16> tf;
312+
// Reducing dim=0 of {512, 1} is not the last dim, so the generic path is
313+
// taken. Without fp32 accumulation the sum saturates at ~256 instead of
314+
// 512. 512 = 2^9 is exactly representable in BFloat16.
315+
constexpr int N = 512;
316+
Tensor x = tf.ones({N, 1});
317+
Tensor out = tf.zeros({1});
318+
int64_t dim = 0;
319+
op_sum_intlist_out(
320+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
321+
Tensor expected = tf.full({1}, static_cast<float>(N));
322+
EXPECT_TENSOR_CLOSE(out, expected);
323+
}
324+
325+
TEST_F(OpSumOutTest, BFloat16LargeDimAccumulatesInFloat) {
326+
TensorFactory<ScalarType::BFloat16> tf;
327+
// N=512, all-ones input: without fp32 accumulation the sum saturates at
328+
// ~256 in BFloat16 instead of 512.
329+
constexpr int N = 512;
330+
Tensor x = tf.ones({1, N});
331+
Tensor out = tf.zeros({1});
332+
int64_t dim = 1;
333+
op_sum_intlist_out(
334+
x, ArrayRef<int64_t>{&dim, 1}, /*keepdim=*/false, /*dtype=*/{}, out);
335+
Tensor expected = tf.full({1}, static_cast<float>(N));
336+
EXPECT_TENSOR_CLOSE(out, expected);
337+
}
338+
310339
TEST_F(OpSumOutTest, InvalidDimensionListDies) {
311340
ET_SKIP_IF(
312341
torch::executor::testing::SupportedFeatures::get()->is_aten,

0 commit comments

Comments
 (0)