Skip to content

Add non-record JEPA byte-level encoder-decoder submission#696

Open
gravelBridge wants to merge 28 commits intoopenai:mainfrom
gravelBridge:main
Open

Add non-record JEPA byte-level encoder-decoder submission#696
gravelBridge wants to merge 28 commits intoopenai:mainfrom
gravelBridge:main

Conversation

@gravelBridge
Copy link

Non-record submission for the 16MB track using a JEPA (Joint Embedding Predictive Architecture) encoder-decoder as an alternative to the standard causal GPT used by all current leaderboard entries.

  • Architecture: Byte-level tokenizer (vocab 260, no BPE), 5-layer depth-recurrent encoder (2 repeats) with patched latent projection, 7-layer causal decoder conditioned on encoder latents. 24.6M parameters.
  • Quantization & Compression: INT6 optimal-clip quantization with STE QAT during warmdown, LZMA preset 9 compression. - Total submission: 15.7MB.
  • Test-Time Training: Sliding window TTT with SGD (lr=0.002), 2 epochs per chunk, stride 256.
  • Result: Final TTT val_bpb 1.2622 (pre-quantization 1.2957), trained for 10,635 / 20,000 steps.

Replace the isolated per-patch decoder (8-byte window with no cross-patch
information flow) with a full-sequence causal decoder over all bytes.
Each byte can now attend to all preceding bytes across patch boundaries,
with patch-level context upsampled and added as conditioning. This removes
the critical information bottleneck where byte predictions at patch
boundaries had no access to preceding bytes from other patches.
Shift parameter budget from the encoder to the decoder, where val_bpb
is determined. Encoder goes from 10 unique blocks to 5 unique blocks
cycled 2x (same 10 effective layers, half the unique params). Decoder
grows from 2 to 6 layers, tripling capacity for byte-level prediction.
Total unique params drops ~1.5M but decoder gets ~4M more.
…atents

Replace the decoder's conditioning signal from pred_latent (predictor's
noisy estimate, routed through a latent_dim bottleneck) with the encoder's
context output directly (model_dim, shifted by 1 patch for causality).
The JEPA predictor + MSE loss remain as an auxiliary training objective,
but the decoder now receives the exact encoder representations instead
of a noisy compressed proxy. Removes decoder_cond projection layer
since context is already model_dim.
The JEPA prediction loss had 2x the gradient weight of the actual
compression objective (CE). Flip the ratio: CE gets 3x weight, pred
gets 0.5x. This directs more gradient signal toward byte-level
prediction quality, with JEPA serving as a lighter regularizer.
Current compressed model is only 9MB of the 16MB limit. Increase
model_dim from 384 to 480 and decoder_layers from 6 to 8, bringing
total params from ~14.7M to ~26.4M (compressed ~15.8MB). Nearly all
the extra capacity goes to the decoder where val_bpb is determined.
Sliding window eval scores each byte with near-maximum context.
Windows of seq_len advance by stride (default 512 bytes = 64 patches).
Only the tail stride bytes per window are scored (first window scores
all). Adds forward_logits() method that returns per-position logits
without computing loss. Only the final int8+zlib roundtrip eval uses
sliding window; periodic training eval stays fast (non-overlapping).
- MLP activation: relu² → LeakyReLU(0.5)² (matching SOTA)
- EMA weight averaging (decay=0.997) applied before serialization
- SWA snapshots collected every 50 steps when lr scale < 0.2
- Test-time training: score-first legal TTT with SGD (lr=0.002,
  momentum=0.9, 3 epochs, 32K chunks, cosine LR decay)
- Eval stride reduced to 64 (matching SOTA)
SWA snapshots collected during warmdown are now averaged and applied
instead of being discarded. TTT adaptation uses forward_logits + CE
loss directly, avoiding unnecessary prediction/SIGReg gradient signal.
Replace INT8+zlib with mixed INT6/INT8+LZMA to reduce serialized model
size. MLP/attn/other weights use INT6 ([-31,31]) with per-row MSE-optimal
clip search; embeddings stay INT8. Add STE quantization-aware training
that activates during warmdown (LR scale < 0.15). Switch compression
from zlib to LZMA for better entropy exploitation on low-range values.

Also bump eval batch defaults (val_sliding_batch, ttt_batch_seqs) from
8 to 32 to match SOTA, and add infer/adapt timing breakdown to TTT logs.
Halves context length to ~4x reduce attention cost per forward pass,
making sliding window eval and TTT feasible within time budget.
524032 = 2047 * 32 * 8, ensuring divisibility across 8 GPUs.
Bump LZMA compression from preset 6 to 9 and quantize embeddings
to INT6 (previously INT8). Previous run was 363KB over budget.
TTT was catastrophically diverging (bpb 1.24 -> 2.53) because
sequential adaptation was destroying the JEPA encoder. Now only
decoder parameters adapt during TTT. Also disable QAT during TTT
to avoid injecting quantization noise on already-dequantized weights.
Revert encoder freeze (divergence was from QAT during TTT, not encoder
adaptation). Increase sliding window stride from 64 to 256 and reduce
TTT epochs from 3 to 1 for faster eval.
Each decoder block is ~1.84M params (~1.1MB compressed). Previous run
was 572KB over the 16MB limit. Dropping one decoder layer should
provide enough headroom.
Drop eval_val_sliding (was burning 88s for a diagnostic log). TTT
already does sliding window evaluation with adaptation. Bump TTT
epochs from 1 to 2 with the freed time budget.
Match the openai#1 record's optimizer settings:
- MATRIX_LR/SCALAR_LR: 0.015 -> 0.025
- MUON_MOMENTUM: 0.95 -> 0.99
- MUON_MOMENTUM_WARMUP_START: 0.85 -> 0.92
- MUON_MOMENTUM_WARMUP_STEPS: 500 -> 1500
- WARMDOWN_ITERS: 1200 -> 3500
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant