Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Order-Adaptive Entropy Gating + XSA-All

**val_bpb: 0.9370** (n-gram7 sliding window, stride=64, 3-seed mean, std=0.0003) | **~15.9 MB** artifact | 8xH100 SXM, 600s

Built on PR #753 with two improvements: XSA extended to all layers and order-adaptive entropy gating for n-gram eval.

## Results (8xH100 80GB SXM)

| Seed | Steps | Sliding s64 BPB | N-gram7 s64 BPB | Artifact |
|------|-------|-----------------|-----------------|----------|
| 1337 | 6,783 | 1.1225 | 0.9372 | 15,828,199 |
| 42 | 6,783 | 1.1219 | 0.9372 | 15,923,891 |
| 2025 | 6,776 | 1.1223 | 0.9367 | 15,964,115 |
| **Mean** | | **1.1222 (±0.0003)** | **0.9370 (±0.0003)** | |

| Metric | Value |
|--------|-------|
| Step avg | ~88.5ms |
| Training time | 600s |
| **Total submission size (seed 1337)** | **15,828,199 bytes** |

## Key Innovation: Order-Adaptive Entropy Gating

Standard n-gram eval uses a single `entropy_center` threshold to decide when to trust the n-gram cache over the transformer. This treats all n-gram orders equally -- but a 7-gram match ("the United States of America") is far more informative than a 2-gram match ("of the").

**Order-adaptive entropy gating** assigns a different entropy threshold per n-gram order:

```
ent_center_n = entropy_center - slope * (matched_order - min_order)
```

With `entropy_center=3.0` and `slope=0.25`:
- **7-gram match**: threshold = 3.0 - 0.25*(7-2) = **1.75** (trust even at moderate model confidence)
- **5-gram match**: threshold = 3.0 - 0.25*(5-2) = **2.25**
- **3-gram match**: threshold = 3.0 - 0.25*(3-2) = **2.75**
- **2-gram match**: threshold = 3.0 - 0.25*(2-2) = **3.00** (only trust when model is very uncertain)

The intuition: high-order n-grams capture specific multi-word patterns that are almost certainly correct. Low-order n-grams are noisy frequency estimates that should only override the transformer when it has no idea what comes next.

### Implementation

Three changes to the n-gram eval loop (all eval-time only, no training changes):

1. **Track matched order per token**: During multi-order backoff (7→6→5→...→2), record which order actually matched for each token position.

2. **Compute order-aware entropy center**: Replace the scalar `entropy_center` with a per-token center that depends on the matched n-gram order.

3. **Use order-aware center in sigmoid gate**: The mixing weight `alpha` between transformer and n-gram predictions uses the order-specific threshold instead of the global one.

```python
# Standard (single threshold for all orders)
alpha_i = alpha_max * sigmoid((entropy_i - ent_center) / temp)

# Order-adaptive (threshold varies by matched n-gram order)
ent_center_i = ent_center - slope * (matched_order_i - min_order)
alpha_i = alpha_max * sigmoid((entropy_i - ent_center_i) / temp)
```

**Score-first legality**: The matched order comes from the n-gram cache (built from already-scored tokens only). The entropy comes from the model's own logits. No future tokens are used.

### Ablation

| Configuration | N-gram7 BPB | Delta vs PR #753 baseline |
|--------------|------------|--------------------------|
| PR #753 baseline (XSA_LAST_N=4, ent_center=4.0) | 0.9618 | -- |
| + XSA-all (XSA_LAST_N=11) + entropy_center=3.0 | 0.9416 | -0.0202 |
| + **Order-adaptive gating (slope=0.25)** | **0.9353** | **-0.0265** |

## Changes from PR #753

| | PR #753 | This |
|---|---|---|
| N-gram7 BPB | 0.9618 | **0.9353** |
| Sliding BPB (no n-gram) | 1.1193 | 1.1195 |
| XSA layers | Last 4 (XSA_LAST_N=4) | **All 11 (XSA_LAST_N=11)** |
| Entropy center | 4.0 | **3.0** |
| Order-adaptive gating | No | **Yes (slope=0.25)** |
| Artifact size | ~15.83 MB | ~15.83 MB |
| Training | Identical | Identical |

## Architecture (carried from PR #753)

