A PyTorch framework for knowledge distillation from large language models (LLMs) to hybrid architectures that combine Transformer attention mechanisms with State-Space Models (SSM), such as Mamba.
- Model building blocks for Llama, Qwen2, Falcon, and Phi-style hybrids.
- A composable YAML config system (
LOAD-based inheritance). - Distillation objectives:
supervised,hstates,matrices, anddpo. - Distributed training wrappers (DDP/FSDP/centralized).
- Evaluation utilities for perplexity and lm-eval-harness tasks.
This is not a packaged library. You run scripts directly from the repo.
- Python 3.8+
- PyTorch 2.1+ with CUDA
- One or more NVIDIA GPUs for training/eval
git clone https://github.com/avivbick/mohawk.git
cd mohawk
pip install -r requirements.txtOptional accelerators:
pip install flash-attn --no-build-isolationUse environment variables for credentials instead of hardcoding:
HF_TOKENfor private/gated Hugging Face modelsWANDB_API_KEYfor experiment trackingCUDA_VISIBLE_DEVICESto pin GPUs
Global runtime defaults live in configs/management.yaml.
python run.py --config configs/Qwen2/1.5B/hybrid/adapter.yamltorchrun --standalone --nproc_per_node=8 run.py \
--config configs/Qwen2/1.5B/hybrid/adapter.yaml--config also accepts a comma-separated list; configs are loaded and run sequentially.
Every run is driven by YAML. The important top-level sections are:
ComponentsConfig: architecture definition (block sequence and layer types)TrainConfig: optimization schedule and training lengthDistillConfig: objective selection and logging run nameTeacherConfig: teacher checkpoint/path and tokenizer contextTrainDataConfig: dataset source and loader strategyLoadConfig: initialization and checkpoint loading rulesManagementConfig: cache paths, W&B config, environment defaults
Useful starting points:
configs/Qwen2/1.5B/hybrid/adapter.yamlconfigs/Llama/1B/hybrid/mohawk_8.yamlconfigs/Llama/8B/bases/_supervised.yaml
Perplexity is integrated through training/eval wrappers and evals/eval_ppl.py implements the evaluator class used by those wrappers.
python evals/benchmark.py --dir <checkpoint_or_hf_model_dir> --tasks mmlu--tasks is a comma-separated list, for example:
arc_challenge,arc_easy,piqa,winogrande,hellaswag,mmlu.
-
tools/hybrid_weights_transfer.pyCopies selected attention heads from a teacher to a hybrid student. Uses--configand expects a supportedTeacherConfig.dir. -
tools/benchmark_throughput.pyCUDA-graph throughput microbenchmark. This script is research-oriented and currently contains model-specific assumptions and hardcoded defaults. -
tools/visualize_attention.pyProduces attention heatmaps for manually selected heads on a fixed example. Useful for qualitative inspection, not automated evaluation. -
generation/generate.pyInference/sampling script with timing output.
mohawk/
├── components/ # Blocks, mixers, LM heads
├── configs/ # Train/eval architecture recipes
├── dataloaders/ # Dataset generators and wrappers
├── distill/ # Run orchestration and objective steps
├── evals/ # Evaluation entrypoints and adapters
├── external_models/ # External model definitions integrated here
├── generation/ # Text generation utilities
├── training_wrapper/ # DDP/FSDP/centralized wrappers
├── utils/ # Config, logging, init, distributed helpers
└── run.py # Main training entrypoint
This codebase was used in the following research publications:
@misc{bick2026retrieval,
title={Retrieval-Aware Distillation for Transformer-SSM Hybrids},
author={Aviv Bick and Eric P. Xing and Albert Gu},
year={2026},
eprint={2602.11374},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2602.11374},
}@article{bick2025llamba,
title={Llamba: Scaling distilled recurrent models for efficient language processing},
author={Bick, Aviv and Katsch, Tobias and Sohoni, Nimit and Desai, Arjun and Gu, Albert},
journal={arXiv preprint arXiv:2502.14458},
year={2025}
}@misc{paliotta2025thinking,
title={Thinking Slow, Fast: Scaling Inference Compute with Distilled Reasoners},
author={Daniele Paliotta and Junxiong Wang and Matteo Pagliardini and Kevin Y. Li and Aviv Bick and J. Zico Kolter and Albert Gu and François Fleuret and Tri Dao},
year={2025},
eprint={2502.20339},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.20339},
}@misc{mohawk,
title={Transformers to SSMs: Distilling Quadratic Knowledge to Subquadratic Models},
author={Aviv Bick and Kevin Y. Li and Eric P. Xing and J. Zico Kolter and Albert Gu},
year={2025},
eprint={2408.10189},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2408.10189},
}If this repository is useful in your work, cite:
@software{mohawk,
title = {Knowledge Distillation for Hybrid Transformer-SSM Models},
author = {Aviv Bick},
year = {2024},
url = {https://github.com/avivbick/mohawk}
}MIT. See LICENSE.
Contribution workflow and expectations are documented in CONTRIBUTING.md.