Rohan Shad, Cyril Zakka, Dhamanpreet Kaur, Mrudang Mathur, Robyn Fong, Joseph Cho, Ross Warren Filice, John Mongan, Kimberly Kalianos, Nishith Khandwala, David Eng, Matthew Leipzig, Walter Witschey, Alejandro de Feria, Victor Ferrari, Euan Ashley, Michael A. Acker, Curtis Langlotz, William Hiesinger
Here we describe a transformer-based vision system that learns complex pathophysiological visual representations from a large multi-institutional dataset of 19,041 CMR scans, guided by natural language supervision from the text reports accompanying each CMR study. We use a large language model to help ‘teach’ a vision network to generate meaningful low-dimensional representations of CMR studies, by showing examples of how radiologists describe what they see while drafting their reports. We utilize a contrastive learning objective using the InfoNCE objective. The video encoder used is an implementation of MVIT (Multi-scale vision transformers) initialzed using Kinetics-400 pre-trained weights. The text encoder used is an implementation of BERT (Bidirectional encoder representations with transformers) pretrained on pubmed abstracts with a custom vocabulary. Please see our paper for more.
MRI cine sequences are stored within hdf5 files for portability and performance. Pixel information are stored as arrays under a top level directory for each unique patient. Certain views may have more than one 'video' taken at multiple parallel sections (eg: SAX view typically has numerous sequences taken from base to apex of the heart). Attributes such as 'slice frame index' demarcate when each unique video begins and ends. Please see the paper and supplementary appendix for additional details on how we prepare and structure our augmentations.
patient_id
├── accession_number_1.h5
├── accession_number_2.h5
├── 4CH {data: 4d array (c, f, h, w)} {attr: total images, slice frame index}
├── SAX {data: 4d array (c, f, h, w)} {attr: total images, slice frame index}
├── 2CH {data: 4d array (c, f, h, w)} {attr: total images, slice frame index}
├── 3CH {data: 4d array (c, f, h, w)} {attr: total images, slice frame index}
This repository contains template code for finetuning and evaluation, in addition contains all model classes required to load our weights for use in your own projects. To use this repository as is for finetuning on your own datasets, you will need to use Wandb for experiment tracking, or make appropriate changes to use Tensorboard. The repository also relies on a local_config.yaml (not included) file to set some variables (eg: PRETRAIN_WEIGHTS or ATTN_DIR. This format of this file is as follows:
# Device specific options
your_computer_name:
tmp_dir: 'some/temp/dir'
pretrain_dir: '/some/pretrain/dir'
attn_dir: 'some/dir/attention_maps'
Using the repo without a local_config.yaml file is possible if you hard code those variables in model_factory.py and mri_trainer.py
Weights for our pretrained CMR encoders are available on Huggingface for non-commercial use (CC-BY-NC 4.0): https://huggingface.co/rohanshad/cmr_c0.1.
If you don't need finetuning or the full training pipeline and just want to load the pretrained encoder and generate embeddings from your own data, use minimal_run.py. It pulls weights directly from Hugging Face and runs a forward pass with no config files, no W&B setup, and no local_config.yaml required.
pip install torch torchvision huggingface_hub
python minimal_run.pyThis will download the checkpoint, load the MViT encoder, and run a forward pass on a random [1, 3, 16, 224, 224] tensor (batch × RGB × frames × H × W), printing the resulting embedding and its shape. Swap in your own preprocessed HDF5 data in place of demo_input.
If you're working from raw DICOM data, run it through cmr_toolkit first to produce HDF5 files, then apply spatial transforms before passing to the model:
from torchvision.transforms import v2
val_transforms = v2.Compose([v2.Resize(size=244), v2.CenterCrop(size=224)])Tested with CUDA on Ubuntu 20.02, 24.04 and CentOS7.
-
Create new conda environment using python version 3.9
conda create -n mri_torch python=3.9 conda activate mri_torch -
Install dependencies
cat requirements.txt | xargs -n 1 pip install --force-reinstall pip install torch==1.11.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113 --forceSpecific torch download links for cuda enabled pytorch versions. Codebase not tested with torch > 1.11
-
Download example data: Please reach out to me over email if you wish to test your models on the University of Pennsylvania Cardiac MRI dataset. Kaggle and ACDC Datasets are publicly avaialble. Kaggle data will require conversion to hdf5 via preprocessing scripts supplied, ACDC datasets directly usable in native nifti format.
Researchers may have access to the UK BioBank and wish to use our models on CMR data from the UK BioBank. We use scripts avaialble in our cmr_toolkit to prepare and pre-process the data. We first run the entire UK BioBank data directory through tar_compress.py to restructure data from each scan into a single parent level tarfile unique for each scan. We use the preprocess_mri.py and then build_dataset.py scripts to build the final hdf5 datastore. Instructions in the cmr_toolkit repository.
python mri_trainer.py validate --config configs/acdc_evaluation.yaml
python mri_trainer.py fit --config configs/finetune_config.yaml
python mri_trainer.py test --config configs/eval_config.yaml
If you use this codebase, or otherwise found our work valuable, please cite:
@article{shad2026generalizabledeeplearningcardiac,
title={A Generalizable Deep Learning System for Cardiac MRI},
author={Rohan Shad and Cyril Zakka and Dhamanpreet Kaur and Robyn Fong and Ross Warren Filice and John Mongan and Kimberly Kalianos and Nishith Khandwala and David Eng and Matthew Leipzig and Walter Witschey and Alejandro de Feria and Victor Ferrari and Euan Ashley and Michael A. Acker and Curtis Langlotz and William Hiesinger},
journal={Nature Biomedical Engineering},
year={2026},
doi={10.1038/s41551-026-01637-3},
url={https://www.nature.com/articles/s41551-026-01637-3},
}