Skip to content

Kevin-Li-2025/bitnet-1p58b-experiments

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

11 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

⚑ TitanBit

1.58-bit LLM Pretraining Engine with Custom Triton Kernels

Python 3.10+ License: Apache 2.0 PyTorch Triton

A from-scratch implementation of BitNet b1.58 (ternary {-1, 0, 1} weights) with custom Triton kernels, production-grade training stability, and full observability β€” designed for single-GPU pretraining on NVIDIA L20 (48GB).


🎯 Why TitanBit?

BitNet b1.58 (Ma et al., 2024) represents a paradigm shift: every weight in the transformer is constrained to {-1, 0, 1}, encoding each parameter in just logβ‚‚(3) β‰ˆ 1.58 bits. This eliminates floating-point multiplication from the forward pass entirely β€” the matmul reduces to addition and subtraction.

TitanBit is not a wrapper around someone else's library. It is a ground-up implementation demonstrating:

Component What it demonstrates
BitLinear layer Quantisation-Aware Training with STE, AbsMean activation quantisation
Triton kernels Custom GPU kernels for ternary matmul (branch-free add/sub)
Weight packing 2-bit encoding β†’ 16Γ— memory compression for inference
Full transformer RoPE, GQA, SwiGLU, RMSNorm β€” all with ternary projections
mmap data pipeline Zero-copy NVMe reads for sustained GPU utilisation
Stability system Loss spike detection, auto-rollback, LR recovery
MFU tracking Real-time Model Flops Utilisation against L20 peak (119.5 TFLOPS)

πŸ—οΈ Architecture

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚                   BitNet b1.58 Forward Pass                    β”‚
β”‚                                                                β”‚
β”‚  Input x (BF16)                                                β”‚
β”‚      β”‚                                                         β”‚
β”‚      β–Ό                                                         β”‚
β”‚  β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”‚
β”‚  β”‚RMSNorm │──▢│ AbsMean    │──▢│ Ternary  │──▢│ Rescale  β”‚   β”‚
β”‚  β”‚(SubLN) β”‚   β”‚ Quant (8b) β”‚   β”‚ MatMul   β”‚   β”‚ (Ξ² Γ— Ξ³)  β”‚   β”‚
β”‚  β””β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚ {-1,0,1} β”‚   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β”‚
β”‚                                 β”‚ ← STE    β”‚                   β”‚
β”‚                                 β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜                   β”‚
β”‚                                                                β”‚
β”‚  Key: No floating-point multiplication in the matmul.          β”‚
β”‚  The ternary constraint means WΓ—x = Β±x or 0.                  β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Model Architecture

  • Rotary Position Embeddings (RoPE) β€” length generalisation
  • Group Query Attention (GQA) β€” reduced KV-cache memory
  • SwiGLU MLP β€” ~5% perplexity improvement over GeLU
  • RMSNorm β€” pre-norm architecture for stability
  • BitLinear in all Q/K/V/O and MLP projections

Pre-defined Sizes

Size Hidden Layers Heads Params L20 VRAM (train)
125M 768 12 12 ~125M ~4 GB
350M 1024 24 16 ~350M ~8 GB
700M 1536 24 24 ~700M ~14 GB
1.3B 2048 24 32 ~1.3B ~22 GB
3B 3200 26 32 ~3B ~42 GB

πŸ“¦ Installation

# Core (model + training + Triton kernels)
pip install -e .

# With evaluation benchmarks
pip install -e ".[eval]"

# With FlashAttention
pip install -e ".[flash]"

# Everything
pip install -e ".[all]"

Requirements:

  • Python β‰₯ 3.10
  • PyTorch β‰₯ 2.2.0 (with CUDA support)
  • Triton β‰₯ 2.2.0
  • NVIDIA GPU with compute capability β‰₯ 7.0 (Ada Lovelace recommended)

πŸš€ Quick Start

1. Show model info

titanbit info --model-size 1.3B

2. Prepare training data

# Tokenise a text corpus into binary format
titanbit tokenize --input ./data/raw/ --output ./data/train.bin

# Or use TitanWash output directly
titanbit tokenize --input ../TitanWash/data/cleaned/ --output ./data/train.bin

3. Train

# Start training with default config (1.3B on L20)
titanbit train --config configs/default.yaml

# Resume from checkpoint
titanbit train --config configs/default.yaml --resume checkpoints/bitnet-1.3B/checkpoint_step_0010000.pt

4. Benchmark kernels

# Benchmark Triton ternary matmul vs cuBLAS
titanbit bench --m 2048 --k 2048 --n 2048

Python API

from titanbit.model import BitNetConfig, BitNetTransformer

