diff --git a/xls/passes/BUILD b/xls/passes/BUILD index 14d2f81b93..25ff373ada 100644 --- a/xls/passes/BUILD +++ b/xls/passes/BUILD @@ -1337,6 +1337,7 @@ cc_library( srcs = ["predicate_dominator_analysis.cc"], hdrs = ["predicate_dominator_analysis.h"], deps = [ + ":optimization_pass", ":predicate_state", "//xls/common:strong_int", "//xls/ir", @@ -3193,6 +3194,7 @@ cc_test( name = "predicate_dominator_analysis_test", srcs = ["predicate_dominator_analysis_test.cc"], deps = [ + ":optimization_pass", ":predicate_dominator_analysis", ":predicate_state", "//xls/common:xls_gunit_main", @@ -4080,6 +4082,7 @@ cc_library( hdrs = ["bit_provenance_analysis.h"], deps = [ ":dataflow_visitor", + ":optimization_pass", ":query_engine", "//xls/common/status:ret_check", "//xls/common/status:status_macros", diff --git a/xls/passes/bit_provenance_analysis.cc b/xls/passes/bit_provenance_analysis.cc index 31d1c60b2e..90f58dc664 100644 --- a/xls/passes/bit_provenance_analysis.cc +++ b/xls/passes/bit_provenance_analysis.cc @@ -34,6 +34,7 @@ #include "xls/ir/nodes.h" #include "xls/ir/type.h" #include "xls/passes/dataflow_visitor.h" +#include "xls/passes/optimization_pass.h" #include "xls/passes/query_engine.h" namespace xls { @@ -301,6 +302,13 @@ BitProvenanceAnalysis::CreatePrepopulated(FunctionBase* func) { XLS_RETURN_IF_ERROR(result.Populate(func)); return result; } +/* static */ absl::StatusOr +BitProvenanceAnalysis::CreatePrepopulated(FunctionBase* func, + OptimizationContext& context) { + BitProvenanceAnalysis result; + XLS_RETURN_IF_ERROR(result.Populate(func, context)); + return result; +} BitProvenanceAnalysis::BitProvenanceAnalysis() : visitor_{std::make_unique()} {} @@ -320,6 +328,14 @@ absl::Status BitProvenanceAnalysis::Populate(FunctionBase* func) { XLS_RETURN_IF_ERROR(func->Accept(visitor_.get())); return absl::OkStatus(); } +absl::Status BitProvenanceAnalysis::Populate(FunctionBase* func, + OptimizationContext& context) { + for (Node* node : context.TopoSort(func)) { + XLS_RETURN_IF_ERROR(node->VisitSingleNode(visitor_.get())); + visitor_->MarkVisited(node); + } + return absl::OkStatus(); +} absl::StatusOr BitProvenanceAnalysis::GetSource( const TreeBitLocation& bit) const { diff --git a/xls/passes/bit_provenance_analysis.h b/xls/passes/bit_provenance_analysis.h index 3944a27555..1768125052 100644 --- a/xls/passes/bit_provenance_analysis.h +++ b/xls/passes/bit_provenance_analysis.h @@ -28,6 +28,7 @@ #include "absl/types/span.h" #include "xls/data_structures/leaf_type_tree.h" #include "xls/ir/node.h" +#include "xls/passes/optimization_pass.h" #include "xls/passes/query_engine.h" namespace xls { @@ -157,6 +158,8 @@ class BitProvenanceAnalysis { // invalid if the function is modified. static absl::StatusOr CreatePrepopulated( FunctionBase* func); + static absl::StatusOr CreatePrepopulated( + FunctionBase* func, OptimizationContext& context); // constructors and destructors need to be declared here and implemented in // the .cc file to avoid the compiler inserting constructors and destructors @@ -170,6 +173,7 @@ class BitProvenanceAnalysis { BitProvenanceAnalysis& operator=(BitProvenanceAnalysis&& other); absl::Status Populate(FunctionBase* func); + absl::Status Populate(FunctionBase* func, OptimizationContext& context); // Get the tree-bit-location which provides the original source of the given // bit. diff --git a/xls/passes/dataflow_simplification_pass.cc b/xls/passes/dataflow_simplification_pass.cc index 24d3cb1025..48e5114ba9 100644 --- a/xls/passes/dataflow_simplification_pass.cc +++ b/xls/passes/dataflow_simplification_pass.cc @@ -86,11 +86,14 @@ absl::StatusOr DataflowSimplificationPass::RunOnFunctionBaseInternal( FunctionBase* func, const OptimizationPassOptions& options, PassResults* results, OptimizationContext& context) const { NodeSourceDataflowVisitor visitor; - XLS_RETURN_IF_ERROR(func->Accept(&visitor)); + for (Node* node : context.TopoSort(func)) { + XLS_RETURN_IF_ERROR(node->VisitSingleNode(&visitor)); + } bool changed = false; // Hashmap from the LTT of a node to the Node*. If two nodes have // the same LTT they are necessarily equivalent. absl::flat_hash_map, Node*> source_map; + source_map.reserve(func->node_count()); for (Node* node : context.TopoSort(func)) { LeafTypeTreeView source = visitor.GetValue(node); VLOG(3) << absl::StrFormat("Considering `%s`: %s", node->GetName(), diff --git a/xls/passes/narrowing_pass.cc b/xls/passes/narrowing_pass.cc index 4fbfe503ca..45db1da180 100644 --- a/xls/passes/narrowing_pass.cc +++ b/xls/passes/narrowing_pass.cc @@ -2258,7 +2258,7 @@ absl::StatusOr NarrowingPass::RunOnFunctionBaseInternal( XLS_ASSIGN_OR_RETURN(AliasingQueryEngine query_engine, GetQueryEngine(f, RealAnalysis(options), context)); - PredicateDominatorAnalysis pda = PredicateDominatorAnalysis::Run(f); + PredicateDominatorAnalysis pda = PredicateDominatorAnalysis::Run(f, context); SpecializedQueryEngines sqe(RealAnalysis(options), pda, query_engine); NarrowVisitor narrower(sqe, RealAnalysis(options), options, diff --git a/xls/passes/predicate_dominator_analysis.cc b/xls/passes/predicate_dominator_analysis.cc index 4f5bdb010d..807fadef11 100644 --- a/xls/passes/predicate_dominator_analysis.cc +++ b/xls/passes/predicate_dominator_analysis.cc @@ -18,6 +18,7 @@ #include #include #include +#include #include #include "absl/container/flat_hash_map.h" @@ -26,7 +27,7 @@ #include "xls/ir/function_base.h" #include "xls/ir/node.h" #include "xls/ir/nodes.h" -#include "xls/ir/topo_sort.h" +#include "xls/passes/optimization_pass.h" #include "xls/passes/predicate_state.h" namespace xls { @@ -97,7 +98,8 @@ class AnalysisHelper { .previous = kRootPredicateId, .distance_to_root = 0}; - explicit AnalysisHelper(FunctionBase* func) : function_(func) {} + AnalysisHelper(FunctionBase* func, OptimizationContext& context) + : function_(func), context_(context) {} absl::flat_hash_map Analyze() { CHECK(node_states_.empty()); @@ -105,7 +107,7 @@ class AnalysisHelper { node_states_.reserve(function_->node_count()); predicate_stacks_.push_back(kRootPredicateStackNode); // Run in reverse topo sort order. Handle users before the values they use. - for (Node* node : ReverseTopoSort(function_)) { + for (Node* node : context_.ReverseTopoSort(function_)) { HandleNode(node); } @@ -259,6 +261,7 @@ class AnalysisHelper { private: FunctionBase* function_; + OptimizationContext& context_; // Map of node to the predicate list head they are guarded by. absl::flat_hash_map node_states_; // Map from 'PredicateStackId' to the predicate node. @@ -266,8 +269,9 @@ class AnalysisHelper { }; } // namespace -PredicateDominatorAnalysis PredicateDominatorAnalysis::Run(FunctionBase* f) { - AnalysisHelper helper(f); +PredicateDominatorAnalysis PredicateDominatorAnalysis::Run( + FunctionBase* f, OptimizationContext& context) { + AnalysisHelper helper(f, context); return PredicateDominatorAnalysis(helper.Analyze()); } diff --git a/xls/passes/predicate_dominator_analysis.h b/xls/passes/predicate_dominator_analysis.h index 8330ae1554..ee0619cd9d 100644 --- a/xls/passes/predicate_dominator_analysis.h +++ b/xls/passes/predicate_dominator_analysis.h @@ -20,6 +20,7 @@ #include "absl/container/flat_hash_map.h" #include "xls/ir/function_base.h" #include "xls/ir/node.h" +#include "xls/passes/optimization_pass.h" #include "xls/passes/predicate_state.h" namespace xls { @@ -38,7 +39,8 @@ class PredicateDominatorAnalysis { PredicateDominatorAnalysis& operator=(PredicateDominatorAnalysis&&) = default; // Execute this analysis and return results. - static PredicateDominatorAnalysis Run(FunctionBase* f); + static PredicateDominatorAnalysis Run(FunctionBase* f, + OptimizationContext& context); // Returns a single element of the common predicate dominators which is // closest to the node (ie the last predicate which gates the use of this diff --git a/xls/passes/predicate_dominator_analysis_test.cc b/xls/passes/predicate_dominator_analysis_test.cc index 23eac74cc4..eee26e0964 100644 --- a/xls/passes/predicate_dominator_analysis_test.cc +++ b/xls/passes/predicate_dominator_analysis_test.cc @@ -21,16 +21,24 @@ #include "xls/common/status/matchers.h" #include "xls/ir/benchmark_support.h" #include "xls/ir/bits.h" +#include "xls/ir/function_base.h" #include "xls/ir/function_builder.h" #include "xls/ir/ir_test_base.h" #include "xls/ir/nodes.h" #include "xls/ir/package.h" +#include "xls/passes/optimization_pass.h" #include "xls/passes/predicate_state.h" namespace xls { namespace { -class PredicateDominatorAnalysisTest : public IrTestBase {}; +class PredicateDominatorAnalysisTest : public IrTestBase { + public: + PredicateDominatorAnalysis RunAnalysis(FunctionBase* f) { + OptimizationContext context; + return PredicateDominatorAnalysis::Run(f, context); + } +}; TEST_F(PredicateDominatorAnalysisTest, NoPredicates) { // No predicates everything goes to base predicate state. @@ -43,7 +51,7 @@ TEST_F(PredicateDominatorAnalysisTest, NoPredicates) { BValue wxyz = fb.Add(fb.Add(w, x), fb.Add(y, z)); XLS_ASSERT_OK_AND_ASSIGN(auto* f, fb.Build()); - auto analysis = PredicateDominatorAnalysis::Run(f); + auto analysis = RunAnalysis(f); EXPECT_EQ(analysis.GetSingleNearestPredicate(w.node()), PredicateState()); EXPECT_EQ(analysis.GetSingleNearestPredicate(x.node()), PredicateState()); @@ -85,7 +93,7 @@ TEST_F(PredicateDominatorAnalysisTest, Simple) { BValue nwxyz = fb.Not(wxyz); XLS_ASSERT_OK_AND_ASSIGN(auto* f, fb.Build()); - auto analysis = PredicateDominatorAnalysis::Run(f); + auto analysis = RunAnalysis(f); auto* select = wxyz.node()->As(); auto s_v2 = v2.node()->As(); auto s_v2 = v2.node()->As(); XLS_ASSERT_OK_AND_ASSIGN(auto* f, fb.Build()); - auto analysis = PredicateDominatorAnalysis::Run(f); + auto analysis = RunAnalysis(f); EXPECT_EQ(analysis.GetSingleNearestPredicate(s1.node()), PredicateState()); EXPECT_EQ(analysis.GetSingleNearestPredicate(s2.node()), PredicateState()); @@ -246,7 +254,7 @@ TEST_F(PredicateDominatorAnalysisTest, NestedSelects) { auto* s_wxyz = wxyz.node()->As(); // condition @@ -316,7 +324,7 @@ TEST_F(PredicateDominatorAnalysisTest, DisjointCovering) { auto* s_xyt = xyt.node()->As(); XLS_ASSERT_OK_AND_ASSIGN(auto* f, fb.Build()); - auto analysis = PredicateDominatorAnalysis::Run(f); + auto analysis = RunAnalysis(f); EXPECT_EQ(analysis.GetSingleNearestPredicate(s1.node()), PredicateState(s_wxty, 0)); @@ -420,7 +428,7 @@ TEST_F(PredicateDominatorAnalysisTest, NestedPartialDisjointCovering) { auto* s_xatbt = xatbt.node()->As