diff --git a/kernels/test/op_le_test.cpp b/kernels/test/op_le_test.cpp index d96d87be596..9bc4e0a4cfa 100644 --- a/kernels/test/op_le_test.cpp +++ b/kernels/test/op_le_test.cpp @@ -991,6 +991,41 @@ TEST_F(OpLeTensorOutTest, Broadcast1DTo2DShapeTest) { EXPECT_TENSOR_EQ(out, tf_bool.make({1, 6}, expected_data)); } +TEST_F(OpLeTensorOutTest, Broadcast2DBy1DShapeTest) { + TensorFactory tf; + TensorFactory tf_bool; + + // Test case: (1, 10) and (6,) -> (6, 10) + Tensor a = tf.make({1, 10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + Tensor b = tf.make({6}, {2, 4, 6, 8, 10, 12}); + + Tensor out = tf_bool.zeros({6, 10}); + + op_le_tensor_out(a, b, out); + + // Expected: a[0,j] <= b[i] for all i,j + // Each row i should be [a[0,0]<=b[i], a[0,1]<=b[i], ..., a[0,9]<=b[i]] + using ctype = + executorch::runtime::testing::internal::ScalarTypeToCppTypeWrapper< + ScalarType::Bool>::ctype; + std::vector expected_data = { + // Row 0 (b=2): [1,2,3,4,5,6,7,8,9,10] <= 2 + true, true, false, false, false, false, false, false, false, false, + // Row 1 (b=4): [1,2,3,4,5,6,7,8,9,10] <= 4 + true, true, true, true, false, false, false, false, false, false, + // Row 2 (b=6): [1,2,3,4,5,6,7,8,9,10] <= 6 + true, true, true, true, true, true, false, false, false, false, + // Row 3 (b=8): [1,2,3,4,5,6,7,8,9,10] <= 8 + true, true, true, true, true, true, true, true, false, false, + // Row 4 (b=10): [1,2,3,4,5,6,7,8,9,10] <= 10 + true, true, true, true, true, true, true, true, true, true, + // Row 5 (b=12): [1,2,3,4,5,6,7,8,9,10] <= 12 + true, true, true, true, true, true, true, true, true, true + }; + + EXPECT_TENSOR_EQ(out, tf_bool.make({6, 10}, expected_data)); +} + TEST_F(OpLeTensorOutTest, Broadcast2dBy1dReverseTest) { TensorFactory tf; TensorFactory tf_bool;