diff --git a/benchmarks/bench_scaling.py b/benchmarks/bench_scaling.py new file mode 100644 index 0000000..5b395a8 --- /dev/null +++ b/benchmarks/bench_scaling.py @@ -0,0 +1,59 @@ +"""Benchmark how training scales with text size and merge count.""" + +import time + +from complex_tokenization.fast_bpe_trainer import FastBPETrainer +from complex_tokenization.graphs.settings import GraphSettings +from complex_tokenization.graphs.units import utf8_clusters +from complex_tokenization.graphs.words import words +from complex_tokenization.trainer import Trainer + + +def train_graph_bpe(texts, num_merges): + GraphSettings.ONLY_MINIMAL_MERGES = True + GraphSettings.MAX_MERGE_SIZE = 2 + GraphSettings.USE_SINGLETONS = False + graphs = tuple(words(t, connected=False, units=utf8_clusters) for t in texts) + trainer = Trainer(graphs=graphs) + trainer.train(num_merges=num_merges) + return trainer.get_merges() + + +def train_fast_bpe(texts, num_merges): + fast = FastBPETrainer(texts) + fast.train(num_merges=num_merges) + return fast.get_merges() + + +BASE_TEXT = "the teacher teaches the thick thing about the theorem " + + +def run(): + print(f"\n{'='*80}") + print("Scaling Benchmark: Graph BPE vs Fast BPE") + print(f"{'='*80}") + print(f"{'Config':30s} {'Graph BPE':>10s} {'Fast BPE':>10s} {'Speedup':>8s}") + print("-" * 80) + + for num_texts in [10, 50, 100]: + for repeat in [10, 50]: + for num_merges in [50, 100, 200]: + texts = [BASE_TEXT * repeat] * num_texts + total_chars = sum(len(t) for t in texts) + + start = time.perf_counter() + graph_merges = train_graph_bpe(texts, num_merges) + graph_time = time.perf_counter() - start + + start = time.perf_counter() + fast_merges = train_fast_bpe(texts, num_merges) + fast_time = time.perf_counter() - start + + speedup = graph_time / fast_time if fast_time > 0 else float('inf') + match = "ok" if graph_merges == fast_merges else "MISMATCH" + label = f"{num_texts}x{repeat}rep m={num_merges} ({total_chars:,}ch)" + print(f"{label:30s} {graph_time:>9.3f}s {fast_time:>9.3f}s {speedup:>7.1f}x {match}") + + +if __name__ == "__main__": + run() diff --git a/tests/test_scaling.py b/tests/test_scaling.py new file mode 100644 index 0000000..f594f81 --- /dev/null +++ b/tests/test_scaling.py @@ -0,0 +1,27 @@ +"""Test that FastBPE scales well with larger inputs.""" + +import time + +from complex_tokenization.fast_bpe_trainer import FastBPETrainer + +BASE = "the teacher teaches the thick thing about the theorem " + + +class TestScaling: + def test_100k_chars_under_5s(self): + texts = [BASE * 50] * 100 # ~270k chars + start = time.perf_counter() + fast = FastBPETrainer(texts) + fast.train(num_merges=25) + elapsed = time.perf_counter() - start + assert elapsed < 5, f"FastBPE on 270k chars took {elapsed:.1f}s (limit: 5s)" + assert len(fast.merges) == 25 + + def test_merges_scale_with_data(self): + small = [BASE * 10] * 10 + large = [BASE * 50] * 50 + f_small = FastBPETrainer(small) + f_small.train(num_merges=50) + f_large = FastBPETrainer(large) + f_large.train(num_merges=50) + assert f_small.get_merges() == f_large.get_merges()