Skip to content

Kaushalya/medclip

Repository files navigation

title Medical image retrieval using a CLIP model
emoji 🩺
colorFrom red
colorTo white
sdk streamlit
app_file app.py
pinned true

MedCLIP: Fine-tuning a CLIP model on the ROCO medical dataset

huggingface-medclip

Summary

This repository contains the code for fine-tuning a CLIP model [Arxiv paper][OpenAI Github Repo] on the ROCO dataset, a dataset made of radiology images and a caption. This work is done as a part of the Flax/Jax community week organized by Hugging Face and Google.

SciBERT (allenai/scibert_scivocab_uncased on 🤗) is used as the casual language model.

[🤗 Model card] [Streamlit demo]

Demo

You can try a Streamlit demo app that uses this model on 🤗 Spaces. You may have to signup for 🤗 Spaces private beta to access this app (screenshot shown below). Streamlit app

The demo can be run locally in the browser with

streamlit run /home/kaushalya/coding/medclip/app.py

Dataset 🧩

Each image is accompanied by a textual caption. The caption length varies from a few characters (a single word) to 2,000 characters (multiple sentences). During preprocessing we remove all images that has a caption shorter than 10 characters. Training set: 57,780 images with their caption. Validation set: 7,200 Test set: 7,650

[ ] Give an example

Downloading the data

The image embeddings for the test set are stored using Git LFS. To download the images, you need to have Git LFS installed.

# Install Git LFS (example for Ubuntu/Debian)
sudo apt install git-lfs

# Initialize Git LFS in your repository (run this once per clone)
git lfs install

# Download the LFS files
git lfs pull

This will download all LFS files in the repository. If you only want to download the data for this project, you might need to specify the directory:

git lfs pull --include="data/*" 

Installation 💽

This repo depends on the master branch of Hugging Face - Transformers library. First you need to clone the transformers repository and then install it locally (preferably inside a virtual environment) with pip install -e ".[flax]".

The Model ⚙️

You can load the pretrained model from the Hugging Face Hub with

from medclip.modeling_hybrid_clip import FlaxHybridCLIP

model = FlaxHybridCLIP.from_pretrained("flax-community/medclip-roco")

Alternatively you can download the model checkpoint from [🤗 Model card].

Training

The model is trained using Flax/JAX on a cloud TPU-v3-8. You can fine-tune a CLIP model implemented in Flax by simply running sh run_medclip. This is the validation loss curve we observed when we trained the model using the run_medclip.sh script. Validation loss

Limitations 🚨

The current model is capable of identifying higher level features such as the modality of ain image (e.g., if a given radiology image is a PET scan or an ultrasound scan). However it fails at identifying a brain scan from a lung scan. ❗️This model should not be used in a medical setting without further evaluations❗️.

Acknowledgements

Huge thanks to the Hugging Face 🤗 team and Google JAX/Flax team for organizing the community week and letting us use cloud compute for 2 weeks. We specially thank @patil-suraj & @patrickvonplaten for the continued support on Slack and the detailed feedback.

TODO

[ ] Mention more examples

[ ] Evaluation on down-stream tasks

[ ] Zero-shot learning performance

Releases

No releases published

Packages

 
 
 

Contributors