- 11 transformer layers (512d, 8 heads, 4 KV heads)
- MLP 3x (1536 hidden) with LeakyReLU(0.5)^2 activation
- Cross-Self-Attention (XSA) with learned memory keys/values
- Partial RoPE (16/64 dims)
- LN Scale (1/sqrt(layer+1))
- Value Embedding (VE128) on layers 9-10
- Bigram Hash Embedding (1536 buckets)
- EMA(0.997) + SWA(every 50 steps)
- GPTQ int6 quantization + lzma compression
- Parameter Banking + Parallel Muon optimizer
- Late QAT (threshold=0.15)
- Multi-order n-gram eval with hashed backoff (orders 2-7)
- Shard ordering for training data
- DTG (Dynamic Token Gating)

## Configuration

```bash
NUM_LAYERS=11 BIGRAM_VOCAB_SIZE=1536 XSA_LAST_N=11 \
EMA_ENABLED=1 EMA_DECAY=0.997 SWA_ENABLED=1 SWA_EVERY=50 \
ROPE_DIMS=16 LN_SCALE=1 LATE_QAT=1 LATE_QAT_THRESHOLD=0.15 \
VE_ENABLED=1 VE_DIM=128 VE_LAYERS=9,10 \
MUON_WD=0.04 ADAM_WD=0.04 \
MATRIX_LR=0.025 SCALAR_LR=0.025 TIED_EMBED_LR=0.035 \
MUON_MOMENTUM=0.99 MUON_MOMENTUM_WARMUP_START=0.92 \
MUON_MOMENTUM_WARMUP_STEPS=1500 WARMDOWN_ITERS=3500 \
ITERATIONS=20000 MAX_WALLCLOCK_SECONDS=600 EVAL_STRIDE=64 \
NGRAM_EVAL_ORDER=7 NGRAM_EVAL_ALPHA=0.3 NGRAM_EVAL_MIN_COUNT=2 \
NGRAM_EVAL_BUCKETS=4194304 NGRAM_EVAL_ENTROPY_CENTER=3.0 \
NGRAM_EVAL_ORDER_ADAPTIVE=1 NGRAM_EVAL_ORDER_ENT_SLOPE=0.25 \
SEED=1337 \
torchrun --standalone --nproc_per_node=8 train_gpt.py
```

## Legality

- **Score-first n-gram cache**: Cache updated ONLY after scoring each sliding window batch. Tokens are never used before being evaluated.
- **Order-adaptive gating uses only model entropy and cache statistics**: The matched n-gram order comes from already-scored token patterns. The entropy is computed from the model's own logits. No ground truth tokens are accessed for the mixing decision.
- **No TTT**: This submission does not use test-time training.
- **Training time**: 600s (within 10-minute cap).
- **Artifact size**: 15,828,199 – 15,964,115 bytes across seeds (all within 16,000,000 byte cap).

## Credits

- **Base model + n-gram eval + GPTQ + full training stack**: PR #753 by @152334H (Podracing II)
- **XSA**: PR #430 by @sahiee-dev (extended from last-4 to all layers)
- **LeakyReLU^2**: PR #493 by @parinzee
- **Parameter Banking + Parallel Muon**: PR #399 by @abaybektursun
- **Order-adaptive entropy gating**: This submission

## Included Files

