Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
- Package `tokenizers/bucket`
- Added `bucket` package for streaming tokenization of sentences into buckets (or batches) of discrete sizes,
to minimize padding.
- Added "Two-Bits Bucketing" strategy.
- Package `datasets`:
- Added `datasets` package for downloading and iterating over parquet files of datasets from the HuggingFace Hub.
- Added `cmd/generate_dataset_structs` for generating Go structs for dataset records.
Expand Down
76 changes: 76 additions & 0 deletions tokenizers/bucket/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ package bucket

import (
"math"
"math/bits"
"runtime"
"sync"
"time"
Expand Down Expand Up @@ -160,6 +161,8 @@ func (b *Bucketizer) WithShapeFn(shapeFn ShapeFn) *Bucketizer {
// The minSentenceLength is the smallest sentenceLength bucket.
//
// See also ByPowerBudget, if you want to round to a fixed tokens budget.
//
// For full control on the bucketing function, see WithShapeFn.
func (b *Bucketizer) ByPower(batchSize, minSentenceLength int, base float64) *Bucketizer {
return b.WithShapeFn(func(sentenceLength int) Shape {
sentenceLength = max(sentenceLength, minSentenceLength)
Expand All @@ -178,6 +181,8 @@ func (b *Bucketizer) ByPower(batchSize, minSentenceLength int, base float64) *Bu
// and hopefully the downstream tasks run on +/- constant time.
//
// For sentence lengths > tokenBudget, it simply uses batchSize = 1.
//
// For full control on the bucketing function, see WithShapeFn.
func (b *Bucketizer) ByPowerBudget(tokensBudget, minSentenceLength int, base float64) *Bucketizer {
return b.WithShapeFn(func(sentenceLength int) Shape {
sentenceLength = max(sentenceLength, minSentenceLength)
Expand All @@ -190,6 +195,77 @@ func (b *Bucketizer) ByPowerBudget(tokensBudget, minSentenceLength int, base flo
})
}

// ByTwoBitBucket configures the bucketizer to use buckets of sentence-length sized to the
// next value that can be represented with 2 bits. So: 1, 2, 3, 4, 6, 8, 12, 16, ...
//
// This is a "2-bit semi-log bucketing", and each size is separated from the other by
// a factor of 1.5 or 1.333 alternatingly, on average, by a factor of 1.414 (sqrt(2)),
// but results in numbers that are "friendlier" for binary addressing (and memory pages, etc.).
//
// For full control on the bucketing function, see WithShapeFn.
func (b *Bucketizer) ByTwoBitBucket(batchSize, minSentenceLength int) *Bucketizer {
return b.WithShapeFn(func(sentenceLength int) Shape {
sentenceLength = max(sentenceLength, minSentenceLength)
bucketSentenceLen := TwoBitBucketLen(sentenceLength)
return Shape{
BatchSize: batchSize,
SentenceLength: bucketSentenceLen,
}
})
}

// ByTwoBitBucketBudget configures the bucketizer to use buckets of sentence-length sized to the
// next value that can be represented with 2 bits. So: 1, 2, 3, 4, 6, 8, 12, 16, ...
//
// This is a "2-bit semi-log bucketing", and each size is separated from the other by
// a factor of 1.5 or 1.333 alternatingly, on average, by a factor of 1.414 (sqrt(2)),
// but results in numbers that are "friendlier" for binary addressing (and memory pages, etc.).
//
// For full control on the bucketing function, see WithShapeFn.
func (b *Bucketizer) ByTwoBitBucketBudget(tokensBudget, minSentenceLength int) *Bucketizer {
return b.WithShapeFn(func(sentenceLength int) Shape {
sentenceLength = max(sentenceLength, minSentenceLength)
bucketSentenceLen := TwoBitBucketLen(sentenceLength)
batchSize := max(tokensBudget/bucketSentenceLen, 1)
return Shape{
BatchSize: batchSize,
SentenceLength: bucketSentenceLen,
}
})
}

// TwoBitBucketLen returns the smallest size >= unpaddedLen that uses only
// the two highest bits (either 2^n or 1.5 * 2^n).
//
// It is used by Bucketizer.ByTwoBitBucket and Bucketizer.ByTwoBitBucketBudget.
func TwoBitBucketLen(unpaddedLen int) int {
if unpaddedLen <= 2 {
return unpaddedLen
}

// Find the position of the most significant bit (MSB).
// bits.Len returns the number of bits required to represent the uint.
// For 5 (101), Len is 3.
msbPos := bits.Len(uint(unpaddedLen)) - 1
msbValue := 1 << msbPos

// Case 1: Exact power of 2
if unpaddedLen == msbValue {
return msbValue
}

// Case 2: Check the "1.5" threshold (the two highest bits)
// Example: If msbValue is 4 (100), threshold is 6 (110)
threshold := msbValue | (msbValue >> 1)

if unpaddedLen <= threshold {
return threshold
}

// Case 3: Above the 1.5 threshold, jump to the next power of 2
return msbValue << 1
}

// WithMaxParallelization sets the maximum number of sentences to tokenize in parallel.
//
// Set to -1 to use runtime.NumCPU().
Expand Down
79 changes: 79 additions & 0 deletions tokenizers/bucket/bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,85 @@ func TestByPowerBudget(t *testing.T) {
}
}

func TestTwoBitBucket(t *testing.T) {
t.Run("TwoBitBucketLen", func(t *testing.T) {
testCases := []struct {
length int
want int
}{
{1, 1},
{2, 2},
{3, 3},
{4, 4},
{5, 6},
{6, 6},
{7, 8},
{8, 8},
{9, 12},
{12, 12},
{13, 16},
{17, 24},
{32, 32},
{33, 48},
}

for _, tc := range testCases {
got := TwoBitBucketLen(tc.length)
if got != tc.want {
t.Errorf("TwoBitBucketLen(%d) = %d, want %d", tc.length, got, tc.want)
}
}
})

t.Run("Bucketizer.ByTwoBitBucket", func(t *testing.T) {
b := &Bucketizer{}
b.ByTwoBitBucket(32, 4)

testCases := []struct {
length int
want Shape
}{
{1, Shape{BatchSize: 32, SentenceLength: 4}}, // max(1, 4) -> 4 -> 4
{3, Shape{BatchSize: 32, SentenceLength: 4}},
{4, Shape{BatchSize: 32, SentenceLength: 4}},
{5, Shape{BatchSize: 32, SentenceLength: 6}},
{7, Shape{BatchSize: 32, SentenceLength: 8}},
{9, Shape{BatchSize: 32, SentenceLength: 12}},
}

for _, tc := range testCases {
got := b.shapeFn(tc.length)
if got != tc.want {
t.Errorf("ByTwoBitBucket(length: %d) = %+v, want %+v", tc.length, got, tc.want)
}
}
})

t.Run("Bucketizer.ByTwoBitBucketBudget", func(t *testing.T) {
b := &Bucketizer{}
b.ByTwoBitBucketBudget(128, 4)

testCases := []struct {
length int
want Shape
}{
{1, Shape{BatchSize: 32, SentenceLength: 4}}, // 128 / 4 = 32
{5, Shape{BatchSize: 21, SentenceLength: 6}}, // 128 / 6 = 21
{7, Shape{BatchSize: 16, SentenceLength: 8}}, // 128 / 8 = 16
{9, Shape{BatchSize: 10, SentenceLength: 12}}, // 128 / 12 = 10
{128, Shape{BatchSize: 1, SentenceLength: 128}}, // 128 / 128 = 1
{129, Shape{BatchSize: 1, SentenceLength: 192}}, // 128 / 192 = 0 -> max(0, 1) = 1
}

for _, tc := range testCases {
got := b.shapeFn(tc.length)
if got != tc.want {
t.Errorf("ByTwoBitBucketBudget(length: %d) = %+v, want %+v", tc.length, got, tc.want)
}
}
})
}

// mockTokenizer for testing. It returns TokenIDs [1, 2, ..., len(text)].
type mockTokenizer struct {
padID int
Expand Down
Loading