From b29a9f88f2d5f5a9bd799ecb259a5db528414218 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 11 May 2026 17:00:35 +0100 Subject: [PATCH 1/4] Add CPU concat-dot splitting flag --- third_party/xla/xla/debug_options_flags.cc | 7 ++ .../xla/xla/service/cpu/cpu_compiler.cc | 112 ++++++++++++++++++ third_party/xla/xla/xla.proto | 5 + 3 files changed, 124 insertions(+) diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 961273151ed5e0..4d1c27d5e02b87 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -124,6 +124,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // By default, copy TF's Eigen style min_max behavior with nans. opts.set_xla_cpu_enable_fast_min_max(true); + opts.set_xla_cpu_split_concat_dot(false); opts.set_xla_gpu_enable_cublaslt(false); @@ -1032,6 +1033,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_cpu_enable_concurrency_optimized_scheduler(), "Use HLO module scheduler that is optimized for extracting concurrency " "from an HLO module by trading off extra memory pressure.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_split_concat_dot", + bool_setter_for(&DebugOptions::set_xla_cpu_split_concat_dot), + debug_options->xla_cpu_split_concat_dot(), + "Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs " + "to avoid materializing concat inputs to CPU dots.")); flag_list->push_back(tsl::Flag( "xla_cpu_prefer_vector_width", int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width), diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 042e7e5b33a0f7..7efb9e4b4f6df8 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -110,6 +110,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" +#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" #include "xla/hlo/transforms/expanders/cholesky_expander.h" @@ -522,6 +523,114 @@ std::unique_ptr> CreateSimplificationPipeline( return pipeline; } +class CpuSplitConcatDot : public HloModulePass { + public: + absl::string_view name() const override { return "cpu-split-concat-dot"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override { + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + TF_ASSIGN_OR_RETURN(bool rewritten, TryRewriteDot(instruction)); + changed |= rewritten; + } + } + return changed; + } + + private: + absl::StatusOr TryRewriteDot(HloInstruction* dot) { + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + if (dim_numbers.lhs_contracting_dimensions_size() != 1 || + dim_numbers.rhs_contracting_dimensions_size() != 1 || + dim_numbers.lhs_batch_dimensions_size() != 0 || + dim_numbers.rhs_batch_dimensions_size() != 0) { + return false; + } + + HloInstruction* concat = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + if (concat->opcode() != HloOpcode::kConcatenate) { + return false; + } + + const int64_t lhs_contracting_dim = + dim_numbers.lhs_contracting_dimensions(0); + const int64_t rhs_contracting_dim = + dim_numbers.rhs_contracting_dimensions(0); + if (concat->concatenate_dimension() != lhs_contracting_dim || + concat->operand_count() < 2 || concat->shape().rank() != 2 || + rhs->shape().rank() != 2 || !dot->shape().IsArray()) { + return false; + } + if (concat->shape().dimensions(lhs_contracting_dim) != + rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + for (HloInstruction* concat_operand : concat->operands()) { + if (concat_operand->shape().rank() != concat->shape().rank() || + concat_operand->shape().element_type() != + concat->shape().element_type()) { + return false; + } + } + + HloComputation* computation = dot->parent(); + std::vector partial_dots; + int64_t rhs_offset = 0; + for (HloInstruction* concat_operand : concat->operands()) { + const int64_t slice_size = + concat_operand->shape().dimensions(lhs_contracting_dim); + std::vector starts(rhs->shape().rank(), 0); + std::vector limits(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + std::vector strides(rhs->shape().rank(), 1); + std::vector slice_dims(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + starts[rhs_contracting_dim] = rhs_offset; + limits[rhs_contracting_dim] = rhs_offset + slice_size; + slice_dims[rhs_contracting_dim] = slice_size; + rhs_offset += slice_size; + + Shape rhs_slice_shape = + ShapeUtil::MakeShape(rhs->shape().element_type(), slice_dims); + HloInstruction* rhs_slice = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape, rhs, starts, limits, + strides)); + rhs_slice->set_metadata(dot->metadata()); + + HloInstruction* partial_dot = computation->AddInstruction( + HloInstruction::CreateDot(dot->shape(), concat_operand, rhs_slice, + dim_numbers, dot->precision_config())); + partial_dot->set_metadata(dot->metadata()); + partial_dot->set_frontend_attributes(dot->frontend_attributes()); + partial_dots.push_back(partial_dot); + } + + if (rhs_offset != rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + + HloInstruction* replacement = partial_dots[0]; + for (int64_t i = 1; i < partial_dots.size(); ++i) { + replacement = computation->AddInstruction(HloInstruction::CreateBinary( + dot->shape(), HloOpcode::kAdd, replacement, partial_dots[i])); + replacement->set_metadata(dot->metadata()); + replacement->set_frontend_attributes(dot->frontend_attributes()); + } + + TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement)); + return true; + } +}; + } // namespace absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( @@ -642,6 +751,9 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(/*single_call_site=*/true); pipeline.AddPass(); pipeline.AddPass(); + if (module->config().debug_options().xla_cpu_split_concat_dot()) { + pipeline.AddPass(); + } // Rewrite to custom calls with target as oneDNN library calls. #if defined(INTEL_MKL) diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 18abf310ac67e4..8e0cdebd8bfd97 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -174,6 +174,11 @@ message DebugOptions { // below! bool xla_cpu_enable_fast_min_max = 140; + // Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs. + // This can avoid materializing expensive CPU concat buffers before skinny + // matrix multiplications. + bool xla_cpu_split_concat_dot = 405; + // When xla_cpu_enable_fast_math is true then this controls whether we forbid // to use the reciprocal of an argument instead of division. Ignored when // xla_cpu_enable_fast_math is false. From 140a875dac888fa52859249038f6c8483a6ebcc0 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 11 May 2026 17:09:00 +0100 Subject: [PATCH 2/4] Fix split concat dot Shape rank checks --- third_party/xla/xla/service/cpu/cpu_compiler.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 7efb9e4b4f6df8..2eb9ed47155233 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -566,8 +566,9 @@ class CpuSplitConcatDot : public HloModulePass { const int64_t rhs_contracting_dim = dim_numbers.rhs_contracting_dimensions(0); if (concat->concatenate_dimension() != lhs_contracting_dim || - concat->operand_count() < 2 || concat->shape().rank() != 2 || - rhs->shape().rank() != 2 || !dot->shape().IsArray()) { + concat->operand_count() < 2 || + concat->shape().dimensions().size() != 2 || + rhs->shape().dimensions().size() != 2 || !dot->shape().IsArray()) { return false; } if (concat->shape().dimensions(lhs_contracting_dim) != @@ -575,7 +576,8 @@ class CpuSplitConcatDot : public HloModulePass { return false; } for (HloInstruction* concat_operand : concat->operands()) { - if (concat_operand->shape().rank() != concat->shape().rank() || + if (concat_operand->shape().dimensions().size() != + concat->shape().dimensions().size() || concat_operand->shape().element_type() != concat->shape().element_type()) { return false; @@ -588,10 +590,10 @@ class CpuSplitConcatDot : public HloModulePass { for (HloInstruction* concat_operand : concat->operands()) { const int64_t slice_size = concat_operand->shape().dimensions(lhs_contracting_dim); - std::vector starts(rhs->shape().rank(), 0); + std::vector starts(rhs->shape().dimensions().size(), 0); std::vector limits(rhs->shape().dimensions().begin(), rhs->shape().dimensions().end()); - std::vector strides(rhs->shape().rank(), 1); + std::vector strides(rhs->shape().dimensions().size(), 1); std::vector slice_dims(rhs->shape().dimensions().begin(), rhs->shape().dimensions().end()); starts[rhs_contracting_dim] = rhs_offset; From 2fe653c5bfe47e10fade1b244c761eda7fe1830a Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 11 May 2026 17:18:02 +0100 Subject: [PATCH 3/4] Log when split concat dot rewrites --- third_party/xla/xla/service/cpu/cpu_compiler.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 2eb9ed47155233..f2409cd2f76715 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -628,6 +628,9 @@ class CpuSplitConcatDot : public HloModulePass { replacement->set_frontend_attributes(dot->frontend_attributes()); } + LOG(INFO) << "CpuSplitConcatDot rewrote " << dot->name() << " by splitting " + << concat->name() << " with " << concat->operand_count() + << " operands"; TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement)); return true; } From 2cac82e012b207089168f264c53bc34f227474e2 Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Mon, 11 May 2026 18:04:49 +0100 Subject: [PATCH 4/4] Move split concat dot pass to its own file --- third_party/xla/xla/service/cpu/BUILD | 17 +++ .../xla/xla/service/cpu/cpu_compiler.cc | 117 +-------------- .../xla/service/cpu/cpu_split_concat_dot.cc | 141 ++++++++++++++++++ .../xla/service/cpu/cpu_split_concat_dot.h | 44 ++++++ 4 files changed, 204 insertions(+), 115 deletions(-) create mode 100644 third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc create mode 100644 third_party/xla/xla/service/cpu/cpu_split_concat_dot.h diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index e5132538349bbe..7ec4da2cb976ae 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -203,6 +203,7 @@ cc_library( ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", + ":cpu_split_concat_dot", ":dot_op_emitter", ":executable_proto_cc", ":fusion_wrapper", @@ -1531,6 +1532,22 @@ xla_cc_test( ], ) +cc_library( + name = "cpu_split_concat_dot", + srcs = ["cpu_split_concat_dot.cc"], + hdrs = ["cpu_split_concat_dot.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "cpu_instruction_fusion", srcs = ["cpu_instruction_fusion.cc"], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index f2409cd2f76715..d76d8a6046585c 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -110,7 +110,6 @@ limitations under the License. #include "xla/hlo/ir/hlo_opcode.h" #include "xla/hlo/ir/hlo_schedule.h" #include "xla/hlo/pass/hlo_pass_fix.h" -#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/pass/hlo_pass_pipeline.h" #include "xla/hlo/transforms/expanders/bitcast_dtypes_expander.h" #include "xla/hlo/transforms/expanders/cholesky_expander.h" @@ -174,6 +173,7 @@ limitations under the License. #include "xla/service/cpu/cpu_instruction_fusion.h" #include "xla/service/cpu/cpu_layout_assignment.h" #include "xla/service/cpu/cpu_options.h" +#include "xla/service/cpu/cpu_split_concat_dot.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/executable.pb.h" #include "xla/service/cpu/fusion_wrapper.h" @@ -523,119 +523,6 @@ std::unique_ptr> CreateSimplificationPipeline( return pipeline; } -class CpuSplitConcatDot : public HloModulePass { - public: - absl::string_view name() const override { return "cpu-split-concat-dot"; } - - absl::StatusOr Run( - HloModule* module, - const absl::flat_hash_set& execution_threads) override { - bool changed = false; - for (HloComputation* computation : - module->MakeNonfusionComputations(execution_threads)) { - for (HloInstruction* instruction : - computation->MakeInstructionPostOrder()) { - if (instruction->opcode() != HloOpcode::kDot) { - continue; - } - TF_ASSIGN_OR_RETURN(bool rewritten, TryRewriteDot(instruction)); - changed |= rewritten; - } - } - return changed; - } - - private: - absl::StatusOr TryRewriteDot(HloInstruction* dot) { - const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); - if (dim_numbers.lhs_contracting_dimensions_size() != 1 || - dim_numbers.rhs_contracting_dimensions_size() != 1 || - dim_numbers.lhs_batch_dimensions_size() != 0 || - dim_numbers.rhs_batch_dimensions_size() != 0) { - return false; - } - - HloInstruction* concat = dot->mutable_operand(0); - HloInstruction* rhs = dot->mutable_operand(1); - if (concat->opcode() != HloOpcode::kConcatenate) { - return false; - } - - const int64_t lhs_contracting_dim = - dim_numbers.lhs_contracting_dimensions(0); - const int64_t rhs_contracting_dim = - dim_numbers.rhs_contracting_dimensions(0); - if (concat->concatenate_dimension() != lhs_contracting_dim || - concat->operand_count() < 2 || - concat->shape().dimensions().size() != 2 || - rhs->shape().dimensions().size() != 2 || !dot->shape().IsArray()) { - return false; - } - if (concat->shape().dimensions(lhs_contracting_dim) != - rhs->shape().dimensions(rhs_contracting_dim)) { - return false; - } - for (HloInstruction* concat_operand : concat->operands()) { - if (concat_operand->shape().dimensions().size() != - concat->shape().dimensions().size() || - concat_operand->shape().element_type() != - concat->shape().element_type()) { - return false; - } - } - - HloComputation* computation = dot->parent(); - std::vector partial_dots; - int64_t rhs_offset = 0; - for (HloInstruction* concat_operand : concat->operands()) { - const int64_t slice_size = - concat_operand->shape().dimensions(lhs_contracting_dim); - std::vector starts(rhs->shape().dimensions().size(), 0); - std::vector limits(rhs->shape().dimensions().begin(), - rhs->shape().dimensions().end()); - std::vector strides(rhs->shape().dimensions().size(), 1); - std::vector slice_dims(rhs->shape().dimensions().begin(), - rhs->shape().dimensions().end()); - starts[rhs_contracting_dim] = rhs_offset; - limits[rhs_contracting_dim] = rhs_offset + slice_size; - slice_dims[rhs_contracting_dim] = slice_size; - rhs_offset += slice_size; - - Shape rhs_slice_shape = - ShapeUtil::MakeShape(rhs->shape().element_type(), slice_dims); - HloInstruction* rhs_slice = computation->AddInstruction( - HloInstruction::CreateSlice(rhs_slice_shape, rhs, starts, limits, - strides)); - rhs_slice->set_metadata(dot->metadata()); - - HloInstruction* partial_dot = computation->AddInstruction( - HloInstruction::CreateDot(dot->shape(), concat_operand, rhs_slice, - dim_numbers, dot->precision_config())); - partial_dot->set_metadata(dot->metadata()); - partial_dot->set_frontend_attributes(dot->frontend_attributes()); - partial_dots.push_back(partial_dot); - } - - if (rhs_offset != rhs->shape().dimensions(rhs_contracting_dim)) { - return false; - } - - HloInstruction* replacement = partial_dots[0]; - for (int64_t i = 1; i < partial_dots.size(); ++i) { - replacement = computation->AddInstruction(HloInstruction::CreateBinary( - dot->shape(), HloOpcode::kAdd, replacement, partial_dots[i])); - replacement->set_metadata(dot->metadata()); - replacement->set_frontend_attributes(dot->frontend_attributes()); - } - - LOG(INFO) << "CpuSplitConcatDot rewrote " << dot->name() << " by splitting " - << concat->name() << " with " << concat->operand_count() - << " operands"; - TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement)); - return true; - } -}; - } // namespace absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( @@ -757,7 +644,7 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); if (module->config().debug_options().xla_cpu_split_concat_dot()) { - pipeline.AddPass(); + pipeline.AddPass(); } // Rewrite to custom calls with target as oneDNN library calls. diff --git a/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc new file mode 100644 index 00000000000000..63ddd0725f99e2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc @@ -0,0 +1,141 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/cpu_split_concat_dot.h" + +#include +#include + +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +absl::StatusOr CpuSplitConcatDot::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + TF_ASSIGN_OR_RETURN(bool rewritten, TryRewriteDot(instruction)); + changed |= rewritten; + } + } + return changed; +} + +absl::StatusOr CpuSplitConcatDot::TryRewriteDot(HloInstruction* dot) { + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + if (dim_numbers.lhs_contracting_dimensions_size() != 1 || + dim_numbers.rhs_contracting_dimensions_size() != 1 || + dim_numbers.lhs_batch_dimensions_size() != 0 || + dim_numbers.rhs_batch_dimensions_size() != 0) { + return false; + } + + HloInstruction* concat = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + if (concat->opcode() != HloOpcode::kConcatenate) { + return false; + } + + const int64_t lhs_contracting_dim = + dim_numbers.lhs_contracting_dimensions(0); + const int64_t rhs_contracting_dim = + dim_numbers.rhs_contracting_dimensions(0); + if (concat->concatenate_dimension() != lhs_contracting_dim || + concat->operand_count() < 2 || + concat->shape().dimensions().size() != 2 || + rhs->shape().dimensions().size() != 2 || !dot->shape().IsArray()) { + return false; + } + if (concat->shape().dimensions(lhs_contracting_dim) != + rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + for (HloInstruction* concat_operand : concat->operands()) { + if (concat_operand->shape().dimensions().size() != + concat->shape().dimensions().size() || + concat_operand->shape().element_type() != + concat->shape().element_type()) { + return false; + } + } + + HloComputation* computation = dot->parent(); + std::vector partial_dots; + int64_t rhs_offset = 0; + for (HloInstruction* concat_operand : concat->operands()) { + const int64_t slice_size = + concat_operand->shape().dimensions(lhs_contracting_dim); + std::vector starts(rhs->shape().dimensions().size(), 0); + std::vector limits(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + std::vector strides(rhs->shape().dimensions().size(), 1); + std::vector slice_dims(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + starts[rhs_contracting_dim] = rhs_offset; + limits[rhs_contracting_dim] = rhs_offset + slice_size; + slice_dims[rhs_contracting_dim] = slice_size; + rhs_offset += slice_size; + + Shape rhs_slice_shape = + ShapeUtil::MakeShape(rhs->shape().element_type(), slice_dims); + HloInstruction* rhs_slice = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape, rhs, starts, limits, + strides)); + rhs_slice->set_metadata(dot->metadata()); + + HloInstruction* partial_dot = computation->AddInstruction( + HloInstruction::CreateDot(dot->shape(), concat_operand, rhs_slice, + dim_numbers, dot->precision_config())); + partial_dot->set_metadata(dot->metadata()); + partial_dot->set_frontend_attributes(dot->frontend_attributes()); + partial_dots.push_back(partial_dot); + } + + if (rhs_offset != rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + + HloInstruction* replacement = partial_dots[0]; + for (int64_t i = 1; i < partial_dots.size(); ++i) { + replacement = computation->AddInstruction(HloInstruction::CreateBinary( + dot->shape(), HloOpcode::kAdd, replacement, partial_dots[i])); + replacement->set_metadata(dot->metadata()); + replacement->set_frontend_attributes(dot->frontend_attributes()); + } + + LOG(INFO) << "CpuSplitConcatDot rewrote " << dot->name() << " by splitting " + << concat->name() << " with " << concat->operand_count() + << " operands"; + TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement)); + return true; +} + +} // namespace cpu +} // namespace xla diff --git a/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h new file mode 100644 index 00000000000000..9305fea83c87c5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h @@ -0,0 +1,44 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_ +#define XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +class CpuSplitConcatDot : public HloModulePass { + public: + absl::string_view name() const override { return "cpu-split-concat-dot"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::StatusOr TryRewriteDot(HloInstruction* dot); +}; + +} // namespace cpu +} // namespace xla + +#endif // XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_