A research implementation of Hybrid GPT-2 + 2-Simplicial Attention for token-efficient reasoning on math benchmarks.
This project explores whether a smaller model with simplicial (trilinear) attention can match or exceed larger standard transformer models on reasoning tasks, particularly GSM8K and MATH benchmarks.
- 2-Simplicial Attention: Trilinear attention that captures three-way token relationships
- Placement at End: Simplicial layers at the end of the network as "reasoning head"
- Adapter Layer: Lightweight adapter (layer 4) with small LR for fine-tuning
- Optional MoE: Mixture of Simplicial Experts for enhanced reasoning
Standard attention: A = softmax(Q · K)
Simplicial attention: A = softmax(Q · K₁) ⊙ softmax(Q · K₂) · V
| Layer | Type | Trainable | LR |
|---|---|---|---|
| 0-3 | Standard (GPT-2 init) | ❌ Frozen | - |
| 4 | Adapter | ✅ | 1e-4 |
| 5-7 | Simplicial | ✅ | 3e-4 |
Total: ~100M params, ~72M trainable
too-simplex/
├── simplicial/ # Model implementation
│ ├── attention/ # Simplicial attention, RoPE, standard attention
│ ├── cache/ # KV cache management
│ ├── layers/ # SimplicialBlock, StandardBlock, MoE
│ └── models/ # SimplicialModel
├── training/
│ ├── configs/ # Training configs
│ ├── datasets/ # Dataset loading
│ ├── plateau_handler.py # Plateau detection
│ ├── synthetic_data_generator.py # Synthetic math data
│ └── logic_trainer.py # Gradual unfreezing (legacy)
├── scripts/
│ ├── train/ # Training scripts (train_hybrid.py)
│ ├── data/ # Dataset download scripts
│ └── evaluate_checkpoint.py # Evaluation
├── data/logic/ # Training data (GSM8K, NuminaMath)
└── tests/ # Test suite
pip install -r requirements.txt
# Install PyTorch (for your CUDA version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121python scripts/data/download_datasets.py --datasets gsm8k numinamath math500python scripts/train/train_hybrid.py \
--config training/configs/simplicial_final_layers.yaml \
--checkpoint-dir checkpoints_v2 \
--log-dir logs_v2Key parameters in training/configs/simplicial_final_layers.yaml:
model:
n_layers: 8
num_last_layers_simplicial: 3 # Last 3 are simplicial
window_size_1: 1024 # Paper: 1024
window_size_2: 32 # Paper: 32
moe: false # Enable for MoE boost
training:
max_steps: 30000
warmup_steps: 2000
lr_simplicial: 3e-4
lr_adapter: 1e-4
eval_every: 2500
save_every: 2500| Phase | Steps | LR Schedule | Action |
|---|---|---|---|
| 0: Warmup | 0-2k | Linear 0→3e-4 | Warmup |
| 1: Main | 2k-10k | Cosine to 1e-6 | Monitor val |
| 2: MoE | 10k-30k | Cosine | Optional: enable MoE |
| 3: Final | 30k+ | - | Export model |
- ✅ Adapter Layer: Layer 4 with small LR (1e-4)
- ✅ Plateau Detection: Auto-inject synthetic data on plateau
- ✅ Synthetic Data: 5K generated math problems
- ✅ MoE Toggle: Optional mixture of simplicial experts
- ✅ GPT-2 Initialization: Standard layers inherit GPT-2 weights
- ✅ Gradient Checkpointing: Memory efficient
| Dataset | Train | Test | Source |
|---|---|---|---|
| GSM8K | 7,473 | 1,319 | HuggingFace |
| NuminaMath-CoT | 859,494 | - | HuggingFace |
| MATH-500 | - | 500 | HuggingFace |
| Synthetic | 5,000 | - | Generated |
python scripts/evaluate_checkpoint.py \
--checkpoint checkpoints_v2/checkpoint_step_25000.pt \
--datasets data/logic/gsm8k_test.jsonltorch>=2.1.0
transformers>=4.35.0
datasets>=2.18.0
tensorboard>=2.14.0
triton>=2.1.0
numpy>=1.26.0
pandas>=2.0.0
pyyaml>=6.0
- GPU: A100 24GB or RTX 4090 48GB recommended
- Memory: ~6-7GB with gradient checkpointing
- Time: ~1.5-2 hours per 10k steps on A100
@article{simplicial2024,
title={Fast and Simplex: 2-Simplicial Attention in Triton},
author={},
year={2024}
}