A 1.7B parameter masked diffusion language model combining Microsoft's BitNet a4.8 quantization (ternary weights + hybrid 4-bit/8-bit activations) with an MDLM-style absorbing-state diffusion objective, quantized KV cache, and latent scratchpad thinking tokens.
License: Model weights are released under the BigCode OpenRAIL-M license. Source code is Apache 2.0. See LICENSE for details.
| Property | BitDiffusion a4.8 | Standard LLM |
|---|---|---|
| Generation | Bidirectional masked diffusion | Left-to-right autoregressive |
| Weights | Ternary {−1, 0, +1} (BitNet b1.58) | float16 / bfloat16 |
| Activations | INT4 inputs + TopK(55%) + INT8 intermediates | float16 |
| KV Cache | 3-bit quantized (TurboQuant-style rotation) | float16 |
| Thinking | 64-token latent scratchpad (adaptive) | Chain-of-thought in prompt |
| Inference weights | Packed 2-bit ternary + Triton/CPU INT4×INT8 kernel | Float16 GEMM |
| Context | 4,096 tokens | Varies |
| Training stages | Two-stage A8 → A4 activation schedule | Single stage |
Prerequisites: Python 3.10+, CUDA 12.1+, ~40 GB VRAM for training (single A100 40GB with gradient checkpointing).
git clone https://github.com/Fury7425/bitDiffusion-a4.8
cd bitDiffusion-a4.8
pip install -e .Dependencies (requirements.txt):
torch>=2.2transformers>=4.40safetensors>=0.4.3wandbdatasets>=2.19
python sample.py \
--checkpoint checkpoints/step_57500.pt \
--prompt "The theory of relativity states that" \
--length 200 \
--steps 20python sample.py \
--checkpoint checkpoints/step_57500.pt \
--thinking \
--adaptive_think \
--prompt "Explain how neural networks learn" \
--length 300 \
--verboseAfter loading any checkpoint, call pack_for_inference() once to swap the
float ternary simulation for the real low-bit kernel — ~16x smaller weight
tensors and a real GPU compute speedup (see Low-bit packed inference).
import torch
from bitdiffusion import BitDiffusionTransformer, ModelConfig
ckpt = torch.load("checkpoints/step_57500.pt", weights_only=True)
cfg = ModelConfig(**ckpt["model_config"])
model = BitDiffusionTransformer(cfg)
model.load_state_dict(ckpt["model_state_dict"])
model.eval().pack_for_inference() # one-time, before sampling-
Weights: Ternary {−1, 0, +1} via absmean quantization with straight-through estimator (STE). Full-precision latent weights are maintained during training and quantized on every forward pass.
-
Activations (BitNet a4.8 hybrid scheme):
- Q, K, V, FFN gate/up projections: absmax INT4 per-token
- Attention output, FFN down projections: TopK(55%) sparsification + absmax INT8
- Two-stage training schedule transitions from INT8 → hybrid INT4+TopK at 90% of steps
-
KV Cache: 3-bit quantized K/V tensors via random rotation + scalar quantization (TurboQuant). BOS token stored at 4-bit for outlier precision. Cache resets between denoising steps; ephemeral mode supports block diffusion.
-
Thinking tokens: 64 latent scratchpad positions prepended to every sequence. At inference, the thinking phase runs adaptively — stopping when token change rate drops below 2% for 3 consecutive steps (max 128 steps).
-
Diffusion objective: Masked absorbing-state diffusion (MDLM-style). Tokens are corrupted to a
[MASK]absorbing state according to a cosine noise schedule. The model is trained to denoise all masked positions simultaneously. -
Positional encoding: Rotary Position Embeddings (RoPE) with auto-extending cache. Supports
rope_offsetfor correct positions in block diffusion. -
FFN: SwiGLU with hidden dimension 8,192.
-
Normalization: RMSNorm pre-norm at each layer.
-
Noise conditioning: Sinusoidal + learned projection embeds the per-sample noise level
t ∈ [0,1]and injects it as an additive bias after the first RMSNorm in every block.
| Hyperparameter | Value |
|---|---|
| Parameters (total) | 1.705B |
| Parameters (ternary) | 1.074B |
| Parameters (full precision) | 0.631B |
| Hidden dimension | 2,048 |
| Layers | 16 |
| Attention heads | 16 |
| Head dimension | 128 |
| FFN dimension | 8,192 |
| Vocabulary size | 152,064 (Qwen tokenizer) |
| Context window | 4,096 tokens |
| Thinking tokens | 64 |
| KV cache bits | 3 (BOS: 4) |
bitdiffusion/rdt.py is the default model in this repo. Built on the
OpenMythos architecture, it replaces the stacked-layers design with a
Prelude → RecurrentBlock → Coda structure: shared weights are applied for
multiple loop iterations, giving the model depth-adaptivity without extra parameters.
Key adaptations for diffusion:
- Bidirectional attention throughout (no causal mask)
- Diffusion timestep
t_embre-injected at every recurrence iteration - Soft ACT weighting (no hard per-token halting) for uniform refinement
- LTI A matrix:
0.99 * tanh(A_raw)guarantees spectral radius < 1 - Loop dropout during training so every loop prefix is independently useful
- Inference-time depth extrapolation via
--n_loopsinsample.py
train.py uses model_type="rdt" by default. Pass --model_type standard to
opt out and train the flat BitDiffusionTransformer instead. Both checkpoint
formats are auto-detected by sample.py and export.py.
Download and preprocess the ~40B token training mix:
export HF_TOKEN=hf_your_token_here
python prepare_hf_jsonl.pyProduces data/train/hf_mix_train.jsonl and data/val/hf_mix_val.jsonl.
Progress is checkpointed to data/hf_shards/progress.json — safe to interrupt and resume.
Dataset mix (~40B tokens):
| Dataset | Source | Tokens |
|---|---|---|
| FineWeb-Edu | HuggingFaceFW/fineweb-edu (sample-100BT) | 15B |
| DCLM | HuggingFaceFW/dclm_100BT | 8B |
| OpenWebMath | open-web-math/open-web-math | 7B |
| Cosmopedia | HuggingFaceTB/cosmopedia | 4B |
| Wikipedia (EN) | wikimedia/wikipedia 20231101.en | 2B |
| FinePDFs | HuggingFaceFW/finepdfs_100BT | 2B |
| MathCode-Pile | MathGenie/MathCode-Pile | 2B |
| StarCoder Python | bigcode/starcoderdata (python) | 2B |
| StarCoder JS | bigcode/starcoderdata (javascript) | 1B |
Chunks are sampled from a weighted sequence-length distribution
{128: 5%, 256: 8%, 512: 10%, 1024: 15%, 2048: 20%, 4096: 42%} so the model
learns to handle the full range of context lengths.
All 1B defaults are baked in:
wandb login # optional
python train.pyRuns 57,500 steps × (8 batch × 16 grad accum × 4,096 seq) = 30.1B tokens.
Training stays on the float-sim path. Never call
pack_for_inference()during training — packed BitLinears are not differentiable. Packing is an inference-only, one-way operation.
Resume after preemption:
python train.py --resume_from checkpoints/step_XXXXX.ptCustom config:
python train.py \
--max_steps 57500 \
--batch_size 8 \
--max_seq_len 4096 \
--lr 2e-4 \
--warmup_steps 4000 \
--grad_accum_steps 16 \
--a4_warmup_fraction 0.10 \
--gradient_checkpointing \
--wandb_project bitdiffusion-a48| Parameter | Value | Notes |
|---|---|---|
| Steps | 57,500 | 30.1B total tokens |
| Batch size | 8 | Per-device |
| Gradient accumulation | 16 | Effective batch: 524,288 tok/step |
| Sequence length | 4,096 | |
| Peak LR (AdamW) | 2e-4 | Embeddings, norms, biases, unembedding head |
| Peak LR (Muon) | 0.02 | 2D weight matrices in the transformer body |
| LR schedule | Cosine + linear warmup | Min LR ratio: 0.1 |
| Warmup steps | 4,000 | |
| Weight decay | 0.05 | AdamW |
| Optimizer | Muon + AdamW hybrid | DeepSeek V4 style; toggle via use_muon=False |
| Gradient clip | 1.0 | |
| Mixed precision | bf16 | |
| Gradient checkpointing | Yes | ~29.5 GB on A100 40GB |
| A4 warmup fraction | 0.10 | Last 10% of steps in A4 mode |
Steps 0 → 51,750 (90%) W1.58A8: all activations INT8
Steps 51,750 → 57,500 (10%) W1.58A4: hybrid INT4 + TopK(55%) + INT8
Stage 1 lets ternary weights converge under a less aggressive quantization regime.
Stage 2 fine-tunes under the exact target inference quantization.
Adjust with --a4_warmup_fraction.
Basic generation:
python sample.py \
--checkpoint checkpoints/step_57500.pt \
--prompt "The theory of relativity states that" \
--length 200 \
--steps 20Adaptive thinking — scratchpad runs until token change rate < 2% for 3 steps (max 128):
python sample.py \
--checkpoint checkpoints/step_57500.pt \
--thinking --adaptive_think \
--prompt "Explain how neural networks learn" \
--length 300 --answer_steps 20 --verboseAuto-length (recommended) — stops at EOS:
python sample.py \
--checkpoint checkpoints/step_57500.pt \
--block --auto_length \
--prompt "What is the mitochondria?" \
--max_length 2048Block diffusion — for outputs longer than the training context:
python sample.py \
--checkpoint checkpoints/step_57500.pt \
--block --block_size 256 --steps 20 \
--prompt "Write a detailed explanation of" \
--length 2048| Flag | Default | Description |
|---|---|---|
--steps |
20 | Denoising steps (more = better quality, slower) |
--temperature |
0.9 | Higher = more creative |
--top_p |
0.95 | Nucleus sampling cutoff |
--num_samples |
1 | Generate N independent samples |
--thinking |
False | Enable thinking phase |
--adaptive_think |
False | Stop thinking when tokens converge |
--max_think_steps |
128 | Hard cap on thinking steps |
--think_change_threshold |
0.02 | Convergence threshold (2%) |
--think_patience |
3 | Consecutive below-threshold steps to stop |
--auto_length |
False | Stop at EOS automatically |
--max_length |
2048 | Hard cap for auto-length mode |
--block |
False | Use block diffusion for long generation |
--block_size |
256 | Tokens per block |
Resume from a pretrained checkpoint with a lower learning rate:
python train.py \
--resume_from checkpoints/step_57500.pt \
--train_data "data/finetune/train/*.jsonl" \
--val_data "data/finetune/val/*.jsonl" \
--lr 2e-5 \
--max_steps 5000 \
--warmup_steps 200Data should follow the same {"text": "..."} JSONL format. For instruction tuning,
concatenate the turn into a single string:
{"text": "User: What is the mitochondria?\nAssistant: The mitochondria is the powerhouse of the cell."}Knowledge distillation (recommended): Use a teacher model (e.g. Claude Haiku, GPT-4o-mini) to generate completions for a large prompt set, then SFT on those completions. ~100K examples costs roughly $20–50 in API fees and yields significant quality improvement.
Export to a portable safetensors package:
python export.py \
--checkpoint checkpoints/step_57500.pt \
--output_dir exports/bitdiffusion-1b \
--format safetensors \
--tokenizer Qwen/Qwen-tokenizerProduces:
model.safetensors— model weightsmodel_config.json— serializedModelConfigexport_metadata.json— checkpoint and export metadata- tokenizer files
Standard GGUF runtimes (llama.cpp, etc.) cannot run BitDiffusion — it is a bidirectional diffusion model, not an autoregressive decoder. Use
safetensorsand build a custom runtime if needed.
Training keeps full-precision latent weights and quantizes them on every forward pass via straight-through estimator — the model on disk is a regular float checkpoint. Inference, by default, simulates the same quantization in float and gets no speedup.
The packed-inference path replaces this simulation with a real INT4 × 2-bit ternary compute kernel:
| Training | Default inference (float-sim) | Packed inference | |
|---|---|---|---|
| Weight dtype on disk | fp32 latent | fp32 latent | uint8 (2 bits/param) |
| Activation compute | float | float (rounded) | INT8 dot-product |
| Weight bytes | 4×params | 4×params | params/4 (16× smaller than fp16) |
| Speedup vs fp16 | n/a | none | hardware-dependent (Triton kernel) |
| Trainable | yes | yes | no |
python export.py \
--checkpoint checkpoints/step_57500.pt \
--output_dir exports/packed \
--format safetensors \
--tokenizer Qwen/Qwen-tokenizer \
--pack--pack runs pack_for_inference() before serializing, drops every
latent_weight tensor, and emits w_packed + scale_w per BitLinear. The
exported file is roughly 16× smaller than an fp16 export. The metadata
file gains "packed": true.
If your checkpoint isn't pre-packed, do it once after load_state_dict:
model = BitDiffusionTransformer(cfg)
model.load_state_dict(ckpt["model_state_dict"])
model.eval().pack_for_inference()
# every BitLinear in attention + FFN is now packed; MoE FFNs stack
# their per-expert weights into a single grouped-matmul tensor.BitLinear._load_from_state_dict auto-detects packed exports — if the
state dict has w_packed (and no latent_weight), the layer flips into
packed mode automatically:
sd = load_file("exports/packed/model.safetensors") # or torch.load(...)
model = BitDiffusionTransformer(cfg)
model.load_state_dict(sd) # works without code changes
# For MoE models, re-stack the expert weights:
if any(isinstance(m, BitMoEFFN) for m in model.modules()):
model.pack_for_inference() # idempotent for already-packed BitLinearsbitdiffusion.kernels.packed_ternary_linear picks a backend per call:
| Device | Backend |
|---|---|
CUDA / ROCm with triton installed |
Autotuned Triton kernel (INT8 tl.dot → INT32 accumulator) |
Intel XPU with triton (via intel-extension-for-pytorch) |
Same Triton kernel |
| CPU | torch._int_mm if available, otherwise int32 torch.mm (correctness, not throughput) |
Anywhere triton import fails |
Silent fallback to the CPU path |
The MoE path uses a fused grouped kernel: tokens are permuted by their
assigned expert, the per-expert packed weights are stacked into a
(n_experts, out, in_padded//4) tensor, and one kernel handles the whole
ragged batch instead of n_experts × top_k_experts separate launches.
- Training must stay on the float-sim path. Never call
pack_for_inference()during training — packed BitLinears are not differentiable andlatent_weightis deleted to free memory. - MoE bit-equivalence requires no token drops. The grouped path uses
vectorized capacity dropping, while the unpacked Python loop uses
first-come-first-served per
(top_k_slot, expert). Setexpert_capacity_factorhigh enough that no drops occur if you need exact bit-equivalence with the training-time forward. - Numerical drift.
topk_int8activation quantization is sensitive to tiny FP noise fromint_mm-vs-float matmul ordering, so end-to-end outputs can drift ~1% relative even though every individualBitLinearis bit-perfect against the float-sim path.
scripts/bench_packed_linear.py compares an FP16 reference, the
float-sim packed path, and the real packed path for a sweep of shapes.
CUDA-only — exits cleanly on CPU machines.
python scripts/bench_packed_linear.py --shapes 768,1024,2048,4096
python scripts/bench_packed_linear.py --batch 1 --seq 1024Numbers depend heavily on the GPU SKU; this repo does not ship pre-measured throughput tables. Run the script on your hardware to validate.
bitdiffusion/
├── model.py # BitLinear, BitAttention, BitFFN, BitMoEFFN, BitDiffusionTransformer
├── rdt.py # BitRDTTransformer — Recurrent-Depth Transformer variant
├── quantization.py # HybridQuantizer, KVCache, TurboQuant rotation, absmax/TopK
├── kernels.py # 2-bit pack/unpack, INT4×ternary Triton kernel + CPU fallback,
│ # grouped MoE-expert kernel, AOT compile probe
├── diffusion.py # CosineSchedule, MaskDiffusionLoss, masking utilities
├── data.py # StreamingJsonlDataset, variable-length chunking, DataLoader
├── train.py # Training loop, TrainConfig, ActivationSchedule, main()
├── sample.py # ThinkingDiffusionSampler, BlockDiffusionSampler, auto-length
├── export.py # Checkpoint export to safetensors / PyTorch (with --pack)
└── utils.py # BitStats, checkpoint save/load, logging, WandB wrapper
scripts/
└── bench_packed_linear.py # GPU benchmark: fp16 vs float-sim vs real packed
prepare_hf_jsonl.py # 40B token data pipeline (HuggingFace streaming)
train.py # CLI entry point for bitdiffusion.train
sample.py # CLI entry point for bitdiffusion.sample
export.py # CLI entry point for bitdiffusion.export
This repo trains a 1.7B model on a single A100 40GB for ~$200. To scale:
| Target | Change |
|---|---|
| Longer context (8K) | --max_seq_len 8192 --batch_size 4 |
| Longer context (32K+) | Multi-GPU cluster, sparse attention |
| Larger model (3B) | --hidden_dim 2560 --n_layers 32 |
| Larger model (7B) | --hidden_dim 4096 --n_layers 32 |
| Multi-GPU | torchrun --nproc_per_node=N train.py (DDP ready) |
| MoE variant | --use_moe --n_experts 8 --top_k_experts 2 |
Flash Attention (F.scaled_dot_product_attention) scales memory linearly with
sequence length — compute, not VRAM, is the bottleneck at long context.
-
BitNet b1.58 — Ma et al. (Microsoft Research, 2024). The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits. arXiv:2402.17764
-
BitNet a4.8 — Wang et al. (Microsoft Research, 2024). BitNet a4.8: 4-bit Activations for 1-bit LLMs. arXiv:2411.04965
-
MDLM — Sahoo et al. (2024). Simple and Effective Masked Diffusion Language Models. arXiv:2406.07524
-
SEDD — Lou et al. (2024). Discrete Diffusion Modeling by Estimating the Ratios of the Data Distribution. arXiv:2310.16834
- TurboQuant — Zandieh, Daliri, Hadian, Mirrokni (Google Research / Google DeepMind, 2025).
TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate.
ICLR 2026. arXiv:2504.19874
— 3-bit KV cache quantization via random rotation (PolarQuant) + 1-bit
Johnson-Lindenstrauss residual. Implemented in
quantization.pyKVCache.
-
Flash Attention 2 — Dao (2023). FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning. arXiv:2307.08691
-
RoPE — Su et al. (2021). RoFormer: Enhanced Transformer with Rotary Position Embedding. arXiv:2104.09864
-
SwiGLU — Shazeer (2020). GLU Variants Improve Transformer. arXiv:2002.05202
-
RMSNorm — Zhang & Sennrich (2019). Root Mean Square Layer Normalization. arXiv:1910.07467
-
Chinchilla — Hoffmann et al. (DeepMind, 2022). Training Compute-Optimal Large Language Models. arXiv:2203.15556
-
FineWeb-Edu — Penedo et al. (HuggingFace, 2024). The FineWeb Datasets: Decanting the Web for the Finest Text Data at Scale. arXiv:2406.17557
-
StarCoder — Li et al. (BigCode, 2023). StarCoder: may the source be with you! arXiv:2305.06161
-
PLAID — Gulrajani & Hashimoto (2024). Likelihood-Based Diffusion Language Models. arXiv:2305.18619
-
Mercury — Inception Labs (2025). Commercial masked diffusion LM demonstrating production viability of diffusion-based text generation at scale.
-
OpenMythos — Gomez (2025). Recurrent-Depth Transformer. Basis for the
BitRDTTransformervariant inrdt.py. github.com/kyegomez/OpenMythos
- Model weights: BigCode OpenRAIL-M v1.0 (use restrictions apply — see LICENSE)
- Source code: Apache 2.0
- Training data: Mixed licenses — see individual dataset cards. StarCoderData (BigCode OpenRAIL-M) is the most restrictive source in the mix, which is why model weights carry the OpenRAIL-M terms.