From 0dd85d9be621e579c8e9a6ae63a0bba279668b5b Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 26 Jun 2025 14:29:16 -0700 Subject: [PATCH] [ET][Portable] Check scalar overflow: op_scatter Differential Revision: [D77401093](https://our.internmc.facebook.com/intern/diff/D77401093/) [ghstack-poisoned] --- kernels/portable/cpu/op_scatter.cpp | 4 +++- kernels/test/op_scatter_test.cpp | 16 ++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/kernels/portable/cpu/op_scatter.cpp b/kernels/portable/cpu/op_scatter.cpp index 7de0ec4d5f9..965afbb4b66 100644 --- a/kernels/portable/cpu/op_scatter.cpp +++ b/kernels/portable/cpu/op_scatter.cpp @@ -154,7 +154,9 @@ Tensor& scatter_value_out( constexpr auto name = "scatter.value_out"; ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() { - const CTYPE val = utils::scalar_to(value); + auto opt_val = utils::internal::check_overflow_scalar_cast(value); + ET_KERNEL_CHECK(ctx, opt_val.has_value(), InvalidArgument, ); + auto val = opt_val.value(); scatter_value_helper(in, dim, index, val, out); }); diff --git a/kernels/test/op_scatter_test.cpp b/kernels/test/op_scatter_test.cpp index 0e55aadaeda..dac9017d188 100644 --- a/kernels/test/op_scatter_test.cpp +++ b/kernels/test/op_scatter_test.cpp @@ -7,6 +7,7 @@ */ #include // Declares the operator +#include #include #include #include @@ -364,6 +365,19 @@ class OpScatterValueOutTest : public OperatorTest { op_scatter_value_out(input, 2, index, value, out); EXPECT_TENSOR_EQ(out, expected); } + + template + void expect_bad_scalar_value_dies(const Scalar& bad_value) { + TensorFactory tf; + TensorFactory tf_index; + + Tensor self = tf.ones({2, 2}); + Tensor index = tf_index.zeros({2, 2}); + Tensor out = tf.zeros({2, 2}); + + ET_EXPECT_KERNEL_FAILURE( + context_, op_scatter_value_out(self, 0, index, bad_value, out)); + } }; TEST_F(OpScatterSrcOutTest, AllValidInputOutputSupport) { @@ -652,3 +666,5 @@ TEST_F(OpScatterSrcOutTest, InvalidOneDimInputAndZeroDimIndex) { ET_EXPECT_KERNEL_FAILURE( context_, op_scatter_src_out(self, 0, index, src, out)); } + +GENERATE_SCALAR_OVERFLOW_TESTS(OpScatterValueOutTest)