A from-scratch implementation of BitNet b1.58 (ternary {-1, 0, 1} weights) with custom Triton kernels, production-grade training stability, and full observability β designed for single-GPU pretraining on NVIDIA L20 (48GB).
BitNet b1.58 (Ma et al., 2024) represents a paradigm shift: every weight in the transformer is constrained to {-1, 0, 1}, encoding each parameter in just logβ(3) β 1.58 bits. This eliminates floating-point multiplication from the forward pass entirely β the matmul reduces to addition and subtraction.
TitanBit is not a wrapper around someone else's library. It is a ground-up implementation demonstrating:
| Component | What it demonstrates |
|---|---|
| BitLinear layer | Quantisation-Aware Training with STE, AbsMean activation quantisation |
| Triton kernels | Custom GPU kernels for ternary matmul (branch-free add/sub) |
| Weight packing | 2-bit encoding β 16Γ memory compression for inference |
| Full transformer | RoPE, GQA, SwiGLU, RMSNorm β all with ternary projections |
| mmap data pipeline | Zero-copy NVMe reads for sustained GPU utilisation |
| Stability system | Loss spike detection, auto-rollback, LR recovery |
| MFU tracking | Real-time Model Flops Utilisation against L20 peak (119.5 TFLOPS) |
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β BitNet b1.58 Forward Pass β
β β
β Input x (BF16) β
β β β
β βΌ β
β ββββββββββ ββββββββββββββ ββββββββββββ ββββββββββββ β
β βRMSNorm ββββΆβ AbsMean ββββΆβ Ternary ββββΆβ Rescale β β
β β(SubLN) β β Quant (8b) β β MatMul β β (Ξ² Γ Ξ³) β β
β ββββββββββ ββββββββββββββ β {-1,0,1} β ββββββββββββ β
β β β STE β β
β ββββββββββββ β
β β
β Key: No floating-point multiplication in the matmul. β
β The ternary constraint means WΓx = Β±x or 0. β
ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- Rotary Position Embeddings (RoPE) β length generalisation
- Group Query Attention (GQA) β reduced KV-cache memory
- SwiGLU MLP β ~5% perplexity improvement over GeLU
- RMSNorm β pre-norm architecture for stability
- BitLinear in all Q/K/V/O and MLP projections
| Size | Hidden | Layers | Heads | Params | L20 VRAM (train) |
|---|---|---|---|---|---|
| 125M | 768 | 12 | 12 | ~125M | ~4 GB |
| 350M | 1024 | 24 | 16 | ~350M | ~8 GB |
| 700M | 1536 | 24 | 24 | ~700M | ~14 GB |
| 1.3B | 2048 | 24 | 32 | ~1.3B | ~22 GB |
| 3B | 3200 | 26 | 32 | ~3B | ~42 GB |
# Core (model + training + Triton kernels)
pip install -e .
# With evaluation benchmarks
pip install -e ".[eval]"
# With FlashAttention
pip install -e ".[flash]"
# Everything
pip install -e ".[all]"Requirements:
- Python β₯ 3.10
- PyTorch β₯ 2.2.0 (with CUDA support)
- Triton β₯ 2.2.0
- NVIDIA GPU with compute capability β₯ 7.0 (Ada Lovelace recommended)
titanbit info --model-size 1.3B# Tokenise a text corpus into binary format
titanbit tokenize --input ./data/raw/ --output ./data/train.bin
# Or use TitanWash output directly
titanbit tokenize --input ../TitanWash/data/cleaned/ --output ./data/train.bin# Start training with default config (1.3B on L20)
titanbit train --config configs/default.yaml
# Resume from checkpoint
titanbit train --config configs/default.yaml --resume checkpoints/bitnet-1.3B/checkpoint_step_0010000.pt# Benchmark Triton ternary matmul vs cuBLAS
titanbit bench --m 2048 --k 2048 --n 2048from titanbit.model import BitNetConfig, BitNetTransformer
# Create a 1.3B model
config = BitNetConfig(hidden_size=2048, num_layers=24, num_heads=32)
model = BitNetTransformer(config)
# Forward pass
import torch
ids = torch.randint(0, 32000, (1, 512))
logits, loss = model(ids, labels=ids)
print(f"Loss: {loss.item():.4f}")Standard nn.Linear: y = x @ W^T + b (FP16 multiply-accumulate)
BitLinear:
y = rescale(quant_8bit(RMSNorm(x)) @ quant_ternary(W)^T)
- RMSNorm (SubLN) β Normalises input distribution before quantisation
- AbsMean Quantisation β Scales activations to [-128, 127] per-token
- Ternary Quantisation β
W_q = round(W / mean(|W|))clipped to {-1, 0, 1} - STE β Gradients bypass the quantisation step during backprop
- Rescale β
out Γ (Ξ² Γ Ξ³ / 127)restores the magnitude
The ternary matmul kernel eliminates all FMA operations:
For each output element y[i,j]:
y[i,j] = Ξ£_k W[j,k] Γ x[i,k]
where W[j,k] β {-1, 0, 1}, so:
if W = +1: accumulate += x
if W = -1: accumulate -= x
if W = 0: skip
Weight packing: 16 ternary values β 1 int32 (2 bits each) β 16Γ memory compression for inference weights.
BitNet training is more unstable than standard transformers due to the non-smooth STE landscape. TitanBit implements a 4-layer stability system:
Layer 1: Gradient clipping (max_norm=1.0)
Layer 2: Loss spike detection (EMA-based, threshold=5Γ)
Layer 3: Automatic rollback to last stable checkpoint
Layer 4: Learning rate scaling (0.5Γ) after recovery
# Run all tests
pytest tests/ -v
# Test specific modules
pytest tests/test_bitlinear.py -v # Core quantisation
pytest tests/test_model.py -v # Full transformer
# With coverage
pytest tests/ --cov=titanbit --cov-report=term-missingTitanBit/
βββ configs/
β βββ default.yaml # 1.3B config tuned for L20
βββ src/titanbit/
β βββ model/
β β βββ config.py # Model configs (125M β 3B)
β β βββ bitlinear.py # BitLinear layer + RMSNorm + STE
β β βββ transformer.py # Full transformer (RoPE, GQA, SwiGLU)
β β βββ kernels.py # Triton ternary matmul kernel
β βββ training/
β β βββ data.py # mmap data pipeline
β β βββ trainer.py # Training loop (BF16, MFU, checkpoints)
β β βββ stability.py # Loss spike detection & recovery
β βββ utils/
β β βββ metrics.py # GPU metrics, throughput tracking
β βββ cli.py # CLI (train, bench, tokenize, info)
βββ tests/
β βββ test_bitlinear.py # 15+ quantisation tests
β βββ test_model.py # 10+ transformer tests
βββ pyproject.toml # PEP 621 project config
βββ README.md
- BitNet b1.58 β Ma et al. (2024). The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits. arXiv:2402.17764
- BitNet β Wang et al. (2023). BitNet: Scaling 1-bit Transformers for Large Language Models. arXiv:2310.11453
- STE β Bengio et al. (2013). Estimating or Propagating Gradients Through Stochastic Neurons.
- RoPE β Su et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding.
- SwiGLU β Shazeer (2020). GLU Variants Improve Transformer.
- GQA β Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models.
- Chinchilla β Hoffmann et al. (2022). Training Compute-Optimal Large Language Models.
- FlashAttention β Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention.
Apache 2.0 β See LICENSE for details.