diff --git a/third_party/xla/xla/debug_options_flags.cc b/third_party/xla/xla/debug_options_flags.cc index 961273151ed5e0..4d1c27d5e02b87 100644 --- a/third_party/xla/xla/debug_options_flags.cc +++ b/third_party/xla/xla/debug_options_flags.cc @@ -124,6 +124,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() { // By default, copy TF's Eigen style min_max behavior with nans. opts.set_xla_cpu_enable_fast_min_max(true); + opts.set_xla_cpu_split_concat_dot(false); opts.set_xla_gpu_enable_cublaslt(false); @@ -1032,6 +1033,12 @@ void MakeDebugOptionsFlags(std::vector* flag_list, debug_options->xla_cpu_enable_concurrency_optimized_scheduler(), "Use HLO module scheduler that is optimized for extracting concurrency " "from an HLO module by trading off extra memory pressure.")); + flag_list->push_back(tsl::Flag( + "xla_cpu_split_concat_dot", + bool_setter_for(&DebugOptions::set_xla_cpu_split_concat_dot), + debug_options->xla_cpu_split_concat_dot(), + "Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs " + "to avoid materializing concat inputs to CPU dots.")); flag_list->push_back(tsl::Flag( "xla_cpu_prefer_vector_width", int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width), diff --git a/third_party/xla/xla/service/cpu/BUILD b/third_party/xla/xla/service/cpu/BUILD index e5132538349bbe..7ec4da2cb976ae 100644 --- a/third_party/xla/xla/service/cpu/BUILD +++ b/third_party/xla/xla/service/cpu/BUILD @@ -203,6 +203,7 @@ cc_library( ":cpu_instruction_fusion", ":cpu_layout_assignment", ":cpu_options", + ":cpu_split_concat_dot", ":dot_op_emitter", ":executable_proto_cc", ":fusion_wrapper", @@ -1531,6 +1532,22 @@ xla_cc_test( ], ) +cc_library( + name = "cpu_split_concat_dot", + srcs = ["cpu_split_concat_dot.cc"], + hdrs = ["cpu_split_concat_dot.h"], + deps = [ + "//xla:shape_util", + "//xla:xla_data_proto_cc", + "//xla/hlo/ir:hlo", + "//xla/hlo/pass:hlo_pass", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + ], +) + cc_library( name = "cpu_instruction_fusion", srcs = ["cpu_instruction_fusion.cc"], diff --git a/third_party/xla/xla/service/cpu/cpu_compiler.cc b/third_party/xla/xla/service/cpu/cpu_compiler.cc index 042e7e5b33a0f7..d76d8a6046585c 100644 --- a/third_party/xla/xla/service/cpu/cpu_compiler.cc +++ b/third_party/xla/xla/service/cpu/cpu_compiler.cc @@ -173,6 +173,7 @@ limitations under the License. #include "xla/service/cpu/cpu_instruction_fusion.h" #include "xla/service/cpu/cpu_layout_assignment.h" #include "xla/service/cpu/cpu_options.h" +#include "xla/service/cpu/cpu_split_concat_dot.h" #include "xla/service/cpu/dot_op_emitter.h" #include "xla/service/cpu/executable.pb.h" #include "xla/service/cpu/fusion_wrapper.h" @@ -642,6 +643,9 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(/*single_call_site=*/true); pipeline.AddPass(); pipeline.AddPass(); + if (module->config().debug_options().xla_cpu_split_concat_dot()) { + pipeline.AddPass(); + } // Rewrite to custom calls with target as oneDNN library calls. #if defined(INTEL_MKL) diff --git a/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc new file mode 100644 index 00000000000000..63ddd0725f99e2 --- /dev/null +++ b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc @@ -0,0 +1,141 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/cpu/cpu_split_concat_dot.h" + +#include +#include + +#include "absl/log/log.h" +#include "xla/hlo/ir/hlo_computation.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/status_macros.h" +#include "xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +absl::StatusOr CpuSplitConcatDot::Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) { + bool changed = false; + for (HloComputation* computation : + module->MakeNonfusionComputations(execution_threads)) { + for (HloInstruction* instruction : + computation->MakeInstructionPostOrder()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + TF_ASSIGN_OR_RETURN(bool rewritten, TryRewriteDot(instruction)); + changed |= rewritten; + } + } + return changed; +} + +absl::StatusOr CpuSplitConcatDot::TryRewriteDot(HloInstruction* dot) { + const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers(); + if (dim_numbers.lhs_contracting_dimensions_size() != 1 || + dim_numbers.rhs_contracting_dimensions_size() != 1 || + dim_numbers.lhs_batch_dimensions_size() != 0 || + dim_numbers.rhs_batch_dimensions_size() != 0) { + return false; + } + + HloInstruction* concat = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + if (concat->opcode() != HloOpcode::kConcatenate) { + return false; + } + + const int64_t lhs_contracting_dim = + dim_numbers.lhs_contracting_dimensions(0); + const int64_t rhs_contracting_dim = + dim_numbers.rhs_contracting_dimensions(0); + if (concat->concatenate_dimension() != lhs_contracting_dim || + concat->operand_count() < 2 || + concat->shape().dimensions().size() != 2 || + rhs->shape().dimensions().size() != 2 || !dot->shape().IsArray()) { + return false; + } + if (concat->shape().dimensions(lhs_contracting_dim) != + rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + for (HloInstruction* concat_operand : concat->operands()) { + if (concat_operand->shape().dimensions().size() != + concat->shape().dimensions().size() || + concat_operand->shape().element_type() != + concat->shape().element_type()) { + return false; + } + } + + HloComputation* computation = dot->parent(); + std::vector partial_dots; + int64_t rhs_offset = 0; + for (HloInstruction* concat_operand : concat->operands()) { + const int64_t slice_size = + concat_operand->shape().dimensions(lhs_contracting_dim); + std::vector starts(rhs->shape().dimensions().size(), 0); + std::vector limits(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + std::vector strides(rhs->shape().dimensions().size(), 1); + std::vector slice_dims(rhs->shape().dimensions().begin(), + rhs->shape().dimensions().end()); + starts[rhs_contracting_dim] = rhs_offset; + limits[rhs_contracting_dim] = rhs_offset + slice_size; + slice_dims[rhs_contracting_dim] = slice_size; + rhs_offset += slice_size; + + Shape rhs_slice_shape = + ShapeUtil::MakeShape(rhs->shape().element_type(), slice_dims); + HloInstruction* rhs_slice = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape, rhs, starts, limits, + strides)); + rhs_slice->set_metadata(dot->metadata()); + + HloInstruction* partial_dot = computation->AddInstruction( + HloInstruction::CreateDot(dot->shape(), concat_operand, rhs_slice, + dim_numbers, dot->precision_config())); + partial_dot->set_metadata(dot->metadata()); + partial_dot->set_frontend_attributes(dot->frontend_attributes()); + partial_dots.push_back(partial_dot); + } + + if (rhs_offset != rhs->shape().dimensions(rhs_contracting_dim)) { + return false; + } + + HloInstruction* replacement = partial_dots[0]; + for (int64_t i = 1; i < partial_dots.size(); ++i) { + replacement = computation->AddInstruction(HloInstruction::CreateBinary( + dot->shape(), HloOpcode::kAdd, replacement, partial_dots[i])); + replacement->set_metadata(dot->metadata()); + replacement->set_frontend_attributes(dot->frontend_attributes()); + } + + LOG(INFO) << "CpuSplitConcatDot rewrote " << dot->name() << " by splitting " + << concat->name() << " with " << concat->operand_count() + << " operands"; + TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement)); + return true; +} + +} // namespace cpu +} // namespace xla diff --git a/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h new file mode 100644 index 00000000000000..9305fea83c87c5 --- /dev/null +++ b/third_party/xla/xla/service/cpu/cpu_split_concat_dot.h @@ -0,0 +1,44 @@ +/* Copyright 2026 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_ +#define XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_ + +#include "absl/container/flat_hash_set.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_module.h" +#include "xla/hlo/pass/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +class CpuSplitConcatDot : public HloModulePass { + public: + absl::string_view name() const override { return "cpu-split-concat-dot"; } + + absl::StatusOr Run( + HloModule* module, + const absl::flat_hash_set& execution_threads) override; + + private: + absl::StatusOr TryRewriteDot(HloInstruction* dot); +}; + +} // namespace cpu +} // namespace xla + +#endif // XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_ diff --git a/third_party/xla/xla/xla.proto b/third_party/xla/xla/xla.proto index 18abf310ac67e4..8e0cdebd8bfd97 100644 --- a/third_party/xla/xla/xla.proto +++ b/third_party/xla/xla/xla.proto @@ -174,6 +174,11 @@ message DebugOptions { // below! bool xla_cpu_enable_fast_min_max = 140; + // Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs. + // This can avoid materializing expensive CPU concat buffers before skinny + // matrix multiplications. + bool xla_cpu_split_concat_dot = 405; + // When xla_cpu_enable_fast_math is true then this controls whether we forbid // to use the reciprocal of an argument instead of division. Ignored when // xla_cpu_enable_fast_math is false.