From 1f881c4f00bd58a392a57c615d9304f41cae1014 Mon Sep 17 00:00:00 2001 From: Chris Leary Date: Sat, 24 Jan 2026 13:29:11 -0800 Subject: [PATCH] [opt] Simplify carry extraction pattern (prepending zero bit, adding, slicing above the MSb) --- xls/ir/node_util.h | 90 +++++++++++ xls/ir/node_util_test.cc | 34 ++++ xls/passes/bit_slice_simplification_pass.cc | 101 ++++++++++++ .../bit_slice_simplification_pass_test.cc | 147 ++++++++++++++++++ 4 files changed, 372 insertions(+) diff --git a/xls/ir/node_util.h b/xls/ir/node_util.h index af34473ff6..8b2b0c1ce9 100644 --- a/xls/ir/node_util.h +++ b/xls/ir/node_util.h @@ -53,6 +53,73 @@ inline bool IsLiteralZero(Node* node) { node->As()->value().bits().IsZero(); } +// A uniform view of a bits-typed node that represents a zero-extension of a +// narrower bits-typed `base` value. +// +// This matches either: +// - `zero_ext(base, new_bit_count=W)` where base has width N < W +// - `concat(0..., base)` where all leading operands are literal zeros and base +// is the final operand. +struct ZeroExtendedBitsView { + // The original value being extended. + Node* base; + // Bit width of `base`. + int64_t base_width; + // Bit width of the zero-extended result (i.e. the node being matched). + int64_t result_width; + // Number of leading zero bits added (result_width - base_width). + int64_t leading_zero_width; +}; + +// Returns a view if `node` is a zero extension of a narrower bits value, as +// defined by `ZeroExtendedBitsView`. +inline std::optional MatchZeroExtendedBits(Node* node) { + if (node == nullptr || !node->GetType()->IsBits()) { + return std::nullopt; + } + const int64_t result_width = node->BitCountOrDie(); + + if (node->op() == Op::kZeroExt) { + ExtendOp* ext = node->As(); + Node* base = ext->operand(0); + const int64_t base_width = base->BitCountOrDie(); + if (base_width < result_width) { + return ZeroExtendedBitsView{ + .base = base, + .base_width = base_width, + .result_width = result_width, + .leading_zero_width = result_width - base_width}; + } + return std::nullopt; + } + + if (node->op() == Op::kConcat) { + Concat* concat = node->As(); + if (concat->operand_count() < 2) { + return std::nullopt; + } + int64_t prefix_width = 0; + for (int64_t i = 0; i < concat->operand_count() - 1; ++i) { + Node* prefix = concat->operand(i); + if (!IsLiteralZero(prefix)) { + return std::nullopt; + } + prefix_width += prefix->BitCountOrDie(); + } + Node* base = concat->operand(concat->operand_count() - 1); + const int64_t base_width = base->BitCountOrDie(); + if (prefix_width <= 0) { + return std::nullopt; + } + return ZeroExtendedBitsView{.base = base, + .base_width = base_width, + .result_width = result_width, + .leading_zero_width = prefix_width}; + } + + return std::nullopt; +} + // Returns true if the given node is a literal with the value one when // interpreted as an unsigned number inline bool IsLiteralUnsignedOne(Node* node) { @@ -123,6 +190,29 @@ inline bool AnyTwoOperandsWhere(Node* node, return false; } +// Returns true if `pred_a` and `pred_b` match `a` and `b` in either order. +// +// If a match is found, `on_match(matched_a, matched_b)` is invoked with +// `matched_a` being the node that satisfied `pred_a` and `matched_b` being the +// node that satisfied `pred_b`. +// +// This is useful for matching commutative patterns while populating additional +// context via captures in `on_match`. +inline bool MatchNodesInAnyOrder( + Node* a, Node* b, const std::function& pred_a, + const std::function& pred_b, + const std::function& on_match) { + if (pred_a(a) && pred_b(b)) { + on_match(a, b); + return true; + } + if (pred_a(b) && pred_b(a)) { + on_match(b, a); + return true; + } + return false; +} + inline bool HasSingleUse(Node* node) { if (node->function_base()->HasImplicitUse(node)) { return node->users().empty(); diff --git a/xls/ir/node_util_test.cc b/xls/ir/node_util_test.cc index d1733f0a6a..06dd357adf 100644 --- a/xls/ir/node_util_test.cc +++ b/xls/ir/node_util_test.cc @@ -244,6 +244,40 @@ TEST_F(NodeUtilTest, MatchBinarySelectLikeNonMatch) { EXPECT_FALSE(arms.has_value()); } +TEST_F(NodeUtilTest, MatchNodesInAnyOrder) { + auto p = CreatePackage(); + FunctionBuilder fb(TestName(), p.get()); + BValue x = fb.Param("x", p->GetBitsType(8)); + BValue y = fb.Param("y", p->GetBitsType(8)); + BValue sum = fb.Add(x, y); + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(sum)); + + const auto pred_x = [](Node* n) { return n->GetName() == "x"; }; + const auto pred_y = [](Node* n) { return n->GetName() == "y"; }; + + // Already in order. + Node* matched_a = nullptr; + Node* matched_b = nullptr; + EXPECT_TRUE(MatchNodesInAnyOrder(f->param(0), f->param(1), pred_x, pred_y, + [&](Node* a, Node* b) { + matched_a = a; + matched_b = b; + })); + EXPECT_EQ(matched_a, f->param(0)); + EXPECT_EQ(matched_b, f->param(1)); + + // Swapped order. + matched_a = nullptr; + matched_b = nullptr; + EXPECT_TRUE(MatchNodesInAnyOrder(f->param(0), f->param(1), pred_y, pred_x, + [&](Node* a, Node* b) { + matched_a = a; + matched_b = b; + })); + EXPECT_EQ(matched_a, f->param(1)); + EXPECT_EQ(matched_b, f->param(0)); +} + TEST_F(NodeUtilTest, GatherTreeBits) { auto p = CreatePackage(); FunctionBuilder fb(TestName(), p.get()); diff --git a/xls/passes/bit_slice_simplification_pass.cc b/xls/passes/bit_slice_simplification_pass.cc index 3bf7238b76..6bb071c59c 100644 --- a/xls/passes/bit_slice_simplification_pass.cc +++ b/xls/passes/bit_slice_simplification_pass.cc @@ -151,6 +151,101 @@ absl::StatusOr> GetUnscaledIndex( return unscaled_index; } +// Simplifies bit-slices that extract the carry bit from an addition of a +// zero-extended value and a literal: +// +// x: bits[N] +// zext: bits[W] = zero_ext(x) or concat(0..., x) (W > N) +// sum: bits[W] = add(zext, K) (K is a literal) +// ret: bits[1] = bit_slice(sum, start=N, width=1) +// +// This bit slice extracts the carry-out bit of x + K (in N bits). Rewrite it +// into a simpler comparison against a literal. +static absl::StatusOr SimplifyCarryExtraction(BitSlice* bit_slice) { + if (bit_slice->width() != 1) { + return false; + } + Node* add = bit_slice->operand(0); + if (add->op() != Op::kAdd) { + return false; + } + const int64_t add_width = add->BitCountOrDie(); + Node* add_lhs = add->operand(0); + Node* add_rhs = add->operand(1); + + // Match (zero_ext(v), literal) in either operand order. + std::optional maybe_v; + Literal* literal = nullptr; + if (!MatchNodesInAnyOrder( + add_lhs, add_rhs, + [](Node* n) { return MatchZeroExtendedBits(n).has_value(); }, + [](Node* n) { return n->Is(); }, + [&](Node* zeroext_node, Node* literal_node) { + maybe_v = MatchZeroExtendedBits(zeroext_node); + literal = literal_node->As(); + })) { + return false; + } + DCHECK(maybe_v.has_value()); + DCHECK(literal != nullptr); + DCHECK_EQ(maybe_v->result_width, add_width); + + Node* v = maybe_v->base; // bits[N] + const int64_t n_width = maybe_v->base_width; + DCHECK_LT(n_width, add_width); + if (n_width <= 0) { + return false; + } + if (bit_slice->start() != n_width) { + return false; + } + const Bits k = literal->value().bits(); + DCHECK_EQ(k.bit_count(), add_width); + + // Let `A = zero_ext(v)` (so `A[N] = 0`) and `B = k` (so `B[N] = b_n`). + // Then `sum[N] = b_n XOR carry_in`, where `carry_in` is the carry-out from + // adding the low N bits: `v + k_low`. + const bool b_n = k.Get(n_width); + Bits k_low_wide = k.Slice(0, n_width); + if (k_low_wide.IsZero()) { + // No carry-in is possible; sum[N] == bN. + XLS_RETURN_IF_ERROR( + bit_slice->ReplaceUsesWithNew(Value(UBits(b_n ? 1 : 0, 1))) + .status()); + return true; + } + + // `carry_in(v + k_low)` <=> `v >= 2^N - k_low` + Bits k_low_ext = bits_ops::ZeroExtend(k_low_wide, n_width + 1); + Bits two_pow_n = Bits::PowerOfTwo(n_width, /*bit_count=*/n_width + 1); + Bits threshold_ext = bits_ops::Sub(two_pow_n, k_low_ext); + Bits threshold = threshold_ext.Slice(0, n_width); + XLS_ASSIGN_OR_RETURN(Node * threshold_literal, + bit_slice->function_base()->MakeNode( + bit_slice->loc(), Value(threshold))); + + if (!b_n) { + XLS_ASSIGN_OR_RETURN(Node * cmp, + bit_slice->function_base()->MakeNode( + bit_slice->loc(), v, threshold_literal, Op::kUGe)); + VLOG(3) << absl::StreamFormat( + "Replacing bitslice(add(zext(x), k), start=N) => uge(x, T): %s", + bit_slice->GetName()); + XLS_RETURN_IF_ERROR(bit_slice->ReplaceUsesWith(cmp)); + return true; + } + + // `bN==1`: `sum[N] == !carry_in == v < threshold` + XLS_ASSIGN_OR_RETURN(Node * cmp, + bit_slice->function_base()->MakeNode( + bit_slice->loc(), v, threshold_literal, Op::kULt)); + VLOG(3) << absl::StreamFormat( + "Replacing bitslice(add(zext(x), k), start=N) => ult(x, T): %s", + bit_slice->GetName()); + XLS_RETURN_IF_ERROR(bit_slice->ReplaceUsesWith(cmp)); + return true; +} + // Attempts to replace the given bit slice with a simpler or more canonical // form. Returns true if the bit slice was replaced. Any newly created // bit-slices are added to the worklist. @@ -159,6 +254,12 @@ absl::StatusOr SimplifyBitSlice(BitSlice* bit_slice, int64_t opt_level, Node* operand = bit_slice->operand(0); BitsType* operand_type = operand->GetType()->AsBitsOrDie(); + XLS_ASSIGN_OR_RETURN(bool carry_rewritten, + SimplifyCarryExtraction(bit_slice)); + if (carry_rewritten) { + return true; + } + // Creates a new bit slice and adds it to the worklist. auto make_bit_slice = [&](const SourceInfo& loc, Node* operand, int64_t start, int64_t width) -> absl::StatusOr { diff --git a/xls/passes/bit_slice_simplification_pass_test.cc b/xls/passes/bit_slice_simplification_pass_test.cc index d8e43f0c42..3b484a43ce 100644 --- a/xls/passes/bit_slice_simplification_pass_test.cc +++ b/xls/passes/bit_slice_simplification_pass_test.cc @@ -215,6 +215,153 @@ TEST_F(BitSliceSimplificationPassTest, EXPECT_THAT(Run(f), IsOkAndHolds(false)); } +// Parameterized by the *result width* (W) of the zero-extension, so we cover +// multiple `zero_ext`/`concat(0..., x)` widths and ensure the rewrite does not +// depend on how many leading zeros are added (so long as W > N). +class CarryOutOfAddWithZeroExtendedOperandTest + : public BitSliceSimplificationPassTest, + public ::testing::WithParamInterface {}; + +TEST_P(CarryOutOfAddWithZeroExtendedOperandTest, + CarryOutOfAddWithZeroExtendedOperand) { + const int64_t extended_width = GetParam(); + ASSERT_GT(extended_width, 8); + + auto p = CreatePackage(); + Type* u8 = p->GetBitsType(8); + Type* u1 = p->GetBitsType(1); + + FunctionBuilder fb("f", p.get()); + + BValue x = fb.Param("x", u8); + BValue not_x = fb.Not(x); + BValue zext_not_x = + fb.Concat({fb.Literal(UBits(0, extended_width - 8)), not_x}); + BValue sum = fb.Add(zext_not_x, fb.Literal(UBits(127, extended_width))); + BValue carry = fb.BitSlice(sum, /*start=*/8, /*width=*/1); + + // Include extra boolean context to mirror the motivating pattern. + BValue eq_x_ff = fb.Eq(x, fb.Literal(UBits(255, 8))); + BValue leaf_237 = fb.Param("leaf_237", u1); + BValue leaf_466 = fb.Param("leaf_466", u1); + BValue not_leaf_237 = fb.Not(leaf_237); + BValue out = fb.And({carry, eq_x_ff, leaf_466, not_leaf_237}); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(out)); + + ScopedVerifyEquivalence eq(f, kProverTimeout); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + + // The carry bit is rewritten to a compare against a literal. + EXPECT_THAT( + f->return_value(), + AllOf(m::And(m::UGe(m::Not(m::Param("x")), + AllOf(m::Literal(129), m::Type("bits[8]"))), + m::Eq(m::Param("x"), AllOf(m::Literal(255), m::Type(u8))), + m::Param("leaf_466"), m::Not(m::Param("leaf_237"))), + m::Type(u1))); +} + +INSTANTIATE_TEST_SUITE_P(VariousExtensionWidths, + CarryOutOfAddWithZeroExtendedOperandTest, + ::testing::Values(int64_t{9}, int64_t{12}, + int64_t{16})); + +TEST_F(BitSliceSimplificationPassTest, + CarryOutOfAddWithZeroExtendedOperand_MatchesSwappedAddOperands) { + auto p = CreatePackage(); + Type* u8 = p->GetBitsType(8); + Type* u1 = p->GetBitsType(1); + + FunctionBuilder fb("f", p.get()); + BValue x = fb.Param("x", u8); + BValue not_x = fb.Not(x); + BValue zext_not_x = fb.Concat({fb.Literal(UBits(0, 4)), not_x}); + + // Put the literal on the LHS so we exercise the "either operand order" + // matching logic. + BValue sum = fb.Add(fb.Literal(UBits(127, 12)), zext_not_x); + BValue carry = fb.BitSlice(sum, /*start=*/8, /*width=*/1); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(carry)); + + ScopedVerifyEquivalence eq(f, kProverTimeout); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + AllOf(m::UGe(m::Not(m::Param("x")), + AllOf(m::Literal(129), m::Type("bits[8]"))), + m::Type(u1))); +} + +TEST_F(BitSliceSimplificationPassTest, + CarryOutOfAddWithZeroExtendedOperand_MatchesZeroExtNode) { + auto p = CreatePackage(); + Type* u8 = p->GetBitsType(8); + Type* u1 = p->GetBitsType(1); + + FunctionBuilder fb("f", p.get()); + BValue x = fb.Param("x", u8); + BValue not_x = fb.Not(x); + BValue zext_not_x = fb.ZeroExtend(not_x, /*new_bit_count=*/12); + BValue sum = fb.Add(zext_not_x, fb.Literal(UBits(127, 12))); + BValue carry = fb.BitSlice(sum, /*start=*/8, /*width=*/1); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(carry)); + + ScopedVerifyEquivalence eq(f, kProverTimeout); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + AllOf(m::UGe(m::Not(m::Param("x")), + AllOf(m::Literal(129), m::Type("bits[8]"))), + m::Type(u1))); +} + +TEST_F(BitSliceSimplificationPassTest, + CarryOutOfAddWithZeroExtendedOperand_MsbOneRewritesToULt) { + auto p = CreatePackage(); + Type* u8 = p->GetBitsType(8); + Type* u1 = p->GetBitsType(1); + + FunctionBuilder fb("f", p.get()); + BValue x = fb.Param("x", u8); + BValue not_x = fb.Not(x); + BValue zext_not_x = fb.ZeroExtend(not_x, /*new_bit_count=*/9); + + // Here we choose a `k` with `k[8] == 1` and non-zero low 8 bits, which makes + // the carry bit rewrite to a `ult` comparison. + BValue sum = fb.Add(zext_not_x, fb.Literal(UBits(257, 9))); + BValue carry = fb.BitSlice(sum, /*start=*/8, /*width=*/1); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(carry)); + + ScopedVerifyEquivalence eq(f, kProverTimeout); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), + AllOf(m::ULt(m::Not(m::Param("x")), + AllOf(m::Literal(255), m::Type("bits[8]"))), + m::Type(u1))); +} + +TEST_F(BitSliceSimplificationPassTest, + CarryOutOfAddWithZeroExtendedOperand_KLowZeroRewritesToLiteral) { + auto p = CreatePackage(); + Type* u8 = p->GetBitsType(8); + Type* u1 = p->GetBitsType(1); + + FunctionBuilder fb("f", p.get()); + BValue x = fb.Param("x", u8); + BValue not_x = fb.Not(x); + BValue zext_not_x = fb.ZeroExtend(not_x, /*new_bit_count=*/12); + BValue sum = fb.Add(zext_not_x, fb.Literal(UBits(0, 12))); + BValue carry = fb.BitSlice(sum, /*start=*/8, /*width=*/1); + + XLS_ASSERT_OK_AND_ASSIGN(Function * f, fb.BuildWithReturnValue(carry)); + + ScopedVerifyEquivalence eq(f, kProverTimeout); + ASSERT_THAT(Run(f), IsOkAndHolds(true)); + EXPECT_THAT(f->return_value(), AllOf(m::Literal(0), m::Type(u1))); +} + TEST_F(BitSliceSimplificationPassTest, SliceOfSignExtCaseOne) { auto p = CreatePackage(); XLS_ASSERT_OK_AND_ASSIGN(Function * f, ParseFunction(R"(