Comparative Analysis of Knowledge Distillation Methods for Medical Question Answering
A comprehensive framework for distilling large medical language models (70B parameters) into efficient smaller models (1.5B parameters) using multiple state-of-the-art distillation techniques.
- Overview
- Features
- Installation
- Quick Start
- Project Structure
- Distillation Methods
- Documentation
- Usage Examples
- Citation
This project implements and compares 10 different knowledge distillation methods for medical language models, focusing on transferring knowledge from a large teacher model (Meditron-70B) to a compact student model (Qwen2-1.5B).
- Which distillation method best preserves medical reasoning capabilities?
- How does distillation affect factual accuracy in medical QA?
- What is the trade-off between model size, inference speed, and accuracy?
- Can we measure evidence faithfulness in distilled models?
- Supervised Fine-Tuning (SFT): Baseline imitation learning
- Logit-KD: Soft target distribution matching
- Adaptive-KD (AdaKD): Token-adaptive temperature adjustment
- Chain-of-Thought (CoT): Reasoning process distillation
- FitNets: Intermediate layer representation matching
- Attention Transfer: Attention pattern mimicry
- On-Policy RL: Policy gradient with teacher rewards
- PPO: Proximal policy optimization
- BOND: Best-of-N distillation
- SPIN: Self-play iterative refinement
- 4 Medical Benchmarks: MedQA, MedMCQA, PubMedQA, PubHealth
- Perplexity Analysis: Language modeling quality on MedPPL-10k
- Fidelity Metrics: KL divergence, BLEU, ROUGE, top-k overlap
- FidelityBench-Med: NLI-based fact verification and hallucination detection
- Automatic Visualization: Training curves, benchmark comparisons, radar charts
- QLoRA: 8-bit quantization + LoRA for memory efficiency
- Gradient Accumulation: Effective large batch training
- Automatic Mixed Precision: Faster training with FP16
- Ablation Studies: Hyperparameter sensitivity analysis
- Comparative Teacher Evaluation: Measure knowledge retention gap
- Python 3.9+
- CUDA 11.8+ (for GPU training)
- 24GB+ VRAM (RTX 4090 or A100 recommended)
- 32GB+ System RAM
# Install uv package manager
curl -LsSf https://astral.sh/uv/install.sh | sh
# Clone repository
git clone https://github.com/shreyanmitra/Medistillation.git
cd Medistillation
# Sync dependencies
uv sync# Clone repository
git clone https://github.com/shreyanmitra/Medistillation.git
cd Medistillation
# Create virtual environment
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
# Install dependencies
pip install -r requirements.txtpython -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA Available: {torch.cuda.is_available()}')"# Download and prepare all datasets (Med-DistillMix, benchmarks, etc.)
python src/DataLoader.py --prepare_allThis creates:
data/processed/train.jsonl(~414k examples)data/processed/validation.jsonldata/processed/test.jsonldata/benchmarks/(MedQA, MedMCQA, PubMedQA, PubHealth)data/medppl_10k.jsonl(perplexity corpus)data/fidelitybench_med.jsonl(faithfulness evaluation)
# Supervised Fine-Tuning (SFT) - Baseline
python src/Trainer.py \
--method sft \
--teacher_model epfl-llm/meditron-70b \
--student_model Qwen/Qwen2-1.5B \
--num_epochs 3 \
--batch_size 8 \
--output_dir outputs/sft_run1# Training curves and evaluation plots
ls outputs/sft_run1/results/*.png
# Evaluation metrics (JSON)
cat outputs/sft_run1/results/comprehensive_evaluation.jsonGenerated Plots:
training_curves.png- Loss and learning rate over epochsbenchmark_comparison.png- Student vs teacher accuracyfidelity_metrics.png- KL/BLEU/ROUGE scoresfidelitybench_radar.png- Evidence faithfulness radar chart
Medistillation/
βββ data/ # Datasets (gitignored)
β βββ raw/ # Original downloaded data
β βββ processed/ # Train/val/test splits
β βββ benchmarks/ # Evaluation benchmarks
β βββ medppl_10k.jsonl # Perplexity corpus
β βββ fidelitybench_med.jsonl # Faithfulness eval
β
βββ src/ # Source code
β βββ DataLoader.py # Dataset preparation and loading
β βββ DistillationMethods.py # Distillation algorithm implementations
β βββ Trainer.py # Training loop and evaluation
β
βββ docs/ # Documentation
β βββ guides/ # User guides
β β βββ EXPERIMENT_PROCEDURE.md
β β βββ FIDELITYBENCH_GUIDE.md
β β βββ VISUALIZATION_GUIDE.md
β βββ implementation/ # Technical details
β βββ research/ # Research materials
β
βββ scripts/ # Utility scripts
β βββ sample_medmcqa.py # Data sampling utility
β
βββ outputs/ # Training results (gitignored)
β βββ {method}_{run_name}/
β βββ checkpoints/
β βββ results/
β βββ final_model/
β
βββ tests/ # Unit tests
βββ requirements.txt # Python dependencies
βββ README.md # This file
python src/Trainer.py --method sft --num_epochs 3python src/Trainer.py --method logit_kd --alpha 0.5 --temperature 3.0python src/Trainer.py --method adakd --base_temperature 3.0 --min_temperature 1.0 --max_temperature 5.0python src/Trainer.py --method cot --num_rationales 3 --sampling_temperature 0.7python src/Trainer.py --method fitnets --layer_mapping '{"6":12,"12":24}' --use_projectionspython src/Trainer.py --method attention --layer_mapping '{"6":12,"12":24}' --match_all_headspython src/Trainer.py --method ppo --epsilon 0.2 --gamma 0.99python src/Trainer.py --method bond --num_samples 16python src/Trainer.py --method spin --beta 0.1Comprehensive guides available in docs/:
- EXPERIMENT_PROCEDURE.md - Complete experimental protocol
- FIDELITYBENCH_GUIDE.md - Evidence faithfulness evaluation
- VISUALIZATION_GUIDE.md - Plotting and analysis
- Code_Organization_Architecture.md - Codebase structure
- DATASET_STRATEGY.md - Data preparation strategy
- VISUALIZATION_FEATURES.md - Plotting features
- project_details_latex.tex - LaTeX project description
- references.bib - BibTeX references
# Test different temperatures for Logit-KD
python src/Trainer.py \
--run_ablation \
--ablation_type temperature \
--ablation_values "2.0,3.0,4.0,5.0" \
--method logit_kd
# View sensitivity analysis plot
start outputs/ablation_temperature/ablation_plot.png # Windows
# open outputs/ablation_temperature/ablation_plot.png # Mac/Linux# Train different methods
python src/Trainer.py --method sft --output_dir outputs/sft
python src/Trainer.py --method logit_kd --alpha 0.5 --temperature 3.0 --output_dir outputs/logit_kd
python src/Trainer.py --method adakd --output_dir outputs/adakd
# Compare results
python scripts/compare_results.py outputs/sft outputs/logit_kd outputs/adakdIf you use this code in your research, please cite:
@misc{medistillation2025,
title={Comparative Analysis of Knowledge Distillation Methods for Medical Question Answering},
author={CSE 493S Team},
year={2025},
publisher={University of Washington},
url={https://github.com/shreyanmitra/Medistillation}
}- Teacher Model: Meditron-70B by EPFL LLM Team
- Student Model: Qwen2-1.5B by Alibaba Cloud
- Datasets: MedQA, MedMCQA, PubMedQA, PubHealth teams
- Course: CSE 493S - Advanced Topics in Machine Learning, University of Washington
Happy Distilling! π§ͺβ¨