Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# LeakyReLU² + XSA-all + Full GPTQ + 5-gram Backoff

**val_bpb: 1.0340** (3-seed mean: 1337→1.0342, 42→1.0340, 7→1.0338)

## Architecture

- 11 transformer layers, dim=512, 8 heads, 4 KV heads (GQA)
- LeakyReLU(0.5)² MLP with 2x expansion
- RoPE, RMSNorm, tied embeddings (vocab=1024), logit softcapping (30.0)
- U-Net skip connections with learned skip weights
- SmearGate + BigramHash embedding augmentation
- XSA (cross-sequence attention) on all 11 layers
- ~27M parameters

## Training

- Muon optimizer for matrix params, Adam for scalars/embeddings
- EMA (decay=0.997) + Tight SWA
- Late QAT (int6 quantization-aware training)
- Full GPTQ: Hessian-based int6 quantization with Cholesky error compensation (32 calibration batches on EMA model)
- Compression: zstd-22
- Training time: ~600s on 8xH100, ~5250 steps at 114ms/step

## Evaluation

- Sliding window eval at stride=64, seq_len=2048
- **5-gram multi-order backoff**: cascade 5→4→3→2-gram lookup with separate hash tables per order
- **Entropy-adaptive alpha**: alpha = 0.05 + 0.35 * sigmoid(2*(H-4.0)), where H is model entropy
- Low entropy (confident model): alpha ≈ 0.05, trust model
- High entropy (uncertain model): alpha ≈ 0.40, trust n-gram cache
- Score-first protocol: each token scored before its n-gram is added to cache
- Hash tables: 4M buckets per order, uint32 counts, min_count=2

## Results

| Seed | Sliding BPB | N-gram BPB |
|------|-------------|------------|
| 1337 | 1.1273 | 1.0342 |
| 42 | 1.1272 | 1.0340 |
| 7 | 1.1269 | 1.0338 |
| **Mean** | **1.1271** | **1.0340** |

Artifact size: 15,903,061 bytes (< 16,000,000)

## Reproduction

```bash
SEED=1337 GPTQ_CALIB_BATCHES=32 \
NGRAM_EVAL_ORDER=5 NGRAM_BACKOFF=1 NGRAM_ENTROPY_ADAPTIVE=1 \
NGRAM_ALPHA_LOW=0.05 NGRAM_ALPHA_HIGH=0.40 NGRAM_ENTROPY_THRESH=4.0 \
torchrun --nproc_per_node=8 train_gpt.py
```

## Key Techniques

1. **LeakyReLU(0.5)²**: Replaces relu² with leaky variant (negative slope 0.5), providing better gradient flow while maintaining sparsity from squaring
2. **XSA-all**: Extended cross-sequence attention from last 4 layers to all 11
3. **Full GPTQ**: Hessian-based quantization with actorder and Cholesky error compensation, calibrated on training data within the training budget
4. **N-gram backoff**: Multi-order cascade (5→4→3→2) with separate tables per order, using entropy-adaptive mixing weights
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
W0325 23:48:06.806000 127487562527360 torch/distributed/run.py:779]
W0325 23:48:06.806000 127487562527360 torch/distributed/run.py:779] *****************************************
W0325 23:48:06.806000 127487562527360 torch/distributed/run.py:779] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 23:48:06.806000 127487562527360 torch/distributed/run.py:779] *****************************************
logs/val_seed1337.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26993756
mtp_num_heads:0 mtp_loss_weight:0.2 mtp_params:0
XSA:last_11 active_layers:[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
world_size:8 grad_accum_steps:1
sdp_backends:cudnn=False flash=True mem_efficient=False math=False
attention_mode:gqa num_heads:8 num_kv_heads:4
tie_embeddings:True embed_lr:0.035 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9271 val_bpb:4.1026 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9300 train_time:167ms step_avg:166.88ms
step:2/20000 train_loss:8.3830 train_time:276ms step_avg:138.11ms
step:3/20000 train_loss:7.5823 train_time:389ms step_avg:129.56ms
step:4/20000 train_loss:8.2829 train_time:501ms step_avg:125.33ms
step:5/20000 train_loss:8.5193 train_time:614ms step_avg:122.75ms
step:6/20000 train_loss:8.3598 train_time:726ms step_avg:121.01ms
step:7/20000 train_loss:7.8053 train_time:838ms step_avg:119.78ms
step:8/20000 train_loss:7.1608 train_time:951ms step_avg:118.87ms
step:9/20000 train_loss:6.6268 train_time:1064ms step_avg:118.17ms
step:10/20000 train_loss:6.2498 train_time:1177ms step_avg:117.66ms
step:500/20000 train_loss:2.4041 train_time:56996ms step_avg:113.99ms
step:1000/20000 train_loss:2.2636 train_time:114220ms step_avg:114.22ms
step:1500/20000 train_loss:2.2091 train_time:171392ms step_avg:114.26ms
step:2000/20000 train_loss:2.0457 train_time:228549ms step_avg:114.27ms
step:2500/20000 train_loss:2.1407 train_time:285742ms step_avg:114.30ms
step:3000/20000 train_loss:2.1218 train_time:342938ms step_avg:114.31ms
step:3500/20000 train_loss:2.1256 train_time:400124ms step_avg:114.32ms
step:4000/20000 train_loss:1.9124 train_time:457301ms step_avg:114.33ms
step:4000/20000 val_loss:2.0022 val_bpb:1.1858 train_time:457305ms step_avg:114.33ms
step:4500/20000 train_loss:2.0560 train_time:514458ms step_avg:114.32ms
swa:start step:4550
late_qat:enabled step:4721 scale:0.1498
step:5000/20000 train_loss:2.0292 train_time:572098ms step_avg:114.42ms
step:5243/20000 val_loss:1.9384 val_bpb:1.1481 train_time:600045ms step_avg:114.45ms
stopping_early: wallclock_cap train_time:600045ms step:5243/20000
peak memory allocated: 27940 MiB reserved: 29072 MiB
ema:applying EMA weights
gptq:calibrating batches=32
gptq:done layers=68 time=6930ms
DIAGNOSTIC post_ema val_loss:1.9373 val_bpb:1.1474 eval_time:2464ms
Serialized model: 106178100 bytes
Code size: 74639 bytes
Serialized model int6+zstd: 15903061 bytes
Total submission size int6+zstd: 15977700 bytes
Total submission size int8+zlib: 15977700 bytes
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
/workspace/parameter-golf/train_gpt.py:1431: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.
quant_state = torch.load(
final_int6_roundtrip val_loss:1.9433 val_bpb:1.1509 eval_time:52638ms
final_int6_roundtrip_exact val_loss:1.94329308 val_bpb:1.15092762
final_int6_sliding_window val_loss:1.9033 val_bpb:1.1273 stride:64 eval_time:119760ms
final_int6_sliding_window_exact val_loss:1.90334860 val_bpb:1.12727324
final_ngram val_loss:1.7462 val_bpb:1.0342 order:5 alpha:0.2 backoff:True adaptive:True time:353792ms
Loading