From 0955514dbf718da00ddf33f36a87086e3174ad5d Mon Sep 17 00:00:00 2001 From: Steven Varoumas Date: Tue, 5 May 2026 16:38:24 +0100 Subject: [PATCH] Allow excluding Softmax from XLA clustering --- .../compiler/jit/compilability_check_util.cc | 8 ++++++ .../compiler/jit/compilability_check_util.h | 3 ++ .../jit/compilability_check_util_test.cc | 28 +++++++++++++++++++ tensorflow/compiler/jit/flags.cc | 2 +- .../compiler/jit/mark_for_compilation_pass.cc | 4 +++ .../jit/mark_for_compilation_pass_test.cc | 28 +++++++++++++++++++ 6 files changed, 72 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 50b26371698877..69e0929046ae32 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -503,6 +503,14 @@ bool RecursiveCompilabilityChecker::IsCompilableNode( return false; } + if (!op_filter_.allow_softmax_op && node.type_string() == "Softmax") { + absl::string_view uncompilable_reason = "Softmax op"; + MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, + encapsulating_function, uncompilable_nodes); + LogNotCompilable(node, uncompilable_reason); + return false; + } + if (!op_filter_.allow_unique_op && node.type_string() == "Unique") { absl::string_view uncompilable_reason = "Unique op"; MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace, diff --git a/tensorflow/compiler/jit/compilability_check_util.h b/tensorflow/compiler/jit/compilability_check_util.h index 0d86c22de11a22..2392a5383aca53 100644 --- a/tensorflow/compiler/jit/compilability_check_util.h +++ b/tensorflow/compiler/jit/compilability_check_util.h @@ -139,6 +139,9 @@ class RecursiveCompilabilityChecker { // Whether to allow the compilation of WhereOp. bool allow_where_op = true; + // Whether to allow the compilation of SoftmaxOp. + bool allow_softmax_op = true; + // Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp // generates output with bounded dynamic shape that may cause failures with // auto clustering. diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index ea24176bb04a4a..df182e80a79c62 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -178,6 +178,34 @@ TEST_F(CompilabilityCheckUtilTest, CheckOutsideCompiledNode) { ASSERT_EQ(0, uncompilable_nodes2.size()); } +TEST_F(CompilabilityCheckUtilTest, CheckSoftmaxDisallowedByFilter) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + Node* input = ops::SourceOp("InputFloatOp", opts); + Node* softmax = ops::UnaryOp("Softmax", input, opts); + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + + auto* flib_runtime = GetFunctionLibraryRuntime(); + EXPECT_TRUE(checker_->IsCompilableNode(*softmax, flib_runtime)); + + op_filter_.allow_softmax_op = false; + checker_ = CreateCompilabilityChecker(); + EXPECT_FALSE(checker_->IsCompilableNode(*softmax, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*softmax, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + EXPECT_TRUE(absl::StrContains( + uncompilable_nodes_inside_function.at(0).uncompilable_reason, + "Softmax op")); +} + TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) { FunctionDefLibrary flib; *flib.add_function() = FunctionDefHelper::Define( diff --git a/tensorflow/compiler/jit/flags.cc b/tensorflow/compiler/jit/flags.cc index 5a6f741a01e972..0c46da76b2bccb 100644 --- a/tensorflow/compiler/jit/flags.cc +++ b/tensorflow/compiler/jit/flags.cc @@ -139,7 +139,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector* flag_list) { "(experimental) " "Exclude the operations from auto-clustering. " "If multiple, separate them with commas." - " Where, Some_other_ops"), + " Where, Softmax, Some_other_ops"), Flag("tf_xla_clustering_debug", &mark_for_compilation_flags->tf_xla_clustering_debug, "Dump graphs during XLA compilation."), diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 566bab23a11867..c462a2a597a7b8 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1648,9 +1648,12 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() { auto cluster_exclude_op_list = CreateClusterExcludeList(); bool allow_where_op = true; + bool allow_softmax_op = true; for (const auto& s : cluster_exclude_op_list) { if (s == "Where") { allow_where_op = false; + } else if (s == "Softmax") { + allow_softmax_op = false; } else { return errors::InvalidArgument( "The operation '", s, @@ -1703,6 +1706,7 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() { filter.allow_collective_reduce_v2 = false; filter.allow_unique_op = false; filter.allow_where_op = allow_where_op; + filter.allow_softmax_op = allow_softmax_op; RecursiveCompilabilityChecker checker( filter, DeviceType{registration->compilation_device_name}); diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 1a120791206369..498a00ba489f77 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/cc/ops/sendrecv_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/compiler/jit/defs.h" +#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h" #include "tensorflow/compiler/jit/node_matchers.h" #include "tensorflow/compiler/jit/xla_cluster_util.h" @@ -50,6 +51,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/test.h" @@ -220,6 +222,32 @@ TEST(XlaCompilationTest, WhereUnsupported) { EXPECT_TRUE(!clusters.empty()); } +TEST(XlaCompilationTest, SoftmaxCanBeExcludedFromAutoclustering) { + MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); + string old_cluster_exclude_ops = flags->tf_xla_cluster_exclude_ops; + flags->tf_xla_cluster_exclude_ops = "Softmax"; + auto cleanup = gtl::MakeCleanup([&] { + flags->tf_xla_cluster_exclude_ops = old_cluster_exclude_ops; + }); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("Softmax", b, builder.opts().WithName("C")); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + EXPECT_TRUE(clusters.find("B") != clusters.cend()); + EXPECT_TRUE(clusters.find("C") == clusters.cend()); +} + TEST(XlaCompilationTest, HalfSupported) { std::unique_ptr graph(new Graph(OpRegistry::Global())); {