diff --git a/third_party/xla/xla/service/shape_inference.cc b/third_party/xla/xla/service/shape_inference.cc index 2a006352e566a3..b23b00d29a7784 100644 --- a/third_party/xla/xla/service/shape_inference.cc +++ b/third_party/xla/xla/service/shape_inference.cc @@ -1320,8 +1320,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(const Shape& lhs, output_dimensions_expressions[i] = lhs.dimensions(i) == 1 ? rhs.expressions(i) : lhs.expressions(i); - } else if (lhs.dimensions(i) == rhs.dimensions(i)) { // && - // *lhs.expressions(i) == *rhs.expressions(i)) { + } else if (lhs.dimensions(i) == rhs.dimensions(i)) { // LHS | RHS | Result // X | X | X // X | <=X | <=X @@ -1529,6 +1528,7 @@ ShapeInference::InferElementwiseBinaryOpShape( for (int64_t i = 0; i < rhs.dimensions_size(); ++i) { if (rhs.is_dynamic_dimension(i)) { result.set_dynamic_dimension(i, true); + result.set_expression(i, rhs.expressions(i)); } } diff --git a/third_party/xla/xla/service/shape_inference_test.cc b/third_party/xla/xla/service/shape_inference_test.cc index 87841d476712d7..e57022767794a3 100644 --- a/third_party/xla/xla/service/shape_inference_test.cc +++ b/third_party/xla/xla/service/shape_inference_test.cc @@ -2822,6 +2822,18 @@ TEST_F(ShapeInferenceTest, BinOpBroadcastBadDimension) { HasSubstr("dimensions order is wrong")); } +TEST_F(ShapeInferenceTest, BinOpPreservesBroadcastedExpressionSameRank) { + Shape lhs = ShapeUtil::MakeShape(F32, {1}, {true}); + lhs.set_expression(0, DExpr::Const(1)); + Shape rhs = ShapeUtil::MakeShape(F32, {1}, {true}); + rhs.set_expression(0, DExpr::Var(1)); + + const absl::StatusOr inferred_shape = + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs, rhs, {0}); + ASSERT_IS_OK(inferred_shape.status()); + EXPECT_EQ(inferred_shape->expressions(0), DExpr::Var(1)); +} + // Tests for the while instruction with proper shapes. TEST_F(ShapeInferenceTest, WhileWithCorrectShapes) { const Shape result_shape = ShapeUtil::MakeTupleShape({s32_, vector_32_});