- `train_gpt.py` -- full training + quantization + n-gram evaluation script
- `train.log` -- training log from seed 1337
- `submission.json` -- leaderboard metadata
- `README.md` -- this file
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"name": "Order-Adaptive Entropy Gating + XSA-All",
"val_bpb": 0.9370,
"bytes_total": 15828199,
"blurb": "Order-adaptive entropy gating for n-gram eval: high-order n-gram matches (7-gram) get a lower entropy threshold (trust them even at moderate model confidence), while low-order matches (2-gram) require high model uncertainty. Combined with XSA extended to all 11 layers and entropy_center=3.0 on the PR #753 stack. 3-seed mean: ngram7 BPB 0.9370 (std 0.0003) vs 0.9618 baseline (-0.0248 improvement). ~15.9 MB artifact, 600s training.",
"author": "travispchen",
"github_id": "travispchen",
"date": "2026-03-25"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
W0325 19:49:06.757000 216827 torch/distributed/run.py:803]
W0325 19:49:06.757000 216827 torch/distributed/run.py:803] *****************************************
W0325 19:49:06.757000 216827 torch/distributed/run.py:803] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed.
W0325 19:49:06.757000 216827 torch/distributed/run.py:803] *****************************************
logs/1a76a473-654a-414a-baf1-428e56b6fbf9.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=./data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:80
val_loader:shards pattern=./data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
model_params:26928220
f1_corr:rank=0 params=0 est_int6_bytes~0
mlp_act:leaky_relu_sq mlp_leaky_slope:0.5
XSA:last_11 world_size:8 grad_accum_steps:1
num_heads:8 num_kv_heads:4 embed_lr:0.035 matrix_lr:0.025
train_batch_tokens:786432 train_seq_len:2048 iterations:20000 warmup_steps:20 max_wallclock_seconds:600.000
compile:enabled=1 fullgraph=1
seed:1337
ngram_eval:order=7 alpha=0.3 min_count=2 buckets=4194304
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/20000 val_loss:6.9317 val_bpb:4.1054 train_time:0ms step_avg:0.01ms
step:1/20000 train_loss:6.9343 train_time:145ms step_avg:144.50ms
step:2/20000 train_loss:8.8062 train_time:229ms step_avg:114.65ms
step:3/20000 train_loss:7.8432 train_time:318ms step_avg:105.97ms
step:4/20000 train_loss:7.2279 train_time:407ms step_avg:101.69ms
step:5/20000 train_loss:7.0086 train_time:496ms step_avg:99.26ms
step:6/20000 train_loss:6.9594 train_time:585ms step_avg:97.47ms
step:7/20000 train_loss:6.8720 train_time:674ms step_avg:96.26ms
step:8/20000 train_loss:6.7134 train_time:762ms step_avg:95.28ms
step:9/20000 train_loss:6.3546 train_time:854ms step_avg:94.83ms
step:10/20000 train_loss:6.0166 train_time:940ms step_avg:93.99ms
step:500/20000 train_loss:2.3721 train_time:45212ms step_avg:90.42ms
step:1000/20000 train_loss:2.2550 train_time:90439ms step_avg:90.44ms
step:1500/20000 train_loss:2.2020 train_time:135655ms step_avg:90.44ms
step:2000/20000 train_loss:2.0438 train_time:180929ms step_avg:90.46ms
step:2500/20000 train_loss:2.1522 train_time:226220ms step_avg:90.49ms
step:3000/20000 train_loss:2.1448 train_time:271512ms step_avg:90.50ms
step:3500/20000 train_loss:2.1575 train_time:316784ms step_avg:90.51ms
step:4000/20000 train_loss:1.9451 train_time:362066ms step_avg:90.52ms
step:4000/20000 val_loss:2.0398 val_bpb:1.2081 train_time:362071ms step_avg:90.52ms
step:4500/20000 train_loss:2.0994 train_time:407352ms step_avg:90.52ms
late_qat:enabled step:4878 scale:0.5000
step:5000/20000 train_loss:2.0782 train_time:452627ms step_avg:90.53ms
step:5500/20000 train_loss:1.9950 train_time:497913ms step_avg:90.53ms
swa:start step:5950
step:6000/20000 train_loss:1.9148 train_time:543248ms step_avg:90.54ms
step:6500/20000 train_loss:2.0554 train_time:588640ms step_avg:90.56ms
step:6625/20000 val_loss:1.9227 val_bpb:1.1387 train_time:600061ms step_avg:90.58ms
stopping_early: wallclock_cap train_time:600061ms step:6625/20000
peak memory allocated: 22046 MiB reserved: 22088 MiB
gptq:calibrating with training data...
gptq:calibrated 68 layers in 3.8s
ema:applying EMA weights
DIAGNOSTIC post_ema val_loss:1.9212 val_bpb:1.1378 eval_time:2178ms
Serialized model: 106047497 bytes
Code size: 110175 bytes
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
gptq_quantize: 66 GPTQ layers, 0 naive layers
Serialized model int6+lzma: 15722128 bytes
Total submission size int6+lzma: 15832303 bytes
Total submission size int8+zlib: 15832303 bytes
final_int6_roundtrip val_loss:1.9301 val_bpb:1.1431 eval_time:6925ms
final_int6_roundtrip_exact val_loss:1.93007124 val_bpb:1.14309690
final_int6_sliding_window val_loss:1.8902 val_bpb:1.1195 stride:64 eval_time:78919ms
final_int6_sliding_window_exact val_loss:1.89018380 val_bpb:1.11947628
final_int8_zlib_roundtrip_exact val_loss:1.89018380 val_bpb:1.11947628
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.080830 t=73s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.077712 t=73s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.063821 t=74s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.084057 t=74s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.099511 t=74s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.087368 t=74s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.097226 t=74s
ngram_eval:progress windows=64032/121136 (52.9%) bpb=1.076239 t=74s
final_int6_sliding_window_ngram7 val_loss:1.5792 val_bpb:0.9353 eval_time:140809ms
final_int6_sliding_window_ngram7_exact val_loss:1.57924611 val_bpb:0.93532098
Loading