From 3a907ccbb8091c5c10c523039bce09a1ffa236ad Mon Sep 17 00:00:00 2001 From: Jan Pfeifer Date: Tue, 7 Apr 2026 11:03:03 +0200 Subject: [PATCH] Added "two-bits bucketing" strategy. --- docs/CHANGELOG.md | 1 + tokenizers/bucket/bucket.go | 76 ++++++++++++++++++++++++++++++ tokenizers/bucket/bucket_test.go | 79 ++++++++++++++++++++++++++++++++ 3 files changed, 156 insertions(+) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 0a9dbbb..e137f12 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -16,6 +16,7 @@ - Package `tokenizers/bucket` - Added `bucket` package for streaming tokenization of sentences into buckets (or batches) of discrete sizes, to minimize padding. + - Added "Two-Bits Bucketing" strategy. - Package `datasets`: - Added `datasets` package for downloading and iterating over parquet files of datasets from the HuggingFace Hub. - Added `cmd/generate_dataset_structs` for generating Go structs for dataset records. diff --git a/tokenizers/bucket/bucket.go b/tokenizers/bucket/bucket.go index 99fcd94..a0c0786 100644 --- a/tokenizers/bucket/bucket.go +++ b/tokenizers/bucket/bucket.go @@ -29,6 +29,7 @@ package bucket import ( "math" + "math/bits" "runtime" "sync" "time" @@ -160,6 +161,8 @@ func (b *Bucketizer) WithShapeFn(shapeFn ShapeFn) *Bucketizer { // The minSentenceLength is the smallest sentenceLength bucket. // // See also ByPowerBudget, if you want to round to a fixed tokens budget. +// +// For full control on the bucketing function, see WithShapeFn. func (b *Bucketizer) ByPower(batchSize, minSentenceLength int, base float64) *Bucketizer { return b.WithShapeFn(func(sentenceLength int) Shape { sentenceLength = max(sentenceLength, minSentenceLength) @@ -178,6 +181,8 @@ func (b *Bucketizer) ByPower(batchSize, minSentenceLength int, base float64) *Bu // and hopefully the downstream tasks run on +/- constant time. // // For sentence lengths > tokenBudget, it simply uses batchSize = 1. +// +// For full control on the bucketing function, see WithShapeFn. func (b *Bucketizer) ByPowerBudget(tokensBudget, minSentenceLength int, base float64) *Bucketizer { return b.WithShapeFn(func(sentenceLength int) Shape { sentenceLength = max(sentenceLength, minSentenceLength) @@ -190,6 +195,77 @@ func (b *Bucketizer) ByPowerBudget(tokensBudget, minSentenceLength int, base flo }) } +// ByTwoBitBucket configures the bucketizer to use buckets of sentence-length sized to the +// next value that can be represented with 2 bits. So: 1, 2, 3, 4, 6, 8, 12, 16, ... +// +// This is a "2-bit semi-log bucketing", and each size is separated from the other by +// a factor of 1.5 or 1.333 alternatingly, on average, by a factor of 1.414 (sqrt(2)), +// but results in numbers that are "friendlier" for binary addressing (and memory pages, etc.). +// +// For full control on the bucketing function, see WithShapeFn. +func (b *Bucketizer) ByTwoBitBucket(batchSize, minSentenceLength int) *Bucketizer { + return b.WithShapeFn(func(sentenceLength int) Shape { + sentenceLength = max(sentenceLength, minSentenceLength) + bucketSentenceLen := TwoBitBucketLen(sentenceLength) + return Shape{ + BatchSize: batchSize, + SentenceLength: bucketSentenceLen, + } + }) +} + +// ByTwoBitBucketBudget configures the bucketizer to use buckets of sentence-length sized to the +// next value that can be represented with 2 bits. So: 1, 2, 3, 4, 6, 8, 12, 16, ... +// +// This is a "2-bit semi-log bucketing", and each size is separated from the other by +// a factor of 1.5 or 1.333 alternatingly, on average, by a factor of 1.414 (sqrt(2)), +// but results in numbers that are "friendlier" for binary addressing (and memory pages, etc.). +// +// For full control on the bucketing function, see WithShapeFn. +func (b *Bucketizer) ByTwoBitBucketBudget(tokensBudget, minSentenceLength int) *Bucketizer { + return b.WithShapeFn(func(sentenceLength int) Shape { + sentenceLength = max(sentenceLength, minSentenceLength) + bucketSentenceLen := TwoBitBucketLen(sentenceLength) + batchSize := max(tokensBudget/bucketSentenceLen, 1) + return Shape{ + BatchSize: batchSize, + SentenceLength: bucketSentenceLen, + } + }) +} + +// TwoBitBucketLen returns the smallest size >= unpaddedLen that uses only +// the two highest bits (either 2^n or 1.5 * 2^n). +// +// It is used by Bucketizer.ByTwoBitBucket and Bucketizer.ByTwoBitBucketBudget. +func TwoBitBucketLen(unpaddedLen int) int { + if unpaddedLen <= 2 { + return unpaddedLen + } + + // Find the position of the most significant bit (MSB). + // bits.Len returns the number of bits required to represent the uint. + // For 5 (101), Len is 3. + msbPos := bits.Len(uint(unpaddedLen)) - 1 + msbValue := 1 << msbPos + + // Case 1: Exact power of 2 + if unpaddedLen == msbValue { + return msbValue + } + + // Case 2: Check the "1.5" threshold (the two highest bits) + // Example: If msbValue is 4 (100), threshold is 6 (110) + threshold := msbValue | (msbValue >> 1) + + if unpaddedLen <= threshold { + return threshold + } + + // Case 3: Above the 1.5 threshold, jump to the next power of 2 + return msbValue << 1 +} + // WithMaxParallelization sets the maximum number of sentences to tokenize in parallel. // // Set to -1 to use runtime.NumCPU(). diff --git a/tokenizers/bucket/bucket_test.go b/tokenizers/bucket/bucket_test.go index d42eacd..618779b 100644 --- a/tokenizers/bucket/bucket_test.go +++ b/tokenizers/bucket/bucket_test.go @@ -102,6 +102,85 @@ func TestByPowerBudget(t *testing.T) { } } +func TestTwoBitBucket(t *testing.T) { + t.Run("TwoBitBucketLen", func(t *testing.T) { + testCases := []struct { + length int + want int + }{ + {1, 1}, + {2, 2}, + {3, 3}, + {4, 4}, + {5, 6}, + {6, 6}, + {7, 8}, + {8, 8}, + {9, 12}, + {12, 12}, + {13, 16}, + {17, 24}, + {32, 32}, + {33, 48}, + } + + for _, tc := range testCases { + got := TwoBitBucketLen(tc.length) + if got != tc.want { + t.Errorf("TwoBitBucketLen(%d) = %d, want %d", tc.length, got, tc.want) + } + } + }) + + t.Run("Bucketizer.ByTwoBitBucket", func(t *testing.T) { + b := &Bucketizer{} + b.ByTwoBitBucket(32, 4) + + testCases := []struct { + length int + want Shape + }{ + {1, Shape{BatchSize: 32, SentenceLength: 4}}, // max(1, 4) -> 4 -> 4 + {3, Shape{BatchSize: 32, SentenceLength: 4}}, + {4, Shape{BatchSize: 32, SentenceLength: 4}}, + {5, Shape{BatchSize: 32, SentenceLength: 6}}, + {7, Shape{BatchSize: 32, SentenceLength: 8}}, + {9, Shape{BatchSize: 32, SentenceLength: 12}}, + } + + for _, tc := range testCases { + got := b.shapeFn(tc.length) + if got != tc.want { + t.Errorf("ByTwoBitBucket(length: %d) = %+v, want %+v", tc.length, got, tc.want) + } + } + }) + + t.Run("Bucketizer.ByTwoBitBucketBudget", func(t *testing.T) { + b := &Bucketizer{} + b.ByTwoBitBucketBudget(128, 4) + + testCases := []struct { + length int + want Shape + }{ + {1, Shape{BatchSize: 32, SentenceLength: 4}}, // 128 / 4 = 32 + {5, Shape{BatchSize: 21, SentenceLength: 6}}, // 128 / 6 = 21 + {7, Shape{BatchSize: 16, SentenceLength: 8}}, // 128 / 8 = 16 + {9, Shape{BatchSize: 10, SentenceLength: 12}}, // 128 / 12 = 10 + {128, Shape{BatchSize: 1, SentenceLength: 128}}, // 128 / 128 = 1 + {129, Shape{BatchSize: 1, SentenceLength: 192}}, // 128 / 192 = 0 -> max(0, 1) = 1 + } + + for _, tc := range testCases { + got := b.shapeFn(tc.length) + if got != tc.want { + t.Errorf("ByTwoBitBucketBudget(length: %d) = %+v, want %+v", tc.length, got, tc.want) + } + } + }) +} + // mockTokenizer for testing. It returns TokenIDs [1, 2, ..., len(text)]. type mockTokenizer struct { padID int