diff --git a/trie/bintrie/binary_node.go b/trie/bintrie/binary_node.go index 3eea737e4234..e563dc99c894 100644 --- a/trie/bintrie/binary_node.go +++ b/trie/bintrie/binary_node.go @@ -25,7 +25,7 @@ import ( ) type ( - NodeFlushFn func([]byte, BinaryNode) + NodeFlushFn func(*BitArray, BinaryNode) NodeResolverFn func([]byte, common.Hash) ([]byte, error) ) @@ -54,7 +54,7 @@ type BinaryNode interface { Hash() common.Hash GetValuesAtStem([]byte, NodeResolverFn) ([][]byte, error) InsertValuesAtStem([]byte, [][]byte, NodeResolverFn, int) (BinaryNode, error) - CollectNodes([]byte, NodeFlushFn) error + CollectNodes(*BitArray, NodeFlushFn) error toDot(parent, path string) string GetHeight() int diff --git a/trie/bintrie/binary_node_test.go b/trie/bintrie/binary_node_test.go index 242743ba53bb..d1ca91743dbf 100644 --- a/trie/bintrie/binary_node_test.go +++ b/trie/bintrie/binary_node_test.go @@ -199,28 +199,28 @@ func TestKeyToPath(t *testing.T) { name: "depth 0", depth: 0, key: []byte{0x80}, // 10000000 in binary - expected: []byte{1}, + expected: []byte{1}, // first 1 bit = 1 wantErr: false, }, { name: "depth 7", depth: 7, key: []byte{0xFF}, // 11111111 in binary - expected: []byte{1, 1, 1, 1, 1, 1, 1, 1}, + expected: []byte{0xFF}, // first 8 bits = 0xFF wantErr: false, }, { name: "depth crossing byte boundary", depth: 10, key: []byte{0xFF, 0x00}, // 11111111 00000000 in binary - expected: []byte{1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0}, + expected: []byte{0x07, 0xF8}, // first 11 bits = 11111111000 = 0x7F8 wantErr: false, }, { name: "max valid depth", depth: StemSize * 8, key: make([]byte, HashSize), - expected: make([]byte, StemSize*8+1), + expected: make([]byte, StemSize), // 248 bits of zeros (capped to key length) wantErr: false, }, { diff --git a/trie/bintrie/bitarray.go b/trie/bintrie/bitarray.go new file mode 100644 index 000000000000..d5b9183166ba --- /dev/null +++ b/trie/bintrie/bitarray.go @@ -0,0 +1,566 @@ +// Copyright 2026 The go-ethereum Authors +// This file is part of go-ethereum. +// +// go-ethereum is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// go-ethereum is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with go-ethereum. If not, see . +package bintrie + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "math" +) + +const ( + maxUint64 = uint64(math.MaxUint64) // 0xFFFFFFFFFFFFFFFF + maxUint8 = uint8(math.MaxUint8) +) + +var emptyBitArray = new(BitArray) + +// BitArray represents a bit array with length representing the number of used bits. +// It uses a little endian representation to do bitwise operations of the words efficiently. +// For example, if len is 10, it means that the 2^9, 2^8, ..., 2^0 bits are used. +// The max length is 255 bits (uint8), because our use case only need up to 248 bits for a given trie key. +// Although words can be used to represent 256 bits, we don't want to add an additional byte for the length. +type BitArray struct { + len uint8 // number of used bits + words [4]uint64 // little endian (i.e. words[0] is the least significant) +} + +// NewBitArray creates a new bit array with the given length and value. +func NewBitArray(length uint8, val uint64) BitArray { + var b BitArray + b.SetUint64(length, val) + return b +} + +func (b *BitArray) Len() uint8 { + return b.len +} + +// Bytes returns the bytes representation of the bit array in big endian format +func (b *BitArray) Bytes() [32]byte { + var res [32]byte + + binary.BigEndian.PutUint64(res[0:8], b.words[3]) + binary.BigEndian.PutUint64(res[8:16], b.words[2]) + binary.BigEndian.PutUint64(res[16:24], b.words[1]) + binary.BigEndian.PutUint64(res[24:32], b.words[0]) + + return res +} + +// Append sets the bit array to the concatenation of x and y and returns the bit array. +// For example: +// +// x = 000 (len=3) +// y = 111 (len=3) +// Append(x,y) = 000111 (len=6) +func (b *BitArray) Append(x, y *BitArray) *BitArray { + if x.len == 0 { + return b.Set(y) + } + if y.len == 0 { + return b.Set(x) + } + if x.len > maxUint8-y.len { + panic("error on bitarray append: result would exceed maximum length of 255 bits") + } + + // Shift left by y's length and OR with y + return b.lsh(x, y.len).or(b, y) +} + +// AppendBit sets the bit array to the concatenation of x and a single bit. +func (b *BitArray) AppendBit(x *BitArray, bit uint8) *BitArray { + return b.Append(x, new(BitArray).SetBit(bit)) +} + +// MSBs sets the bit array to the most significant 'n' bits of x, that is position 0 to n (exclusive). +// If n >= x.len, the bit array is an exact copy of x. +// Think of this method as array[0:n] +// For example: +// +// x = 11001011 (len=8) +// MSBs(x, 4) = 1100 (len=4) +// MSBs(x, 10) = 11001011 (len=8, original x) +// MSBs(x, 0) = 0 (len=0) +func (b *BitArray) MSBs(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + return b.rsh(x, x.len-n) +} + +// Equal checks if two bit arrays are equal +func (b *BitArray) Equal(x *BitArray) bool { + if b == nil || x == nil { + panic("bit array is nil") + } + + return b.len == x.len && b.words == x.words +} + +// SetBytes interprets the data as the big-endian bytes, sets the bit array to that value and returns it. +// If the data is larger than 32 bytes, only the first 32 bytes are used. +func (b *BitArray) SetBytes(length uint8, data []byte) *BitArray { + switch l := len(data); l { + case 0: + b.clear() + case 1: + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(data[0]) + case 2: + _ = data[1] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])) + case 3: + _ = data[2] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + case 4: + _ = data[3] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])) + case 5: + _ = data[4] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint40(data[0:5]) + case 6: + _ = data[5] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint48(data[0:6]) + case 7: + _ = data[6] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, bigEndianUint56(data[0:7]) + case 8: + _ = data[7] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, binary.BigEndian.Uint64(data[0:8]) + case 9: + _ = data[8] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + case 10: + _ = data[9] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + case 11: + _ = data[10] + b.words[3], b.words[2] = 0, 0 + b.words[1], b.words[0] = uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16, binary.BigEndian.Uint64(data[3:11]) + case 12: + _ = data[11] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + case 13: + _ = data[12] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + case 14: + _ = data[13] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + case 15: + _ = data[14] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + case 16: + _ = data[15] + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, binary.BigEndian.Uint64(data[0:8]), binary.BigEndian.Uint64(data[8:16]) + case 17: + _ = data[16] + b.words[3], b.words[2] = 0, uint64(data[0]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[1:9]), binary.BigEndian.Uint64(data[9:17]) + case 18: + _ = data[17] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[0:2])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[2:10]), binary.BigEndian.Uint64(data[10:18]) + case 19: + _ = data[18] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint16(data[1:3]))|uint64(data[0])<<16 + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[3:11]), binary.BigEndian.Uint64(data[11:19]) + case 20: + _ = data[19] + b.words[3], b.words[2] = 0, uint64(binary.BigEndian.Uint32(data[0:4])) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[4:12]), binary.BigEndian.Uint64(data[12:20]) + case 21: + _ = data[20] + b.words[3], b.words[2] = 0, bigEndianUint40(data[0:5]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[5:13]), binary.BigEndian.Uint64(data[13:21]) + case 22: + _ = data[21] + b.words[3], b.words[2] = 0, bigEndianUint48(data[0:6]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[6:14]), binary.BigEndian.Uint64(data[14:22]) + case 23: + _ = data[22] + b.words[3], b.words[2] = 0, bigEndianUint56(data[0:7]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[7:15]), binary.BigEndian.Uint64(data[15:23]) + case 24: + _ = data[23] + b.words[3], b.words[2] = 0, binary.BigEndian.Uint64(data[0:8]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[8:16]), binary.BigEndian.Uint64(data[16:24]) + case 25: + _ = data[24] + b.words[3], b.words[2] = uint64(data[0]), binary.BigEndian.Uint64(data[1:9]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[9:17]), binary.BigEndian.Uint64(data[17:25]) + case 26: + _ = data[25] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint16(data[0:2])), binary.BigEndian.Uint64(data[2:10]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[10:18]), binary.BigEndian.Uint64(data[18:26]) + case 27: + _ = data[26] + b.words[3] = uint64(binary.BigEndian.Uint16(data[1:3])) | uint64(data[0])<<16 + b.words[2] = binary.BigEndian.Uint64(data[3:11]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[11:19]), binary.BigEndian.Uint64(data[19:27]) + case 28: + _ = data[27] + b.words[3], b.words[2] = uint64(binary.BigEndian.Uint32(data[0:4])), binary.BigEndian.Uint64(data[4:12]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[12:20]), binary.BigEndian.Uint64(data[20:28]) + case 29: + _ = data[28] + b.words[3], b.words[2] = bigEndianUint40(data[0:5]), binary.BigEndian.Uint64(data[5:13]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[13:21]), binary.BigEndian.Uint64(data[21:29]) + case 30: + _ = data[29] + b.words[3], b.words[2] = bigEndianUint48(data[0:6]), binary.BigEndian.Uint64(data[6:14]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[14:22]), binary.BigEndian.Uint64(data[22:30]) + case 31: + _ = data[30] + b.words[3], b.words[2] = bigEndianUint56(data[0:7]), binary.BigEndian.Uint64(data[7:15]) + b.words[1], b.words[0] = binary.BigEndian.Uint64(data[15:23]), binary.BigEndian.Uint64(data[23:31]) + default: + b.setBytes32(data) + } + b.len = length + b.truncateToLength() + return b +} + +// SetUint64 sets the bit array to the uint64 representation of a bit array. +func (b *BitArray) SetUint64(length uint8, data uint64) *BitArray { + b.words[0] = data + b.len = length + b.truncateToLength() + return b +} + +// SetBit sets the bit array to a single bit. +func (b *BitArray) SetBit(bit uint8) *BitArray { + b.len = 1 + b.words[0] = uint64(bit & 1) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + return b +} + +// Copy returns a deep copy of the bit array. +func (b *BitArray) Copy() BitArray { + var res BitArray + res.Set(b) + return res +} + +// String returns a string representation of the bit array. +// This is typically used for logging or debugging. +func (b *BitArray) String() string { + bt := b.Bytes() + return fmt.Sprintf("(%d) %s", b.len, hex.EncodeToString(bt[:])) +} + +// Bit returns the bit value at position n, where n = 0 is MSB. +// If n is out of bounds, returns 0. +func (b *BitArray) Bit(n uint8) uint8 { + if n >= b.Len() { + return 0 + } + + return b.bitFromLSB(b.Len() - n - 1) +} + +// Set sets the bit array to the same value as x. +func (b *BitArray) Set(x *BitArray) *BitArray { + b.len = x.len + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] + return b +} + +// ActiveBytes returns a slice containing only the bytes that are actually used by the bit array, +// as specified by the length. The returned slice is in big-endian order. +// +// Example: +// +// len = 10, words = [0x3FF, 0, 0, 0] -> [0x03, 0xFF] +func (b *BitArray) ActiveBytes() []byte { + wordsBytes := b.Bytes() + return wordsBytes[32-b.byteCount():] +} + +// bitFromLSB returns the bit value at position n, where n = 0 is LSB. +// If n is out of bounds, returns 0. +func (b *BitArray) bitFromLSB(n uint8) uint8 { + if n >= b.len { + return 0 + } + + if (b.words[n/64] & (1 << (n % 64))) != 0 { + return 1 + } + + return 0 +} + +// copyLsb sets the bit array to the least significant 'n' bits of x. +// n is counted from the least significant bit, starting at 0. +// If length >= x.len, the bit array is an exact copy of x. +// For example: +// +// x = 11001011 (len=8) +// copyLsb(x, 4) = 1011 (len=4) +// copyLsb(x, 10) = 11001011 (len=8, original x) +// copyLsb(x, 0) = 0 (len=0) +func (b *BitArray) copyLsb(x *BitArray, n uint8) *BitArray { + if n >= x.len { + return b.Set(x) + } + + b.len = n + + switch { + case n == 0: + b.words = [4]uint64{0, 0, 0, 0} + case n <= 64: + b.words[0] = x.words[0] & (maxUint64 >> (64 - n)) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case n <= 128: + b.words[0] = x.words[0] + b.words[1] = x.words[1] & (maxUint64 >> (128 - n)) + b.words[2], b.words[3] = 0, 0 + case n <= 192: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] & (maxUint64 >> (192 - n)) + b.words[3] = 0 + default: + b.words[0] = x.words[0] + b.words[1] = x.words[1] + b.words[2] = x.words[2] + b.words[3] = x.words[3] & (maxUint64 >> (256 - uint16(n))) + } + + return b +} + +// lsb returns the least significant bits of `x` with `n` counted from the most significant bit, starting at 0. +// Think of this method as array[n:] +// For example: +// +// x = 11001011 (len=8) +// lsb(x, 1) = 1001011 (len=7) +// lsb(x, 10) = 0 (len=0) +// lsb(x, 0) = 11001011 (len=8, original x) +func (b *BitArray) lsb(x *BitArray, n uint8) *BitArray { + if n == 0 { + return b.Set(x) + } + + if n > x.Len() { + return b.clear() + } + + return b.copyLsb(x, x.Len()-n) +} + +// or sets the bit array to x | y and returns the bit array. +func (b *BitArray) or(x, y *BitArray) *BitArray { + b.words[0] = x.words[0] | y.words[0] + b.words[1] = x.words[1] | y.words[1] + b.words[2] = x.words[2] | y.words[2] + b.words[3] = x.words[3] | y.words[3] + b.len = x.len + return b +} + +// rsh sets the bit array to x >> n and returns the bit array. +func (b *BitArray) rsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 { + return b.Set(x) + } + + if n >= x.len { + return b.clear() + } + + switch { + case n == 0: + return b.Set(x) + case n >= 192: + b.rsh192(x) + b.len = x.len - n + n -= 192 + b.words[0] >>= n + case n >= 128: + b.rsh128(x) + b.len = x.len - n + n -= 128 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] >>= n + case n >= 64: + b.rsh64(x) + b.len = x.len - n + n -= 64 + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] >>= n + default: + b.Set(x) + b.len -= n + b.words[0] = (b.words[0] >> n) | (b.words[1] << (64 - n)) + b.words[1] = (b.words[1] >> n) | (b.words[2] << (64 - n)) + b.words[2] = (b.words[2] >> n) | (b.words[3] << (64 - n)) + b.words[3] >>= n + } + + b.truncateToLength() + return b +} + +// lsh sets the bit array to x << n and returns the bit array. +func (b *BitArray) lsh(x *BitArray, n uint8) *BitArray { + if x.len == 0 || n == 0 { + return b.Set(x) + } + + // If the result will overflow, we set the length to the max length + // but we still shift `n` bits + if n > maxUint8-x.len { + b.len = maxUint8 + } else { + b.len = x.len + n + } + + switch { + case n >= 192: + b.lsh192(x) + n -= 192 + b.words[3] <<= n + case n >= 128: + b.lsh128(x) + n -= 128 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] <<= n + case n >= 64: + b.lsh64(x) + n -= 64 + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] <<= n + default: + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[3], x.words[2], x.words[1], x.words[0] + b.words[3] = (b.words[3] << n) | (b.words[2] >> (64 - n)) + b.words[2] = (b.words[2] << n) | (b.words[1] >> (64 - n)) + b.words[1] = (b.words[1] << n) | (b.words[0] >> (64 - n)) + b.words[0] <<= n + } + + b.truncateToLength() + return b +} + +func (b *BitArray) setBytes32(data []byte) { + _ = data[31] // bound check hint, see https://golang.org/issue/14808 + b.words[3] = binary.BigEndian.Uint64(data[0:8]) + b.words[2] = binary.BigEndian.Uint64(data[8:16]) + b.words[1] = binary.BigEndian.Uint64(data[16:24]) + b.words[0] = binary.BigEndian.Uint64(data[24:32]) +} + +// byteCount returns the minimum number of bytes needed to represent the bit array. +// It rounds up to the nearest byte. +func (b *BitArray) byteCount() uint { + const bits8 = 8 + return (uint(b.len) + (bits8 - 1)) / uint(bits8) +} + +func (b *BitArray) rsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, x.words[3], x.words[2], x.words[1] +} + +func (b *BitArray) rsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, x.words[3], x.words[2] +} + +func (b *BitArray) rsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = 0, 0, 0, x.words[3] +} + +func (b *BitArray) lsh64(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[2], x.words[1], x.words[0], 0 +} + +func (b *BitArray) lsh128(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[1], x.words[0], 0, 0 +} + +func (b *BitArray) lsh192(x *BitArray) { + b.words[3], b.words[2], b.words[1], b.words[0] = x.words[0], 0, 0, 0 +} + +func (b *BitArray) clear() *BitArray { + b.len = 0 + b.words[0], b.words[1], b.words[2], b.words[3] = 0, 0, 0, 0 + return b +} + +// truncateToLength truncates the bit array to the specified length, ensuring that any unused bits are all zeros. +// +// Example: +// +// b := &BitArray{ +// len: 5, +// words: [4]uint64{ +// 0xFFFFFFFFFFFFFFFF, // Before: all bits are 1 +// 0x0, 0x0, 0x0, +// }, +// } +// b.truncateToLength() +// // After: only first 5 bits remain +// // words[0] = 0x000000000000001F +// // words[1..3] = 0x0 +func (b *BitArray) truncateToLength() { + switch { + case b.len == 0: + b.words = [4]uint64{0, 0, 0, 0} + case b.len <= 64: + b.words[0] &= maxUint64 >> (64 - b.len) + b.words[1], b.words[2], b.words[3] = 0, 0, 0 + case b.len <= 128: + b.words[1] &= maxUint64 >> (128 - b.len) + b.words[2], b.words[3] = 0, 0 + case b.len <= 192: + b.words[2] &= maxUint64 >> (192 - b.len) + b.words[3] = 0 + default: + b.words[3] &= maxUint64 >> (256 - uint16(b.len)) + } +} + +func bigEndianUint40(b []byte) uint64 { + _ = b[4] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[4]) | uint64(b[3])<<8 | uint64(b[2])<<16 | uint64(b[1])<<24 | + uint64(b[0])<<32 +} + +func bigEndianUint48(b []byte) uint64 { + _ = b[5] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[5]) | uint64(b[4])<<8 | uint64(b[3])<<16 | uint64(b[2])<<24 | + uint64(b[1])<<32 | uint64(b[0])<<40 +} + +func bigEndianUint56(b []byte) uint64 { + _ = b[6] // bounds check hint to compiler; see golang.org/issue/14808 + return uint64(b[6]) | uint64(b[5])<<8 | uint64(b[4])<<16 | uint64(b[3])<<24 | + uint64(b[2])<<32 | uint64(b[1])<<40 | uint64(b[0])<<48 +} diff --git a/trie/bintrie/bitarray_test.go b/trie/bintrie/bitarray_test.go new file mode 100644 index 000000000000..3ea01f14e5ec --- /dev/null +++ b/trie/bintrie/bitarray_test.go @@ -0,0 +1,1078 @@ +package bintrie + +import ( + "bytes" + "encoding/binary" + "math/bits" + "testing" +) + +const ( + ones63 = 0x7FFFFFFFFFFFFFFF // 63 bits of 1 +) + +func TestBytes(t *testing.T) { + tests := []struct { + name string + ba BitArray + want [32]byte + }{ + { + name: "length == 0", + ba: BitArray{len: 0, words: [4]uint64{0, 0, 0, 0}}, + want: [32]byte{}, + }, + { + name: "length < 64", + ba: BitArray{len: 38, words: [4]uint64{0x3FFFFFFFFF, 0, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[24:32], 0x3FFFFFFFFF) + return b + }(), + }, + { + name: "64 <= length < 128", + ba: BitArray{len: 100, words: [4]uint64{maxUint64, 0xFFFFFFFFF, 0, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[16:24], 0xFFFFFFFFF) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "128 <= length < 192", + ba: BitArray{len: 130, words: [4]uint64{maxUint64, maxUint64, 0x3, 0}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[8:16], 0x3) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "192 <= length < 255", + ba: BitArray{len: 201, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x1FF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x1FF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 254", + ba: BitArray{len: 254, words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x3FFFFFFFFFFFFFFF}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], 0x3FFFFFFFFFFFFFFF) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + { + name: "length == 255", + ba: BitArray{len: 255, words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}}, + want: func() [32]byte { + var b [32]byte + binary.BigEndian.PutUint64(b[0:8], ones63) + binary.BigEndian.PutUint64(b[8:16], maxUint64) + binary.BigEndian.PutUint64(b[16:24], maxUint64) + binary.BigEndian.PutUint64(b[24:32], maxUint64) + return b + }(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.ba.Bytes() + if !bytes.Equal(got[:], tt.want[:]) { + t.Errorf("BitArray.Bytes() = %v, want %v", got, tt.want) + } + + // check if the received bytes has the same bit count as the BitArray.len + count := 0 + for _, b := range got { + count += bits.OnesCount8(b) + } + if count != int(tt.ba.len) { + t.Errorf("BitArray.Bytes() bit count = %v, want %v", count, tt.ba.len) + } + }) + } +} + +func TestRsh(t *testing.T) { + tests := []struct { + name string + initial *BitArray + shiftBy uint8 + expected *BitArray + }{ + { + name: "zero length array", + initial: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + shiftBy: 5, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by 0", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + shiftBy: 0, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by more than length", + initial: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 65, + expected: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "shift by less than 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 32, + expected: &BitArray{ + len: 96, + words: [4]uint64{maxUint64, 0x00000000FFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by exactly 64", + initial: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + shiftBy: 64, + expected: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift by 127", + initial: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + shiftBy: 127, + expected: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "shift by 128", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 128, + expected: &BitArray{ + len: 123, + words: [4]uint64{maxUint64, 0x7FFFFFFFFFFFFFF, 0, 0}, + }, + }, + { + name: "shift by 192", + initial: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + shiftBy: 192, + expected: &BitArray{ + len: 59, + words: [4]uint64{0x7FFFFFFFFFFFFFF, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).rsh(tt.initial, tt.shiftBy) + if !result.Equal(tt.expected) { + t.Errorf("rsh() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestLsh(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 5, + want: emptyBitArray, + }, + { + name: "shift by 0", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "shift within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 4, + want: &BitArray{ + len: 8, + words: [4]uint64{0xF0, 0, 0, 0}, // 11110000 + }, + }, + { + name: "shift across word boundary", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + n: 62, + want: &BitArray{ + len: 66, + words: [4]uint64{0xC000000000000000, 0x3, 0, 0}, + }, + }, + { + name: "shift by 64 (full word)", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 64, + want: &BitArray{ + len: 72, + words: [4]uint64{0, 0xFF, 0, 0}, + }, + }, + { + name: "shift by 128", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 128, + want: &BitArray{ + len: 136, + words: [4]uint64{0, 0, 0xFF, 0}, + }, + }, + { + name: "shift by 192", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + n: 192, + want: &BitArray{ + len: 200, + words: [4]uint64{0, 0, 0, 0xFF}, + }, + }, + { + name: "shift causing length overflow", + x: &BitArray{ + len: 200, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{ + 0xF000000000000000, + 0xF, + 0, + 0, + }, + }, + }, + { + name: "shift sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + n: 4, + want: &BitArray{ + len: 12, + words: [4]uint64{0xAA0, 0, 0, 0}, // 101010100000 + }, + }, + { + name: "shift partial word across boundary", + x: &BitArray{ + len: 100, + words: [4]uint64{0xFF, 0xFF, 0, 0}, + }, + n: 60, + want: &BitArray{ + len: 160, + words: [4]uint64{ + 0xF000000000000000, + 0xF00000000000000F, + 0xF, + 0, + }, + }, + }, + { + name: "near maximum length shift", + x: &BitArray{ + len: 251, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + n: 4, + want: &BitArray{ + len: 255, // capped at maxUint8 + words: [4]uint64{0xFF0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).lsh(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("Lsh() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestAppend(t *testing.T) { + tests := []struct { + name string + x *BitArray + y *BitArray + want *BitArray + }{ + { + name: "both empty arrays", + x: emptyBitArray, + y: emptyBitArray, + want: emptyBitArray, + }, + { + name: "first array empty", + x: emptyBitArray, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "second array empty", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: emptyBitArray, + want: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + }, + { + name: "within first word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + }, + { + name: "different lengths within word", + x: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + y: &BitArray{ + len: 2, + words: [4]uint64{0x3, 0, 0, 0}, // 11 + }, + want: &BitArray{ + len: 6, + words: [4]uint64{0x3F, 0, 0, 0}, // 111111 + }, + }, + { + name: "across word boundary", + x: &BitArray{ + len: 62, + words: [4]uint64{0x3FFFFFFFFFFFFFFF, 0, 0, 0}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, // 1111 + }, + want: &BitArray{ + len: 66, + words: [4]uint64{maxUint64, 0x3, 0, 0}, + }, + }, + { + name: "across multiple words", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + y: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + want: &BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + y: &BitArray{ + len: 8, + words: [4]uint64{0x55, 0, 0, 0}, // 01010101 + }, + want: &BitArray{ + len: 16, + words: [4]uint64{0xAA55, 0, 0, 0}, // 1010101001010101 + }, + }, + { + name: "result exactly at length limit", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + y: &BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + want: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFFF}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).Append(tt.x, tt.y) + if !got.Equal(tt.want) { + t.Errorf("Append() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + pos uint8 + want *BitArray + }{ + { + name: "zero position", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 0, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "position beyond length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 65, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get last 4 bits", + x: &BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, // 11111111 + }, + pos: 4, + want: &BitArray{ + len: 4, + words: [4]uint64{0x0F, 0, 0, 0}, // 1111 + }, + }, + { + name: "get bits across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get bits from max length array", + x: &BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + pos: 200, + want: &BitArray{ + len: 51, + words: [4]uint64{0x7FFFFFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "empty array", + x: emptyBitArray, + pos: 1, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 16, + words: [4]uint64{0xAAAA, 0, 0, 0}, // 1010101010101010 + }, + pos: 8, + want: &BitArray{ + len: 8, + words: [4]uint64{0xAA, 0, 0, 0}, // 10101010 + }, + }, + { + name: "position equals length", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + pos: 64, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).lsb(tt.x, tt.pos) + if !got.Equal(tt.want) { + t.Errorf("LSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestLSBsFromLSB(t *testing.T) { + tests := []struct { + name string + initial BitArray + length uint8 + expected BitArray + }{ + { + name: "zero", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 0, + expected: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "get 32 LSBs", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 32, + expected: BitArray{ + len: 32, + words: [4]uint64{0x00000000FFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get 1 LSB", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 1, + expected: BitArray{ + len: 1, + words: [4]uint64{0x1, 0, 0, 0}, + }, + }, + { + name: "get 100 LSBs across words", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{maxUint64, 0x0000000FFFFFFFFF, 0, 0}, + }, + }, + { + name: "get 64 LSBs at word boundary", + initial: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get 128 LSBs at word boundary", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 128, + expected: BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + }, + { + name: "get 150 LSBs in third word", + initial: BitArray{ + len: 192, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0}, + }, + length: 150, + expected: BitArray{ + len: 150, + words: [4]uint64{maxUint64, maxUint64, 0x3FFFFF, 0}, + }, + }, + { + name: "get 220 LSBs in fourth word", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 220, + expected: BitArray{ + len: 220, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0xFFFFFFF}, + }, + }, + { + name: "get 251 LSBs", + initial: BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, maxUint64}, + }, + length: 251, + expected: BitArray{ + len: 251, + words: [4]uint64{maxUint64, maxUint64, maxUint64, 0x7FFFFFFFFFFFFFF}, + }, + }, + { + name: "get 100 LSBs from sparse bits", + initial: BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + length: 100, + expected: BitArray{ + len: 100, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x0000000555555555, 0, 0}, + }, + }, + { + name: "no change when new length equals current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 64, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "no change when new length greater than current length", + initial: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + length: 128, + expected: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := new(BitArray).copyLsb(&tt.initial, tt.length) + if !result.Equal(&tt.expected) { + t.Errorf("Truncate() got = %+v, want %+v", result, tt.expected) + } + }) + } +} + +func TestMSBs(t *testing.T) { + tests := []struct { + name string + x *BitArray + n uint8 + want *BitArray + }{ + { + name: "empty array", + x: emptyBitArray, + n: 0, + want: emptyBitArray, + }, + { + name: "get all bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get more bits than available", + x: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF, 0, 0, 0}, + }, + }, + { + name: "get half of available bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 32, + want: &BitArray{ + len: 32, + words: [4]uint64{0xFFFFFFFF00000000 >> 32, 0, 0, 0}, + }, + }, + { + name: "get MSBs across word boundary", + x: &BitArray{ + len: 128, + words: [4]uint64{maxUint64, maxUint64, 0, 0}, + }, + n: 100, + want: &BitArray{ + len: 100, + words: [4]uint64{maxUint64, maxUint64 >> 28, 0, 0}, + }, + }, + { + name: "get MSBs from max length array", + x: &BitArray{ + len: 255, + words: [4]uint64{maxUint64, maxUint64, maxUint64, ones63}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "get zero bits", + x: &BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + n: 0, + want: &BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "sparse bits", + x: &BitArray{ + len: 128, + words: [4]uint64{0xAAAAAAAAAAAAAAAA, 0x5555555555555555, 0, 0}, + }, + n: 64, + want: &BitArray{ + len: 64, + words: [4]uint64{0x5555555555555555, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).MSBs(tt.x, tt.n) + if !got.Equal(tt.want) { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + + if got.len != tt.want.len { + t.Errorf("MSBs() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestSetBit(t *testing.T) { + tests := []struct { + name string + bit uint8 + want BitArray + }{ + { + name: "set bit 0", + bit: 0, + want: BitArray{ + len: 1, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "set bit 1", + bit: 1, + want: BitArray{ + len: 1, + words: [4]uint64{1, 0, 0, 0}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBit(tt.bit) + if !got.Equal(&tt.want) { + t.Errorf("SetBit(%v) = %v, want %v", tt.bit, got, tt.want) + } + }) + } +} + +func TestSetBytes(t *testing.T) { + tests := []struct { + name string + length uint8 + data []byte + want BitArray + }{ + { + name: "empty data", + length: 0, + data: []byte{}, + want: BitArray{ + len: 0, + words: [4]uint64{0, 0, 0, 0}, + }, + }, + { + name: "single byte", + length: 8, + data: []byte{0xFF}, + want: BitArray{ + len: 8, + words: [4]uint64{0xFF, 0, 0, 0}, + }, + }, + { + name: "two bytes", + length: 16, + data: []byte{0xAA, 0xFF}, + want: BitArray{ + len: 16, + words: [4]uint64{0xAAFF, 0, 0, 0}, + }, + }, + { + name: "three bytes", + length: 24, + data: []byte{0xAA, 0xBB, 0xCC}, + want: BitArray{ + len: 24, + words: [4]uint64{0xAABBCC, 0, 0, 0}, + }, + }, + { + name: "four bytes", + length: 32, + data: []byte{0xAA, 0xBB, 0xCC, 0xDD}, + want: BitArray{ + len: 32, + words: [4]uint64{0xAABBCCDD, 0, 0, 0}, + }, + }, + { + name: "eight bytes (full word)", + length: 64, + data: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + want: BitArray{ + len: 64, + words: [4]uint64{maxUint64, 0, 0, 0}, + }, + }, + { + name: "sixteen bytes (two words)", + length: 128, + data: []byte{ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, 0xAA, + }, + want: BitArray{ + len: 128, + words: [4]uint64{ + 0xAAAAAAAAAAAAAAAA, + 0xFFFFFFFFFFFFFFFF, + 0, 0, + }, + }, + }, + { + name: "thirty-two bytes (full array)", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + { + name: "truncate to length", + length: 4, + data: []byte{0xFF}, + want: BitArray{ + len: 4, + words: [4]uint64{0xF, 0, 0, 0}, + }, + }, + { + name: "data larger than 32 bytes", + length: 251, + data: []byte{ + 0x7F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, // extra bytes should be ignored + }, + want: BitArray{ + len: 251, + words: [4]uint64{ + maxUint64, + maxUint64, + maxUint64, + 0x7FFFFFFFFFFFFFF, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(BitArray).SetBytes(tt.length, tt.data) + if !got.Equal(&tt.want) { + t.Errorf("SetBytes(%d, %v) = %v, want %v", tt.length, tt.data, got, tt.want) + } + }) + } +} diff --git a/trie/bintrie/empty.go b/trie/bintrie/empty.go index 7cfe373b35b0..5ba94f8c1517 100644 --- a/trie/bintrie/empty.go +++ b/trie/bintrie/empty.go @@ -59,7 +59,7 @@ func (e Empty) InsertValuesAtStem(key []byte, values [][]byte, _ NodeResolverFn, }, nil } -func (e Empty) CollectNodes(_ []byte, _ NodeFlushFn) error { +func (e Empty) CollectNodes(*BitArray, NodeFlushFn) error { return nil } diff --git a/trie/bintrie/empty_test.go b/trie/bintrie/empty_test.go index 574ae1830bed..134a59b6c14f 100644 --- a/trie/bintrie/empty_test.go +++ b/trie/bintrie/empty_test.go @@ -182,11 +182,12 @@ func TestEmptyCollectNodes(t *testing.T) { node := Empty{} var collected []BinaryNode - flushFn := func(path []byte, n BinaryNode) { + flushFn := func(path *BitArray, n BinaryNode) { collected = append(collected, n) } - err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + path := NewBitArray(3, 0b010) + err := node.CollectNodes(&path, flushFn) if err != nil { t.Fatalf("Unexpected error: %v", err) } diff --git a/trie/bintrie/expired_node.go b/trie/bintrie/expired_node.go index d3b90ee9eae7..1569c0fef801 100644 --- a/trie/bintrie/expired_node.go +++ b/trie/bintrie/expired_node.go @@ -148,7 +148,7 @@ func (n *expiredNode) InsertValuesAtStem(stem []byte, values [][]byte, resolver return resolved.InsertValuesAtStem(stem, values, resolver, depth) } -func (n *expiredNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { +func (n *expiredNode) CollectNodes(path *BitArray, flushfn NodeFlushFn) error { return nil } diff --git a/trie/bintrie/expired_node_test.go b/trie/bintrie/expired_node_test.go index ca9a7548cb2d..55d07b80d1b4 100644 --- a/trie/bintrie/expired_node_test.go +++ b/trie/bintrie/expired_node_test.go @@ -128,7 +128,7 @@ func TestExpiredNodeGetHeight(t *testing.T) { func TestExpiredNodeCollectNodes(t *testing.T) { node := &expiredNode{Offset: 100, depth: 5} called := false - err := node.CollectNodes(nil, func(path []byte, n BinaryNode) { + err := node.CollectNodes(new(BitArray), func(path *BitArray, n BinaryNode) { called = true }) diff --git a/trie/bintrie/hashed_node.go b/trie/bintrie/hashed_node.go index e4d8c2e7ac7d..042677ca55b5 100644 --- a/trie/bintrie/hashed_node.go +++ b/trie/bintrie/hashed_node.go @@ -80,7 +80,7 @@ func (h HashedNode) toDot(parent string, path string) string { return ret } -func (h HashedNode) CollectNodes([]byte, NodeFlushFn) error { +func (h HashedNode) CollectNodes(*BitArray, NodeFlushFn) error { // HashedNodes are already persisted in the database and don't need to be collected. return nil } diff --git a/trie/bintrie/internal_node.go b/trie/bintrie/internal_node.go index 0a7bece521fd..89eca83e6a93 100644 --- a/trie/bintrie/internal_node.go +++ b/trie/bintrie/internal_node.go @@ -24,16 +24,20 @@ import ( "github.com/ethereum/go-ethereum/common" ) +// keyToPath converts a key (stem) and depth into a compact path representation. +// It extracts the first depth+1 bits from the key and returns them as a +// packed byte slice. For example, depth=10 with key starting 0xFF 0x00 +// returns [0x07, 0xF8] (the first 11 bits: 11111111000). +// TODO(weiihann): double check on the correctness, depth=0 should point to root node in which the path should be nil? func keyToPath(depth int, key []byte) ([]byte, error) { if depth > 31*8 { return nil, errors.New("node too deep") } - path := make([]byte, 0, depth+1) - for i := range depth + 1 { - bit := key[i/8] >> (7 - (i % 8)) & 1 - path = append(path, bit) - } - return path, nil + // Cap key length to 31 bytes (248 bits) to avoid uint8 overflow + keyLen := min(len(key), 31) + ba := new(BitArray).SetBytes(uint8(keyLen*8), key[:keyLen]) + path := new(BitArray).MSBs(ba, uint8(depth+1)) + return path.ActiveBytes(), nil } // InternalNode is a binary trie internal node. @@ -186,21 +190,15 @@ func (bt *InternalNode) InsertValuesAtStem(stem []byte, values [][]byte, resolve // CollectNodes collects all child nodes at a given path, and flushes it // into the provided node collector. -func (bt *InternalNode) CollectNodes(path []byte, flushfn NodeFlushFn) error { +func (bt *InternalNode) CollectNodes(path *BitArray, flushfn NodeFlushFn) error { if bt.left != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 0) + childpath := new(BitArray).AppendBit(path, 0) if err := bt.left.CollectNodes(childpath, flushfn); err != nil { return err } } if bt.right != nil { - var p [256]byte - copy(p[:], path) - childpath := p[:len(path)] - childpath = append(childpath, 1) + childpath := new(BitArray).AppendBit(path, 1) if err := bt.right.CollectNodes(childpath, flushfn); err != nil { return err } diff --git a/trie/bintrie/internal_node_test.go b/trie/bintrie/internal_node_test.go index 158d8b7147d5..a1ab794c1c23 100644 --- a/trie/bintrie/internal_node_test.go +++ b/trie/bintrie/internal_node_test.go @@ -369,17 +369,17 @@ func TestInternalNodeCollectNodes(t *testing.T) { right: rightStem, } - var collectedPaths [][]byte + var collectedPaths []BitArray var collectedNodes []BinaryNode - flushFn := func(path []byte, n BinaryNode) { - pathCopy := make([]byte, len(path)) - copy(pathCopy, path) - collectedPaths = append(collectedPaths, pathCopy) + flushFn := func(path *BitArray, n BinaryNode) { + collectedPaths = append(collectedPaths, path.Copy()) collectedNodes = append(collectedNodes, n) } - err := node.CollectNodes([]byte{1}, flushFn) + // Initial path: 1 (single bit) + initialPath := NewBitArray(1, 1) + err := node.CollectNodes(&initialPath, flushFn) if err != nil { t.Fatalf("Failed to collect nodes: %v", err) } @@ -389,15 +389,15 @@ func TestInternalNodeCollectNodes(t *testing.T) { t.Errorf("Expected 3 collected nodes, got %d", len(collectedNodes)) } - // Check paths - expectedPaths := [][]byte{ - {1, 0}, // left child - {1, 1}, // right child - {1}, // internal node itself + // Check paths (binary: 10 = left child after 1, 11 = right child after 1, 1 = internal node) + expectedPaths := []BitArray{ + NewBitArray(2, 0b10), // left child (1 followed by 0) + NewBitArray(2, 0b11), // right child (1 followed by 1) + NewBitArray(1, 0b1), // internal node itself } for i, expectedPath := range expectedPaths { - if !bytes.Equal(collectedPaths[i], expectedPath) { + if !collectedPaths[i].Equal(&expectedPath) { t.Errorf("Path %d mismatch: expected %v, got %v", i, expectedPath, collectedPaths[i]) } } diff --git a/trie/bintrie/stem_node.go b/trie/bintrie/stem_node.go index 60856b42ce60..e422eb8b8d0e 100644 --- a/trie/bintrie/stem_node.go +++ b/trie/bintrie/stem_node.go @@ -135,7 +135,7 @@ func (bt *StemNode) Hash() common.Hash { // CollectNodes collects all child nodes at a given path, and flushes it // into the provided node collector. -func (bt *StemNode) CollectNodes(path []byte, flush NodeFlushFn) error { +func (bt *StemNode) CollectNodes(path *BitArray, flush NodeFlushFn) error { flush(path, bt) return nil } diff --git a/trie/bintrie/stem_node_test.go b/trie/bintrie/stem_node_test.go index d8d6844427de..b4754c274c5b 100644 --- a/trie/bintrie/stem_node_test.go +++ b/trie/bintrie/stem_node_test.go @@ -336,18 +336,17 @@ func TestStemNodeCollectNodes(t *testing.T) { depth: 0, } - var collectedPaths [][]byte + var collectedPaths []BitArray var collectedNodes []BinaryNode - flushFn := func(path []byte, n BinaryNode) { - // Make a copy of the path - pathCopy := make([]byte, len(path)) - copy(pathCopy, path) - collectedPaths = append(collectedPaths, pathCopy) + flushFn := func(path *BitArray, n BinaryNode) { + collectedPaths = append(collectedPaths, path.Copy()) collectedNodes = append(collectedNodes, n) } - err := node.CollectNodes([]byte{0, 1, 0}, flushFn) + // Path 010 in binary (3 bits) + initialPath := NewBitArray(3, 0b010) + err := node.CollectNodes(&initialPath, flushFn) if err != nil { t.Fatalf("Failed to collect nodes: %v", err) } @@ -363,7 +362,8 @@ func TestStemNodeCollectNodes(t *testing.T) { } // Check the path - if !bytes.Equal(collectedPaths[0], []byte{0, 1, 0}) { - t.Errorf("Path mismatch: expected [0, 1, 0], got %v", collectedPaths[0]) + expectedPath := NewBitArray(3, 0b010) + if !collectedPaths[0].Equal(&expectedPath) { + t.Errorf("Path mismatch: expected %v, got %v", expectedPath, collectedPaths[0]) } } diff --git a/trie/bintrie/trie.go b/trie/bintrie/trie.go index c082d57bdfb2..695ab157cd1c 100644 --- a/trie/bintrie/trie.go +++ b/trie/bintrie/trie.go @@ -324,9 +324,10 @@ func (t *BinaryTrie) Commit(_ bool) (common.Hash, *trienode.NodeSet) { nodeset := trienode.NewNodeSet(common.Hash{}) // The root can be any type of BinaryNode (InternalNode, StemNode, etc.) - err := t.root.CollectNodes(nil, func(path []byte, node BinaryNode) { + err := t.root.CollectNodes(new(BitArray), func(path *BitArray, node BinaryNode) { + pathBytes := path.ActiveBytes() serialized := SerializeNode(node) - nodeset.AddNode(path, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(path))) + nodeset.AddNode(pathBytes, trienode.NewNodeWithPrev(node.Hash(), serialized, t.tracer.Get(pathBytes))) }) if err != nil { panic(fmt.Errorf("CollectNodes failed: %v", err))