Skip to content

Commit ba12352

Browse files
Add complex type support to div operator (#17414)
Summary: As titled Differential Revision: D93086411
1 parent 7343cb0 commit ba12352

3 files changed

Lines changed: 144 additions & 21 deletions

File tree

kernels/optimized/cpu/op_div.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,11 @@ namespace native {
2121
namespace {
2222

2323
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
24-
ET_CHECK(
25-
!isComplexType(a_type) && !isQIntType(a_type) && !isBitsType(a_type));
26-
ET_CHECK(
27-
!isComplexType(b_type) && !isQIntType(b_type) && !isBitsType(b_type));
24+
if (isComplexType(a_type) || isComplexType(b_type)) {
25+
return promoteTypes(a_type, b_type);
26+
}
27+
ET_CHECK(!isQIntType(a_type) && !isBitsType(a_type));
28+
ET_CHECK(!isQIntType(b_type) && !isBitsType(b_type));
2829

2930
if (isFloatingType(a_type) && isFloatingType(b_type)) {
3031
return promoteTypes(a_type, b_type);
@@ -61,6 +62,20 @@ Tensor& opt_div_out(
6162
ScalarType b_type = b.scalar_type();
6263
ScalarType out_type = out.scalar_type();
6364

65+
// Handle complex types
66+
if (isComplexType(a_type) || isComplexType(b_type)) {
67+
ScalarType common_type = get_common_type(a_type, b_type);
68+
ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() {
69+
const CTYPE* a_data = a.const_data_ptr<CTYPE>();
70+
const CTYPE* b_data = b.const_data_ptr<CTYPE>();
71+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
72+
for (size_t i = 0; i < out.numel(); ++i) {
73+
out_data[i] = a_data[i] / b_data[i];
74+
}
75+
});
76+
return out;
77+
}
78+
6479
if (a.numel() == 1 || b.numel() == 1) {
6580
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
6681
a_type != ScalarType::BFloat16) {

kernels/portable/cpu/op_div.cpp

Lines changed: 30 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ namespace native {
2020
namespace {
2121

2222
ScalarType get_common_type(ScalarType a_type, ScalarType b_type) {
23-
if (isFloatingType(a_type) && isFloatingType(b_type)) {
23+
if (executorch::runtime::isComplexType(a_type) ||
24+
executorch::runtime::isComplexType(b_type)) {
25+
return promoteTypes(a_type, b_type);
26+
} else if (isFloatingType(a_type) && isFloatingType(b_type)) {
2427
return promoteTypes(a_type, b_type);
2528
} else if (isFloatingType(a_type)) {
2629
return a_type;
@@ -51,25 +54,35 @@ Tensor& div_out(
5154
InvalidArgument,
5255
out);
5356

54-
// Compute Dtype
55-
ScalarType compute_type = utils::get_compute_type(common_type);
56-
5757
// @lint-ignore CLANGTIDY facebook-hte-CArray
5858
static constexpr const char op_name[] = "div.out";
5959

60-
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
61-
utils::apply_bitensor_elementwise_fn<
62-
CTYPE_COMPUTE,
63-
op_name,
64-
utils::SupportedTensorDtypes::FLOATHBF16>(
65-
[](const auto& val_a, const auto& val_b) { return val_a / val_b; },
66-
ctx,
67-
a,
68-
utils::SupportedTensorDtypes::REALHBBF16,
69-
b,
70-
utils::SupportedTensorDtypes::REALHBBF16,
71-
out);
72-
});
60+
if (executorch::runtime::isComplexType(common_type)) {
61+
ET_SWITCH_COMPLEX_TYPES(common_type, ctx, op_name, CTYPE, [&]() {
62+
const CTYPE* a_data = a.const_data_ptr<CTYPE>();
63+
const CTYPE* b_data = b.const_data_ptr<CTYPE>();
64+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
65+
for (ssize_t i = 0; i < out.numel(); ++i) {
66+
out_data[i] = a_data[i] / b_data[i];
67+
}
68+
});
69+
} else {
70+
// Compute Dtype for real types
71+
ScalarType compute_type = utils::get_compute_type(common_type);
72+
ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
73+
utils::apply_bitensor_elementwise_fn<
74+
CTYPE_COMPUTE,
75+
op_name,
76+
utils::SupportedTensorDtypes::FLOATHBF16>(
77+
[](const auto& val_a, const auto& val_b) { return val_a / val_b; },
78+
ctx,
79+
a,
80+
utils::SupportedTensorDtypes::REALHBBF16,
81+
b,
82+
utils::SupportedTensorDtypes::REALHBBF16,
83+
out);
84+
});
85+
}
7386

7487
return out;
7588
}

kernels/test/op_div_test.cpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,3 +601,98 @@ TEST_F(OpDivScalarOutTest, OptimizedSanityCheck) {
601601
// Check that it matches the expected output.
602602
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {0.65, 1.05, 2.3, 4.1}));
603603
}
604+
605+
//
606+
// Complex Type Tests
607+
//
608+
609+
TEST_F(OpDivOutTest, ComplexFloatBasic) {
610+
TensorFactory<ScalarType::ComplexFloat> tf;
611+
612+
const std::vector<int32_t> sizes = {2, 2};
613+
614+
// (1+2i) / (1+0i) = (1+2i)
615+
// (4+4i) / (2+0i) = (2+2i)
616+
// (3+4i) / (1-1i) = (3+4i)(1+1i) / 2 = (-1+7i) / 2 = (-0.5+3.5i)
617+
// (8+0i) / (2+2i) = (8)(2-2i) / 8 = (2-2i)
618+
Tensor a = tf.make(
619+
sizes,
620+
{executorch::aten::complex<float>(1.0f, 2.0f),
621+
executorch::aten::complex<float>(4.0f, 4.0f),
622+
executorch::aten::complex<float>(3.0f, 4.0f),
623+
executorch::aten::complex<float>(8.0f, 0.0f)});
624+
625+
Tensor b = tf.make(
626+
sizes,
627+
{executorch::aten::complex<float>(1.0f, 0.0f),
628+
executorch::aten::complex<float>(2.0f, 0.0f),
629+
executorch::aten::complex<float>(1.0f, -1.0f),
630+
executorch::aten::complex<float>(2.0f, 2.0f)});
631+
632+
Tensor out = tf.zeros(sizes);
633+
634+
op_div_out(a, b, out);
635+
636+
Tensor expected = tf.make(
637+
sizes,
638+
{executorch::aten::complex<float>(1.0f, 2.0f),
639+
executorch::aten::complex<float>(2.0f, 2.0f),
640+
executorch::aten::complex<float>(-0.5f, 3.5f),
641+
executorch::aten::complex<float>(2.0f, -2.0f)});
642+
643+
EXPECT_TENSOR_CLOSE(out, expected);
644+
}
645+
646+
TEST_F(OpDivOutTest, ComplexDoubleBasic) {
647+
TensorFactory<ScalarType::ComplexDouble> tf;
648+
649+
const std::vector<int32_t> sizes = {2};
650+
651+
Tensor a = tf.make(
652+
sizes,
653+
{executorch::aten::complex<double>(6.0, 8.0),
654+
executorch::aten::complex<double>(4.0, 0.0)});
655+
656+
Tensor b = tf.make(
657+
sizes,
658+
{executorch::aten::complex<double>(2.0, 0.0),
659+
executorch::aten::complex<double>(0.0, 2.0)});
660+
661+
Tensor out = tf.zeros(sizes);
662+
663+
op_div_out(a, b, out);
664+
665+
// (6+8i) / 2 = (3+4i)
666+
// 4 / 2i = 4 * (-i) / 2 = -2i = (0-2i)
667+
Tensor expected = tf.make(
668+
sizes,
669+
{executorch::aten::complex<double>(3.0, 4.0),
670+
executorch::aten::complex<double>(0.0, -2.0)});
671+
672+
EXPECT_TENSOR_CLOSE(out, expected);
673+
}
674+
675+
TEST_F(OpDivOutTest, ComplexFloatIdentity) {
676+
TensorFactory<ScalarType::ComplexFloat> tf;
677+
678+
const std::vector<int32_t> sizes = {3};
679+
680+
// Dividing by 1 should return the same value
681+
Tensor a = tf.make(
682+
sizes,
683+
{executorch::aten::complex<float>(1.0f, 2.0f),
684+
executorch::aten::complex<float>(3.0f, 4.0f),
685+
executorch::aten::complex<float>(-5.0f, 6.0f)});
686+
687+
Tensor one = tf.make(
688+
sizes,
689+
{executorch::aten::complex<float>(1.0f, 0.0f),
690+
executorch::aten::complex<float>(1.0f, 0.0f),
691+
executorch::aten::complex<float>(1.0f, 0.0f)});
692+
693+
Tensor out = tf.zeros(sizes);
694+
695+
op_div_out(a, one, out);
696+
697+
EXPECT_TENSOR_CLOSE(out, a);
698+
}

0 commit comments

Comments
 (0)