Skip to content

and-per-i/too-simplex

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

148 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

2-Simplicial Transformer Hybrid Model

A research implementation of Hybrid GPT-2 + 2-Simplicial Attention for token-efficient reasoning on math benchmarks.

Overview

This project explores whether a smaller model with simplicial (trilinear) attention can match or exceed larger standard transformer models on reasoning tasks, particularly GSM8K and MATH benchmarks.

Key Innovation

  • 2-Simplicial Attention: Trilinear attention that captures three-way token relationships
  • Placement at End: Simplicial layers at the end of the network as "reasoning head"
  • Adapter Layer: Lightweight adapter (layer 4) with small LR for fine-tuning
  • Optional MoE: Mixture of Simplicial Experts for enhanced reasoning
Standard attention:    A = softmax(Q · K)
Simplicial attention: A = softmax(Q · K₁) ⊙ softmax(Q · K₂) · V

Architecture

Layer Type Trainable LR
0-3 Standard (GPT-2 init) ❌ Frozen -
4 Adapter 1e-4
5-7 Simplicial 3e-4

Total: ~100M params, ~72M trainable

Project Structure

too-simplex/
├── simplicial/                 # Model implementation
│   ├── attention/             # Simplicial attention, RoPE, standard attention
│   ├── cache/                 # KV cache management
│   ├── layers/                # SimplicialBlock, StandardBlock, MoE
│   └── models/               # SimplicialModel
├── training/
│   ├── configs/               # Training configs
│   ├── datasets/             # Dataset loading
│   ├── plateau_handler.py    # Plateau detection
│   ├── synthetic_data_generator.py  # Synthetic math data
│   └── logic_trainer.py       # Gradual unfreezing (legacy)
├── scripts/
│   ├── train/                # Training scripts (train_hybrid.py)
│   ├── data/                 # Dataset download scripts
│   └── evaluate_checkpoint.py # Evaluation
├── data/logic/               # Training data (GSM8K, NuminaMath)
└── tests/                    # Test suite

Quick Start

1. Install Requirements

pip install -r requirements.txt

# Install PyTorch (for your CUDA version)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

2. Download Datasets

python scripts/data/download_datasets.py --datasets gsm8k numinamath math500

3. Train

python scripts/train/train_hybrid.py \
    --config training/configs/simplicial_final_layers.yaml \
    --checkpoint-dir checkpoints_v2 \
    --log-dir logs_v2

Training Configuration

Key parameters in training/configs/simplicial_final_layers.yaml:

model:
  n_layers: 8
  num_last_layers_simplicial: 3  # Last 3 are simplicial
  window_size_1: 1024  # Paper: 1024
  window_size_2: 32    # Paper: 32
  moe: false          # Enable for MoE boost

training:
  max_steps: 30000
  warmup_steps: 2000
  lr_simplicial: 3e-4
  lr_adapter: 1e-4
  eval_every: 2500
  save_every: 2500

Training Phases

Phase Steps LR Schedule Action
0: Warmup 0-2k Linear 0→3e-4 Warmup
1: Main 2k-10k Cosine to 1e-6 Monitor val
2: MoE 10k-30k Cosine Optional: enable MoE
3: Final 30k+ - Export model

Features

  • Adapter Layer: Layer 4 with small LR (1e-4)
  • Plateau Detection: Auto-inject synthetic data on plateau
  • Synthetic Data: 5K generated math problems
  • MoE Toggle: Optional mixture of simplicial experts
  • GPT-2 Initialization: Standard layers inherit GPT-2 weights
  • Gradient Checkpointing: Memory efficient

Datasets

Dataset Train Test Source
GSM8K 7,473 1,319 HuggingFace
NuminaMath-CoT 859,494 - HuggingFace
MATH-500 - 500 HuggingFace
Synthetic 5,000 - Generated

Evaluation

python scripts/evaluate_checkpoint.py \
    --checkpoint checkpoints_v2/checkpoint_step_25000.pt \
    --datasets data/logic/gsm8k_test.jsonl

Requirements

torch>=2.1.0
transformers>=4.35.0
datasets>=2.18.0
tensorboard>=2.14.0
triton>=2.1.0
numpy>=1.26.0
pandas>=2.0.0
pyyaml>=6.0

Hardware

  • GPU: A100 24GB or RTX 4090 48GB recommended
  • Memory: ~6-7GB with gradient checkpointing
  • Time: ~1.5-2 hours per 10k steps on A100

Citation

@article{simplicial2024,
  title={Fast and Simplex: 2-Simplicial Attention in Triton},
  author={},
  year={2024}
}

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages