A focused Neural ODE that models how cognitive states (attention, fatigue, stress) evolve through time using EEG features from the DEAP dataset.
Learns a differential equation dS/dt = f(S, U, θ) from EEG data. Given an initial state and EEG band powers, it predicts how attention, fatigue, and stress evolve continuously through time.
DEAP EEG Dataset (32-channel, 32 subjects, 40 emotion trials each)
↓
EEG Band Features → [alpha_power, theta_power, beta_power]
↓
State Mapping → S = [attention, fatigue, stress] ∈ ℝ³
↓
Neural ODE → dS/dt = f(S, U, θ) learned via torchdiffeq
↓
Trajectory S(t) → dopri5 adaptive solver
↓
FastAPI → React frontend
neuro_ode/
├── data/
│ └── preprocess.py ← EEG feature extraction + state mapping
├── model/
│ └── neural_ode.py ← ODEFunc + NeuralODE (PyTorch)
├── scripts/
│ └── train.py ← training loop, evaluation, saves weights
├── api/
│ ├── backend.py ← FastAPI: /api/health + /api/simulate
│ ├── weights.pt ← saved after training
│ └── norm_stats.json ← normalization stats
├── frontend/
│ └── index.html ← React + Recharts visualization
├── requirements.txt
└── README.md
pip install -r requirements.txt
# Train (uses synthetic DEAP-like data if real DEAP not present)
python scripts/train.py
# Serve API
python api/backend.py
# → http://localhost:8000
# Open frontend/index.html in browser- Request access at https://www.eecs.qmul.ac.uk/mmv/datasets/deap/
- Place
.matfiles indata/deap/ - Re-run
python scripts/train.py
{ "status": "ok", "trained": true, "state_dim": 3 }{
"attention": 0.65, "fatigue": 0.25, "stress": 0.30,
"alpha_power": 0.40, "theta_power": 0.30, "beta_power": 0.25,
"duration": 10.0, "n_steps": 40
}Returns: { trajectory: [{t, attention, fatigue, stress}], mse, n_steps }
| State | EEG Proxy | Interpretation |
|---|---|---|
| Attention | ↑ alpha relative power | Focused, alert |
| Fatigue | ↑ theta relative power | Drowsy, slowing |
| Stress | ↑ beta relative power | Anxious, activated |
| Before | After |
|---|---|
| SyntheticDataLayer (4 fake datasets) | Single real DEAP pipeline |
| CORAL multi-domain alignment | Simple z-score normalization |
| Hand-crafted coupling matrices | Learned ODEFunc (nn.Sequential) |
| RK4 duplicated in frontend JS | Backend-only inference |
| HNNBridge, domain_gains, 5-D state | Clean 3-D state space |
| 8 API endpoints | 2 endpoints |