| title | Medical image retrieval using a CLIP model |
|---|---|
| emoji | 🩺 |
| colorFrom | red |
| colorTo | white |
| sdk | streamlit |
| app_file | app.py |
| pinned | true |
This repository contains the code for fine-tuning a CLIP model [Arxiv paper][OpenAI Github Repo] on the ROCO dataset, a dataset made of radiology images and a caption. This work is done as a part of the Flax/Jax community week organized by Hugging Face and Google.
SciBERT (allenai/scibert_scivocab_uncased on 🤗) is used as the casual language model.
[🤗 Model card] [Streamlit demo]
You can try a Streamlit demo app that uses this model on 🤗 Spaces. You may have to signup for 🤗 Spaces private beta to access this app (screenshot shown below).

The demo can be run locally in the browser with
streamlit run /home/kaushalya/coding/medclip/app.py
Each image is accompanied by a textual caption. The caption length varies from a few characters (a single word) to 2,000 characters (multiple sentences). During preprocessing we remove all images that has a caption shorter than 10 characters. Training set: 57,780 images with their caption. Validation set: 7,200 Test set: 7,650
[ ] Give an example
The image embeddings for the test set are stored using Git LFS. To download the images, you need to have Git LFS installed.
# Install Git LFS (example for Ubuntu/Debian)
sudo apt install git-lfs
# Initialize Git LFS in your repository (run this once per clone)
git lfs install
# Download the LFS files
git lfs pullThis will download all LFS files in the repository. If you only want to download the data for this project, you might need to specify the directory:
git lfs pull --include="data/*" This repo depends on the master branch of Hugging Face - Transformers library. First you need to clone the transformers repository and then install it locally (preferably inside a virtual environment) with pip install -e ".[flax]".
You can load the pretrained model from the Hugging Face Hub with
from medclip.modeling_hybrid_clip import FlaxHybridCLIP
model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")
Alternatively you can download the model checkpoint from [🤗 Model card].
The model is trained using Flax/JAX on a cloud TPU-v3-8.
You can fine-tune a CLIP model implemented in Flax by simply running sh run_medclip.
This is the validation loss curve we observed when we trained the model using the run_medclip.sh script.

The current model is capable of identifying higher level features such as the modality of ain image (e.g., if a given radiology image is a PET scan or an ultrasound scan). However it fails at identifying a brain scan from a lung scan. ❗️This model should not be used in a medical setting without further evaluations❗️.
Huge thanks to the Hugging Face 🤗 team and Google JAX/Flax team for organizing the community week and letting us use cloud compute for 2 weeks. We specially thank @patil-suraj & @patrickvonplaten for the continued support on Slack and the detailed feedback.
[ ] Mention more examples
[ ] Evaluation on down-stream tasks
[ ] Zero-shot learning performance
