Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,14 @@ bool RecursiveCompilabilityChecker::IsCompilableNode(
return false;
}

if (!op_filter_.allow_softmax_op && node.type_string() == "Softmax") {
absl::string_view uncompilable_reason = "Softmax op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
encapsulating_function, uncompilable_nodes);
LogNotCompilable(node, uncompilable_reason);
return false;
}

if (!op_filter_.allow_unique_op && node.type_string() == "Unique") {
absl::string_view uncompilable_reason = "Unique op";
MaybeMarkUncompilableNode(uncompilable_reason, *stack_trace,
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ class RecursiveCompilabilityChecker {
// Whether to allow the compilation of WhereOp.
bool allow_where_op = true;

// Whether to allow the compilation of SoftmaxOp.
bool allow_softmax_op = true;

// Whether to allow the compilation of UniqueOp. Compilation of the UniqueOp
// generates output with bounded dynamic shape that may cause failures with
// auto clustering.
Expand Down
28 changes: 28 additions & 0 deletions tensorflow/compiler/jit/compilability_check_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,34 @@ TEST_F(CompilabilityCheckUtilTest, CheckOutsideCompiledNode) {
ASSERT_EQ(0, uncompilable_nodes2.size());
}

TEST_F(CompilabilityCheckUtilTest, CheckSoftmaxDisallowedByFilter) {
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
auto opts = builder.opts();
Node* input = ops::SourceOp("InputFloatOp", opts);
Node* softmax = ops::UnaryOp("Softmax", input, opts);
GraphDef graph_def;
TF_EXPECT_OK(builder.ToGraphDef(&graph_def));

auto* flib_runtime = GetFunctionLibraryRuntime();
EXPECT_TRUE(checker_->IsCompilableNode(*softmax, flib_runtime));

op_filter_.allow_softmax_op = false;
checker_ = CreateCompilabilityChecker();
EXPECT_FALSE(checker_->IsCompilableNode(*softmax, flib_runtime));

const auto uncompilable_nodes =
checker_->FindUncompilableNodes(*softmax, flib_runtime);
ASSERT_EQ(1, uncompilable_nodes.size());
auto node_info_it =
uncompilable_nodes.find(NameAttrList().ShortDebugString());
ASSERT_NE(uncompilable_nodes.end(), node_info_it);
const auto& uncompilable_nodes_inside_function = node_info_it->second.second;
ASSERT_EQ(1, uncompilable_nodes_inside_function.size());
EXPECT_TRUE(absl::StrContains(
uncompilable_nodes_inside_function.at(0).uncompilable_reason,
"Softmax op"));
}

TEST_F(CompilabilityCheckUtilTest, CheckSimpleFunctionNode) {
FunctionDefLibrary flib;
*flib.add_function() = FunctionDefHelper::Define(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/jit/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ void AppendMarkForCompilationPassFlagsInternal(std::vector<Flag>* flag_list) {
"(experimental) "
"Exclude the operations from auto-clustering. "
"If multiple, separate them with commas."
" Where, Some_other_ops"),
" Where, Softmax, Some_other_ops"),
Flag("tf_xla_clustering_debug",
&mark_for_compilation_flags->tf_xla_clustering_debug,
"Dump graphs during XLA compilation."),
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1648,9 +1648,12 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() {

auto cluster_exclude_op_list = CreateClusterExcludeList();
bool allow_where_op = true;
bool allow_softmax_op = true;
for (const auto& s : cluster_exclude_op_list) {
if (s == "Where") {
allow_where_op = false;
} else if (s == "Softmax") {
allow_softmax_op = false;
} else {
return errors::InvalidArgument(
"The operation '", s,
Expand Down Expand Up @@ -1703,6 +1706,7 @@ absl::Status MarkForCompilationPassImpl::FindCompilationCandidates() {
filter.allow_collective_reduce_v2 = false;
filter.allow_unique_op = false;
filter.allow_where_op = allow_where_op;
filter.allow_softmax_op = allow_softmax_op;

RecursiveCompilabilityChecker checker(
filter, DeviceType{registration->compilation_device_name});
Expand Down
28 changes: 28 additions & 0 deletions tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ limitations under the License.
#include "tensorflow/cc/ops/sendrecv_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
#include "tensorflow/compiler/jit/node_matchers.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
Expand All @@ -50,6 +51,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
Expand Down Expand Up @@ -220,6 +222,32 @@ TEST(XlaCompilationTest, WhereUnsupported) {
EXPECT_TRUE(!clusters.empty());
}

TEST(XlaCompilationTest, SoftmaxCanBeExcludedFromAutoclustering) {
MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags();
string old_cluster_exclude_ops = flags->tf_xla_cluster_exclude_ops;
flags->tf_xla_cluster_exclude_ops = "Softmax";
auto cleanup = gtl::MakeCleanup([&] {
flags->tf_xla_cluster_exclude_ops = old_cluster_exclude_ops;
});

std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
Node* a = ops::SourceOp("Const", builder.opts()
.WithName("A")
.WithAttr("dtype", DT_FLOAT)
.WithAttr("value", Tensor()));
Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
ops::UnaryOp("Softmax", b, builder.opts().WithName("C"));
TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get()));
}

TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
auto clusters = GetClusters(*graph);
EXPECT_TRUE(clusters.find("B") != clusters.cend());
EXPECT_TRUE(clusters.find("C") == clusters.cend());
}

TEST(XlaCompilationTest, HalfSupported) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
{
Expand Down