Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions third_party/xla/xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

// By default, copy TF's Eigen style min_max behavior with nans.
opts.set_xla_cpu_enable_fast_min_max(true);
opts.set_xla_cpu_split_concat_dot(false);

opts.set_xla_gpu_enable_cublaslt(false);

Expand Down Expand Up @@ -1032,6 +1033,12 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
debug_options->xla_cpu_enable_concurrency_optimized_scheduler(),
"Use HLO module scheduler that is optimized for extracting concurrency "
"from an HLO module by trading off extra memory pressure."));
flag_list->push_back(tsl::Flag(
"xla_cpu_split_concat_dot",
bool_setter_for(&DebugOptions::set_xla_cpu_split_concat_dot),
debug_options->xla_cpu_split_concat_dot(),
"Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs "
"to avoid materializing concat inputs to CPU dots."));
flag_list->push_back(tsl::Flag(
"xla_cpu_prefer_vector_width",
int32_setter_for(&DebugOptions::set_xla_cpu_prefer_vector_width),
Expand Down
17 changes: 17 additions & 0 deletions third_party/xla/xla/service/cpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ cc_library(
":cpu_instruction_fusion",
":cpu_layout_assignment",
":cpu_options",
":cpu_split_concat_dot",
":dot_op_emitter",
":executable_proto_cc",
":fusion_wrapper",
Expand Down Expand Up @@ -1531,6 +1532,22 @@ xla_cc_test(
],
)

cc_library(
name = "cpu_split_concat_dot",
srcs = ["cpu_split_concat_dot.cc"],
hdrs = ["cpu_split_concat_dot.h"],
deps = [
"//xla:shape_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/pass:hlo_pass",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
],
)

cc_library(
name = "cpu_instruction_fusion",
srcs = ["cpu_instruction_fusion.cc"],
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ limitations under the License.
#include "xla/service/cpu/cpu_instruction_fusion.h"
#include "xla/service/cpu/cpu_layout_assignment.h"
#include "xla/service/cpu/cpu_options.h"
#include "xla/service/cpu/cpu_split_concat_dot.h"
#include "xla/service/cpu/dot_op_emitter.h"
#include "xla/service/cpu/executable.pb.h"
#include "xla/service/cpu/fusion_wrapper.h"
Expand Down Expand Up @@ -642,6 +643,9 @@ absl::Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<CallInliner>(/*single_call_site=*/true);
pipeline.AddPass<BatchDotSimplification>();
pipeline.AddPass<DotDecomposer>();
if (module->config().debug_options().xla_cpu_split_concat_dot()) {
pipeline.AddPass<cpu::CpuSplitConcatDot>();
}

// Rewrite to custom calls with target as oneDNN library calls.
#if defined(INTEL_MKL)
Expand Down
141 changes: 141 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_split_concat_dot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
/* Copyright 2026 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/service/cpu/cpu_split_concat_dot.h"

#include <cstdint>
#include <vector>

#include "absl/log/log.h"
#include "xla/hlo/ir/hlo_computation.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/xla_data.pb.h"

namespace xla {
namespace cpu {

absl::StatusOr<bool> CpuSplitConcatDot::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
bool changed = false;
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kDot) {
continue;
}
TF_ASSIGN_OR_RETURN(bool rewritten, TryRewriteDot(instruction));
changed |= rewritten;
}
}
return changed;
}

absl::StatusOr<bool> CpuSplitConcatDot::TryRewriteDot(HloInstruction* dot) {
const DotDimensionNumbers& dim_numbers = dot->dot_dimension_numbers();
if (dim_numbers.lhs_contracting_dimensions_size() != 1 ||
dim_numbers.rhs_contracting_dimensions_size() != 1 ||
dim_numbers.lhs_batch_dimensions_size() != 0 ||
dim_numbers.rhs_batch_dimensions_size() != 0) {
return false;
}

HloInstruction* concat = dot->mutable_operand(0);
HloInstruction* rhs = dot->mutable_operand(1);
if (concat->opcode() != HloOpcode::kConcatenate) {
return false;
}

const int64_t lhs_contracting_dim =
dim_numbers.lhs_contracting_dimensions(0);
const int64_t rhs_contracting_dim =
dim_numbers.rhs_contracting_dimensions(0);
if (concat->concatenate_dimension() != lhs_contracting_dim ||
concat->operand_count() < 2 ||
concat->shape().dimensions().size() != 2 ||
rhs->shape().dimensions().size() != 2 || !dot->shape().IsArray()) {
return false;
}
if (concat->shape().dimensions(lhs_contracting_dim) !=
rhs->shape().dimensions(rhs_contracting_dim)) {
return false;
}
for (HloInstruction* concat_operand : concat->operands()) {
if (concat_operand->shape().dimensions().size() !=
concat->shape().dimensions().size() ||
concat_operand->shape().element_type() !=
concat->shape().element_type()) {
return false;
}
}

HloComputation* computation = dot->parent();
std::vector<HloInstruction*> partial_dots;
int64_t rhs_offset = 0;
for (HloInstruction* concat_operand : concat->operands()) {
const int64_t slice_size =
concat_operand->shape().dimensions(lhs_contracting_dim);
std::vector<int64_t> starts(rhs->shape().dimensions().size(), 0);
std::vector<int64_t> limits(rhs->shape().dimensions().begin(),
rhs->shape().dimensions().end());
std::vector<int64_t> strides(rhs->shape().dimensions().size(), 1);
std::vector<int64_t> slice_dims(rhs->shape().dimensions().begin(),
rhs->shape().dimensions().end());
starts[rhs_contracting_dim] = rhs_offset;
limits[rhs_contracting_dim] = rhs_offset + slice_size;
slice_dims[rhs_contracting_dim] = slice_size;
rhs_offset += slice_size;

Shape rhs_slice_shape =
ShapeUtil::MakeShape(rhs->shape().element_type(), slice_dims);
HloInstruction* rhs_slice = computation->AddInstruction(
HloInstruction::CreateSlice(rhs_slice_shape, rhs, starts, limits,
strides));
rhs_slice->set_metadata(dot->metadata());

HloInstruction* partial_dot = computation->AddInstruction(
HloInstruction::CreateDot(dot->shape(), concat_operand, rhs_slice,
dim_numbers, dot->precision_config()));
partial_dot->set_metadata(dot->metadata());
partial_dot->set_frontend_attributes(dot->frontend_attributes());
partial_dots.push_back(partial_dot);
}

if (rhs_offset != rhs->shape().dimensions(rhs_contracting_dim)) {
return false;
}

HloInstruction* replacement = partial_dots[0];
for (int64_t i = 1; i < partial_dots.size(); ++i) {
replacement = computation->AddInstruction(HloInstruction::CreateBinary(
dot->shape(), HloOpcode::kAdd, replacement, partial_dots[i]));
replacement->set_metadata(dot->metadata());
replacement->set_frontend_attributes(dot->frontend_attributes());
}

LOG(INFO) << "CpuSplitConcatDot rewrote " << dot->name() << " by splitting "
<< concat->name() << " with " << concat->operand_count()
<< " operands";
TF_RETURN_IF_ERROR(dot->parent()->ReplaceInstruction(dot, replacement));
return true;
}

} // namespace cpu
} // namespace xla
44 changes: 44 additions & 0 deletions third_party/xla/xla/service/cpu/cpu_split_concat_dot.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/* Copyright 2026 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_
#define XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_

#include "absl/container/flat_hash_set.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h"

namespace xla {
namespace cpu {

class CpuSplitConcatDot : public HloModulePass {
public:
absl::string_view name() const override { return "cpu-split-concat-dot"; }

absl::StatusOr<bool> Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override;

private:
absl::StatusOr<bool> TryRewriteDot(HloInstruction* dot);
};

} // namespace cpu
} // namespace xla

#endif // XLA_SERVICE_CPU_CPU_SPLIT_CONCAT_DOT_H_
5 changes: 5 additions & 0 deletions third_party/xla/xla/xla.proto
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ message DebugOptions {
// below!
bool xla_cpu_enable_fast_min_max = 140;

// Split dot(concat(lhs...), rhs) into a sum of dots over slices of rhs.
// This can avoid materializing expensive CPU concat buffers before skinny
// matrix multiplications.
bool xla_cpu_split_concat_dot = 405;

// When xla_cpu_enable_fast_math is true then this controls whether we forbid
// to use the reciprocal of an argument instead of division. Ignored when
// xla_cpu_enable_fast_math is false.
Expand Down