Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 35 additions & 42 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ Tensor& opt_add_out(
ET_SWITCH_REALB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx,
torch::executor::native::utils::extract_scalar(alpha, &alpha_val),
InvalidArgument, );
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
CTYPE_B b_val = *b.const_data_ptr<CTYPE_B>();
CTYPE b_casted = static_cast<CTYPE>(b_val);

Expand Down Expand Up @@ -81,7 +79,6 @@ Tensor& opt_add_scalar_out(
(void)ctx;

ScalarType a_type = a.scalar_type();
ScalarType b_type = utils::get_scalar_dtype(b);
ScalarType common_type =
utils::promote_type_with_scalar(a_type, b, /*half_to_float*/ false);
ScalarType out_type = out.scalar_type();
Expand All @@ -99,47 +96,43 @@ Tensor& opt_add_scalar_out(
if (a_type == common_type && a_type == out_type &&
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
CTYPE_B b_val;
ET_EXTRACT_SCALAR(b, b_val);
CTYPE b_casted = static_cast<CTYPE>(b_val);
CTYPE alpha_val;
ET_EXTRACT_SCALAR(alpha, alpha_val);

using Vec = at::vec::Vectorized<CTYPE>;
at::vec::map<CTYPE>(
[alpha_val, b_casted](Vec x) {
return x + Vec(alpha_val * b_casted);
},
out.mutable_data_ptr<CTYPE>(),
a.const_data_ptr<CTYPE>(),
out.numel());
});
CTYPE b_casted = utils::scalar_to<CTYPE>(b);
CTYPE alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );

using Vec = at::vec::Vectorized<CTYPE>;
at::vec::map<CTYPE>(
[alpha_val, b_casted](Vec x) {
return x + Vec(alpha_val * b_casted);
},
out.mutable_data_ptr<CTYPE>(),
a.const_data_ptr<CTYPE>(),
out.numel());
});
} else {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_B b_val;
ET_EXTRACT_SCALAR(b, b_val);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(b_val);
CTYPE_IN alpha_val;
ET_EXTRACT_SCALAR(alpha, alpha_val);

const size_t n = a.numel();
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
for (auto i = 0; i < n; ++i) {
out_data[i] = static_cast<CTYPE_OUT>(
static_cast<CTYPE_IN>(a_data[i]) +
alpha_val * b_casted);
}
});
});
});
ET_SWITCH_REALB_TYPES(
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_IN b_casted = utils::scalar_to<CTYPE_IN>(b);
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx,
utils::extract_scalar(alpha, &alpha_val),
InvalidArgument, );

const size_t n = a.numel();
const CTYPE_A* a_data = a.const_data_ptr<CTYPE_A>();
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
for (auto i = 0; i < n; ++i) {
out_data[i] = static_cast<CTYPE_OUT>(
static_cast<CTYPE_IN>(a_data[i]) +
alpha_val * b_casted);
}
});
});
});
}

Expand Down
8 changes: 6 additions & 2 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ Tensor& add_out(
static constexpr const char op_name[] = "add.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
CTYPE_COMPUTE val_alpha;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
Expand Down Expand Up @@ -103,7 +105,9 @@ Tensor& add_scalar_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
CTYPE_COMPUTE val_alpha;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &val_alpha), InvalidArgument, );
auto val_alpha_times_b = val_alpha * val_b;
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
Expand Down
112 changes: 110 additions & 2 deletions kernels/test/op_add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@

#include <gtest/gtest.h>

#include <iostream>

using namespace ::testing;
using executorch::aten::Scalar;
using executorch::aten::ScalarType;
Expand Down Expand Up @@ -231,6 +229,16 @@ class OpAddOutKernelTest : public OperatorTest {
EXPECT_TENSOR_CLOSE(op_add_out(a, b, 1.0, out), expected);
EXPECT_TENSOR_CLOSE(op_add_out(b, a, 1.0, out), expected);
}

template <ScalarType DTYPE>
void expect_bad_alpha_value_dies(Scalar bad_value) {
TensorFactory<DTYPE> tf;
Tensor a = tf.ones({2, 2});
Tensor b = tf.ones({2, 2});
Tensor out = tf.zeros({2, 2});

ET_EXPECT_KERNEL_FAILURE(context_, op_add_out(a, b, bad_value, out));
}
};

class OpAddScalarOutKernelTest : public OperatorTest {
Expand All @@ -242,6 +250,16 @@ class OpAddScalarOutKernelTest : public OperatorTest {
Tensor& out) {
return torch::executor::aten::add_outf(context_, self, other, alpha, out);
}

template <ScalarType DTYPE>
void expect_bad_alpha_value_dies(Scalar bad_value) {
TensorFactory<DTYPE> tf;
Tensor a = tf.ones({2, 2});
Scalar b = 1;
Tensor out = tf.zeros({2, 2});

ET_EXPECT_KERNEL_FAILURE(context_, op_add_scalar_out(a, b, bad_value, out));
}
};

/**
Expand Down Expand Up @@ -794,3 +812,93 @@ TEST_F(OpAddScalarOutKernelTest, DtypeTest_float16_bool_int_float16) {
op_add_scalar_out(self, other, alpha, out);
EXPECT_TENSOR_CLOSE(out, out_expected);
}

TEST_F(OpAddOutKernelTest, ByteTensorTooLargeAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
}

TEST_F(OpAddOutKernelTest, ByteTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
}

#ifndef USE_ATEN_LIB
TEST_F(OpAddOutKernelTest, IntTensorTooSmallAlphaDies) {
// Cannot be represented by a int32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
}

TEST_F(OpAddOutKernelTest, IntTensorTooLargeAlphaDies) {
// Cannot be represented by a int32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
}
#endif

TEST_F(OpAddOutKernelTest, IntTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
}

TEST_F(OpAddOutKernelTest, FloatTensorTooSmallAlphaDies) {
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
}

TEST_F(OpAddOutKernelTest, FloatTensorTooLargeAlphaDies) {
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
}

TEST_F(OpAddOutKernelTest, HalfTensorTooLargeAlphaDies) {
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "Portable kernel does the computation in float";
}
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
}

TEST_F(OpAddScalarOutKernelTest, ByteTensorTooLargeAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(256);
}

TEST_F(OpAddScalarOutKernelTest, ByteTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint8_t.
expect_bad_alpha_value_dies<ScalarType::Byte>(2.2);
}

#ifndef USE_ATEN_LIB
TEST_F(OpAddScalarOutKernelTest, IntTensorTooSmallAlphaDies) {
// Cannot be represented by a int32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(-2147483649);
}

TEST_F(OpAddScalarOutKernelTest, IntTensorTooLargeAlphaDies) {
// Cannot be represented by a int32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2147483648);
}
#endif

TEST_F(OpAddScalarOutKernelTest, IntTensorFloatingPointAlphaDies) {
// Cannot be represented by a uint32_t.
expect_bad_alpha_value_dies<ScalarType::Int>(2.2);
}

TEST_F(OpAddScalarOutKernelTest, FloatTensorTooSmallAlphaDies) {
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Float>(-3.41e+38);
}

TEST_F(OpAddScalarOutKernelTest, FloatTensorTooLargeAlphaDies) {
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Float>(3.41e+38);
}

TEST_F(OpAddScalarOutKernelTest, HalfTensorTooLargeAlphaDies) {
if (!torch::executor::testing::SupportedFeatures::get()->is_aten) {
GTEST_SKIP() << "Portable kernel does the computation in float";
}
// Cannot be represented by a float.
expect_bad_alpha_value_dies<ScalarType::Half>(65505.0);
}
Loading