Self-supervised distribution shift detection in clinical ECG data using contrastive learning
A core challenge in deploying machine learning models in clinical settings is distribution shift โ the phenomenon where test data differs systematically from training data. A model trained on ECGs from one patient cohort may silently fail when deployed on a different population, with no warning signal.
This project addresses that problem directly. Rather than assuming test data mirrors training conditions, ClinicalShift learns a compact representation of normal ECG patterns using self-supervised contrastive learning, then flags samples that fall outside the learned distribution โ a prerequisite for trustworthy clinical AI.
This work is directly motivated by the EU AI Act (2024), which requires high-risk AI systems in medical settings to include mechanisms for detecting anomalous inputs.
- Problem Statement
- Methodology
- Dataset
- Architecture
- Mathematical Formulation
- Results
- Project Structure
- Setup and Usage
- References
Given:
- A model trained on ECG data from young patients (age 20โ50)
- New incoming ECGs from old patients (age 70+)
Can we detect this shift automatically, without labels, before the model makes predictions?
This simulates real deployment scenarios where hospitals apply models across patient populations that differ from the original training cohort.
The pipeline has three stages:
Stage 1 โ Representation Learning
Train a Conv1D encoder using SimCLR-style contrastive learning
on in-distribution ECG windows. No labels required.
Stage 2 โ Distribution Fitting
Extract embeddings for all training samples.
Fit a multivariate Gaussian: compute ฮผ and ฮฃ.
Stage 3 โ Shift Detection
For each new sample, compute Mahalanobis distance from ฮผ.
Samples exceeding threshold = distribution shift detected.
PTB-XL โ A large publicly available ECG dataset from the Physikalisch-Technische Bundesanstalt, Berlin, Germany.
| Property | Value |
|---|---|
| Total recordings | 21,837 |
| Sampling rate | 100 Hz (used) / 500 Hz |
| Signal duration | 10 seconds |
| ECG leads | 12 |
| Signal shape | (1000, 12) per recording |
Distribution Shift Simulation:
| Split | Age Range | Recordings | Role |
|---|---|---|---|
| In-Distribution | 20 โ 50 years | 5,148 | Training domain |
| Out-of-Distribution | 70+ years | 6,726 | Shifted domain |
The ECGEncoder processes single-lead ECG windows through a convolutional backbone followed by a projection head:
Input (batch, 1, 250)
โ Conv Block 1: Conv1D(1โ64, k=7) + BatchNorm + ReLU + MaxPool
โ Conv Block 2: Conv1D(64โ128, k=5) + BatchNorm + ReLU + MaxPool
โ Conv Block 3: Conv1D(128โ256,k=3) + BatchNorm + ReLU + MaxPool
โ Global Average Pooling โ (batch, 256)
โ Embedding Layer: Linear(256โ128) + ReLU + Dropout(0.3)
โ Projection Head: Linear(128โ64) + L2 Normalization
Output: (batch, 64) โ unit-norm embedding vector
Total parameters: ~110K (deliberately lightweight for clinical deployment)
Min-Max Normalization:
Sliding Window Segmentation:
Where
For each ECG window
ECG Augmentations used:
- Gaussian noise:
$x' = x + \mathcal{N}(0, 0.05)$ - Random scaling:
$x' = x \cdot s$ ,$s \sim \mathcal{U}(0.8, 1.2)$ - Time shift:
$x' = \text{roll}(x, \delta)$ ,$\delta \sim \mathcal{U}(-20, 20)$ - Random masking:
$x'[a:a+l] = 0$
NT-Xent Loss (Normalized Temperature-scaled Cross Entropy):
Where cosine similarity is:
Temperature
Step 1 โ Fit multivariate Gaussian on training embeddings:
Step 2 โ Mahalanobis distance for test sample
Step 3 โ Detection threshold (2-sigma rule):
Samples where
| Metric | Value |
|---|---|
| Starting NT-Xent Loss | 5.1325 |
| Final NT-Xent Loss | 4.6131 |
| Improvement | 10.1% |
| Epochs | 50 |
| Metric | Value |
|---|---|
| AUROC | 0.6839 |
| Average Precision | 0.6892 |
| Detection Threshold (2ฯ) | see drift plot |
| In-dist flagged as shifted | ~5% (expected) |
| Out-dist flagged as shifted | significantly higher |
AUROC of 0.68 is achieved with zero label supervision during training. The encoder never received patient age or diagnosis information โ the shift signal emerges purely from contrastive representation learning.
The UMAP projection reveals that the contrastive encoder learns embeddings where in-distribution (young) and out-of-distribution (old) patients occupy partially distinct regions of the latent space โ without ever being told patient age during training.
ClinicalShift/
โ
โโโ src/
โ โโโ dataset.py # PTB-XL loading, normalization, sliding window
โ โโโ augmentations.py # ECG augmentation functions
โ โโโ model.py # Conv1D contrastive encoder
โ โโโ loss.py # NT-Xent contrastive loss
โ โโโ trainer.py # Training loop with cosine LR schedule
โ โโโ shift_detector.py # Mahalanobis distance shift detection
โ
โโโ assets/
โ โโโ ecg_sample.png # ECG signal comparison
โ โโโ dataset_distribution.png # Age split visualization
โ โโโ preprocessing.png # Preprocessing pipeline
โ โโโ architecture.png # Model architecture diagram
โ โโโ augmentations.png # Augmentation examples
โ โโโ training_loss.png # Training curve
โ โโโ drift_scores.png # Shift detection results
โ โโโ umap_embeddings.png # Embedding space visualization
โ
โโโ ClinicalShift.ipynb # Complete reproducible notebook
โโโ requirements.txt
โโโ README.md
git clone https://github.com/dingdingpista/ClinicalShift.git
cd ClinicalShiftpip install -r requirements.txtDownload from PhysioNet
and place in data/ folder.
Open ClinicalShift.ipynb in Google Colab or Jupyter.
Run all cells sequentially โ the notebook is self-contained.
- Multi-lead encoding โ use all 12 leads instead of Lead II only
- Pathology-based shift โ shift by diagnosis rather than age
- Online detection โ sliding window shift detection in real-time
- Transformer encoder โ replace Conv1D with temporal attention
- Calibrated scores โ convert Mahalanobis distances to probabilities
-
Chen, T., et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations (SimCLR). ICML. arxiv:2002.05709
-
Wagner, P., et al. (2020). PTB-XL, a large publicly available electrocardiography dataset. Scientific Data. doi:10.1038/s41597-020-0495-6
-
Lee, K., et al. (2018). A Simple Unified Framework for Detecting Out-of-Distribution Samples. NeurIPS. arxiv:1807.03888
MIT License โ see LICENSE file.
Built as part of a portfolio for MSc Data Science applications โ Germany 2025/26