# Create a 1.3B model
config = BitNetConfig(hidden_size=2048, num_layers=24, num_heads=32)
model = BitNetTransformer(config)

# Forward pass
import torch
ids = torch.randint(0, 32000, (1, 512))
logits, loss = model(ids, labels=ids)
print(f"Loss: {loss.item():.4f}")

πŸ”¬ Technical Deep-Dives

BitLinear: The Core Innovation

Standard nn.Linear: y = x @ W^T + b (FP16 multiply-accumulate)

BitLinear:

y = rescale(quant_8bit(RMSNorm(x)) @ quant_ternary(W)^T)
  1. RMSNorm (SubLN) β€” Normalises input distribution before quantisation
  2. AbsMean Quantisation β€” Scales activations to [-128, 127] per-token
  3. Ternary Quantisation β€” W_q = round(W / mean(|W|)) clipped to {-1, 0, 1}
  4. STE β€” Gradients bypass the quantisation step during backprop
  5. Rescale β€” out Γ— (Ξ² Γ— Ξ³ / 127) restores the magnitude

Custom Triton Kernel

The ternary matmul kernel eliminates all FMA operations:

For each output element y[i,j]:
    y[i,j] = Ξ£_k  W[j,k] Γ— x[i,k]

    where W[j,k] ∈ {-1, 0, 1}, so:
        if W = +1:  accumulate += x
        if W = -1:  accumulate -= x
        if W =  0:  skip

Weight packing: 16 ternary values β†’ 1 int32 (2 bits each) β†’ 16Γ— memory compression for inference weights.

Stability System

BitNet training is more unstable than standard transformers due to the non-smooth STE landscape. TitanBit implements a 4-layer stability system:

Layer 1: Gradient clipping (max_norm=1.0)
Layer 2: Loss spike detection (EMA-based, threshold=5Γ—)
Layer 3: Automatic rollback to last stable checkpoint
Layer 4: Learning rate scaling (0.5Γ—) after recovery

πŸ§ͺ Testing

# Run all tests
pytest tests/ -v

# Test specific modules
pytest tests/test_bitlinear.py -v      # Core quantisation
pytest tests/test_model.py -v          # Full transformer

# With coverage
pytest tests/ --cov=titanbit --cov-report=term-missing

πŸ“ Project Structure

TitanBit/
β”œβ”€β”€ configs/
β”‚   └── default.yaml              # 1.3B config tuned for L20
β”œβ”€β”€ src/titanbit/
β”‚   β”œβ”€β”€ model/
β”‚   β”‚   β”œβ”€β”€ config.py             # Model configs (125M β†’ 3B)
β”‚   β”‚   β”œβ”€β”€ bitlinear.py          # BitLinear layer + RMSNorm + STE
β”‚   β”‚   β”œβ”€β”€ transformer.py        # Full transformer (RoPE, GQA, SwiGLU)
β”‚   β”‚   └── kernels.py            # Triton ternary matmul kernel
β”‚   β”œβ”€β”€ training/
β”‚   β”‚   β”œβ”€β”€ data.py               # mmap data pipeline
β”‚   β”‚   β”œβ”€β”€ trainer.py            # Training loop (BF16, MFU, checkpoints)
β”‚   β”‚   └── stability.py          # Loss spike detection & recovery
β”‚   β”œβ”€β”€ utils/
β”‚   β”‚   └── metrics.py            # GPU metrics, throughput tracking
β”‚   └── cli.py                    # CLI (train, bench, tokenize, info)
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ test_bitlinear.py         # 15+ quantisation tests
β”‚   └── test_model.py             # 10+ transformer tests
β”œβ”€β”€ pyproject.toml                # PEP 621 project config
└── README.md

πŸ”¬ References

  • BitNet b1.58 β€” Ma et al. (2024). The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits. arXiv:2402.17764
  • BitNet β€” Wang et al. (2023). BitNet: Scaling 1-bit Transformers for Large Language Models. arXiv:2310.11453
  • STE β€” Bengio et al. (2013). Estimating or Propagating Gradients Through Stochastic Neurons.
  • RoPE β€” Su et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding.
  • SwiGLU β€” Shazeer (2020). GLU Variants Improve Transformer.
  • GQA β€” Ainslie et al. (2023). GQA: Training Generalized Multi-Query Transformer Models.
  • Chinchilla β€” Hoffmann et al. (2022). Training Compute-Optimal Large Language Models.
  • FlashAttention β€” Dao et al. (2022). FlashAttention: Fast and Memory-Efficient Exact Attention.

πŸ“ License

Apache 2.0 β€” See LICENSE for details.

About

1.58-bit LLM pretraining experiments with quantization-aware layers, Triton kernels, stability tracking, and GPU instrumentation.

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages