Skip to content

dingdingpista/ClinicalShift

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

10 Commits
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

ClinicalShift ๐Ÿซ€

Self-supervised distribution shift detection in clinical ECG data using contrastive learning

Python PyTorch Dataset License


๐Ÿง  Motivation

A core challenge in deploying machine learning models in clinical settings is distribution shift โ€” the phenomenon where test data differs systematically from training data. A model trained on ECGs from one patient cohort may silently fail when deployed on a different population, with no warning signal.

This project addresses that problem directly. Rather than assuming test data mirrors training conditions, ClinicalShift learns a compact representation of normal ECG patterns using self-supervised contrastive learning, then flags samples that fall outside the learned distribution โ€” a prerequisite for trustworthy clinical AI.

This work is directly motivated by the EU AI Act (2024), which requires high-risk AI systems in medical settings to include mechanisms for detecting anomalous inputs.


๐Ÿ“‹ Table of Contents


โ“ Problem Statement

Given:

  • A model trained on ECG data from young patients (age 20โ€“50)
  • New incoming ECGs from old patients (age 70+)

Can we detect this shift automatically, without labels, before the model makes predictions?

This simulates real deployment scenarios where hospitals apply models across patient populations that differ from the original training cohort.


๐Ÿ”ฌ Methodology

The pipeline has three stages:

Stage 1 โ€” Representation Learning
    Train a Conv1D encoder using SimCLR-style contrastive learning
    on in-distribution ECG windows. No labels required.

Stage 2 โ€” Distribution Fitting  
    Extract embeddings for all training samples.
    Fit a multivariate Gaussian: compute ฮผ and ฮฃ.

Stage 3 โ€” Shift Detection
    For each new sample, compute Mahalanobis distance from ฮผ.
    Samples exceeding threshold = distribution shift detected.

๐Ÿ—ƒ๏ธ Dataset

PTB-XL โ€” A large publicly available ECG dataset from the Physikalisch-Technische Bundesanstalt, Berlin, Germany.

Property Value
Total recordings 21,837
Sampling rate 100 Hz (used) / 500 Hz
Signal duration 10 seconds
ECG leads 12
Signal shape (1000, 12) per recording

Distribution Shift Simulation:

Split Age Range Recordings Role
In-Distribution 20 โ€“ 50 years 5,148 Training domain
Out-of-Distribution 70+ years 6,726 Shifted domain

Dataset Distribution


๐Ÿ—๏ธ Architecture

Architecture

The ECGEncoder processes single-lead ECG windows through a convolutional backbone followed by a projection head:

Input (batch, 1, 250)
    โ†’ Conv Block 1: Conv1D(1โ†’64,  k=7) + BatchNorm + ReLU + MaxPool
    โ†’ Conv Block 2: Conv1D(64โ†’128, k=5) + BatchNorm + ReLU + MaxPool
    โ†’ Conv Block 3: Conv1D(128โ†’256,k=3) + BatchNorm + ReLU + MaxPool
    โ†’ Global Average Pooling          โ†’ (batch, 256)
    โ†’ Embedding Layer: Linear(256โ†’128) + ReLU + Dropout(0.3)
    โ†’ Projection Head: Linear(128โ†’64)  + L2 Normalization
Output: (batch, 64) โ€” unit-norm embedding vector

Total parameters: ~110K (deliberately lightweight for clinical deployment)


๐Ÿ“ Mathematical Formulation

1. ECG Preprocessing

Min-Max Normalization:

$$x_{norm} = \frac{x - x_{min}}{x_{max} - x_{min} + \epsilon}$$

Sliding Window Segmentation:

$$n_{windows} = \left\lfloor \frac{L - W}{S} \right\rfloor + 1$$

Where $L=1000$ (signal length), $W=250$ (window), $S=125$ (step, 50% overlap)


2. Contrastive Learning (SimCLR)

For each ECG window $x$, two augmented views are created: $\tilde{x}_i = t(x)$, $\tilde{x}_j = t'(x)$ where $t, t'$ are random augmentations.

ECG Augmentations used:

  • Gaussian noise: $x' = x + \mathcal{N}(0, 0.05)$
  • Random scaling: $x' = x \cdot s$, $s \sim \mathcal{U}(0.8, 1.2)$
  • Time shift: $x' = \text{roll}(x, \delta)$, $\delta \sim \mathcal{U}(-20, 20)$
  • Random masking: $x'[a:a+l] = 0$

Augmentations

NT-Xent Loss (Normalized Temperature-scaled Cross Entropy):

$$\mathcal{L}_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j)/\tau)}{\sum_{k=1}^{2N} \mathbf{1}_{[k \neq i]} \exp(\text{sim}(z_i, z_k)/\tau)}$$

Where cosine similarity is:

$$\text{sim}(u, v) = \frac{u \cdot v}{|u| \cdot |v|}$$

Temperature $\tau = 0.5$ controls the sharpness of the distribution.


3. Distribution Shift Detection

Step 1 โ€” Fit multivariate Gaussian on training embeddings:

$$\mu = \frac{1}{N}\sum_{i=1}^{N} z_i, \quad \Sigma = \frac{1}{N}\sum_{i=1}^{N}(z_i - \mu)(z_i - \mu)^T$$

Step 2 โ€” Mahalanobis distance for test sample $z$:

$$D_M(z) = \sqrt{(z - \mu)^T \Sigma^{-1} (z - \mu)}$$

Step 3 โ€” Detection threshold (2-sigma rule):

$$\theta = \mu_{train} + 2\sigma_{train}$$

Samples where $D_M(z) > \theta$ are flagged as distribution shift detected.


๐Ÿ“Š Results

ECG Signal Comparison

ECG Sample

Preprocessing Pipeline

Preprocessing

Training Curve

Training Loss

Metric Value
Starting NT-Xent Loss 5.1325
Final NT-Xent Loss 4.6131
Improvement 10.1%
Epochs 50

Shift Detection Performance

Drift Scores

Metric Value
AUROC 0.6839
Average Precision 0.6892
Detection Threshold (2ฯƒ) see drift plot
In-dist flagged as shifted ~5% (expected)
Out-dist flagged as shifted significantly higher

AUROC of 0.68 is achieved with zero label supervision during training. The encoder never received patient age or diagnosis information โ€” the shift signal emerges purely from contrastive representation learning.

UMAP Embedding Space

UMAP Embeddings

The UMAP projection reveals that the contrastive encoder learns embeddings where in-distribution (young) and out-of-distribution (old) patients occupy partially distinct regions of the latent space โ€” without ever being told patient age during training.


๐Ÿ“ Project Structure

ClinicalShift/
โ”‚
โ”œโ”€โ”€ src/
โ”‚   โ”œโ”€โ”€ dataset.py          # PTB-XL loading, normalization, sliding window
โ”‚   โ”œโ”€โ”€ augmentations.py    # ECG augmentation functions
โ”‚   โ”œโ”€โ”€ model.py            # Conv1D contrastive encoder
โ”‚   โ”œโ”€โ”€ loss.py             # NT-Xent contrastive loss
โ”‚   โ”œโ”€โ”€ trainer.py          # Training loop with cosine LR schedule
โ”‚   โ””โ”€โ”€ shift_detector.py   # Mahalanobis distance shift detection
โ”‚
โ”œโ”€โ”€ assets/
โ”‚   โ”œโ”€โ”€ ecg_sample.png           # ECG signal comparison
โ”‚   โ”œโ”€โ”€ dataset_distribution.png # Age split visualization
โ”‚   โ”œโ”€โ”€ preprocessing.png        # Preprocessing pipeline
โ”‚   โ”œโ”€โ”€ architecture.png         # Model architecture diagram
โ”‚   โ”œโ”€โ”€ augmentations.png        # Augmentation examples
โ”‚   โ”œโ”€โ”€ training_loss.png        # Training curve
โ”‚   โ”œโ”€โ”€ drift_scores.png         # Shift detection results
โ”‚   โ””โ”€โ”€ umap_embeddings.png      # Embedding space visualization
โ”‚
โ”œโ”€โ”€ ClinicalShift.ipynb     # Complete reproducible notebook
โ”œโ”€โ”€ requirements.txt
โ””โ”€โ”€ README.md

โš™๏ธ Setup and Usage

1. Clone the repository

git clone https://github.com/dingdingpista/ClinicalShift.git
cd ClinicalShift

2. Install dependencies

pip install -r requirements.txt

3. Download PTB-XL dataset

Download from PhysioNet and place in data/ folder.

4. Run the complete pipeline

Open ClinicalShift.ipynb in Google Colab or Jupyter. Run all cells sequentially โ€” the notebook is self-contained.


๐Ÿ”ฎ Future Work

  1. Multi-lead encoding โ€” use all 12 leads instead of Lead II only
  2. Pathology-based shift โ€” shift by diagnosis rather than age
  3. Online detection โ€” sliding window shift detection in real-time
  4. Transformer encoder โ€” replace Conv1D with temporal attention
  5. Calibrated scores โ€” convert Mahalanobis distances to probabilities

๐Ÿ“š References

  1. Chen, T., et al. (2020). A Simple Framework for Contrastive Learning of Visual Representations (SimCLR). ICML. arxiv:2002.05709

  2. Wagner, P., et al. (2020). PTB-XL, a large publicly available electrocardiography dataset. Scientific Data. doi:10.1038/s41597-020-0495-6

  3. Lee, K., et al. (2018). A Simple Unified Framework for Detecting Out-of-Distribution Samples. NeurIPS. arxiv:1807.03888


๐Ÿ“„ License

MIT License โ€” see LICENSE file.


Built as part of a portfolio for MSc Data Science applications โ€” Germany 2025/26

Releases

No releases published

Packages

 
 
 

Contributors