Disentangling Visual Transformers: Patch-level Interpretability for Image Classification
Official Code Repository
Download ImageNet pretrained weights and config files from HuggingFace here.
Note: We do not provide pretrained weights for other datasets, but you can retrain them yourself.
Preparing datasets is straightforward — just download them, no postprocessing required.
For ImageNet, organize the folder as follows:
path/to/imagenet/
├── train/
│ ├── class1/
│ ├── class2/
│ └── ...
└── val/
├── class1/
├── class2/
└── ...
Other datasets:
- Stanford Dogs: Download
- Stanford Cars: (Download link not found — feel free to help!)
- CUB-200-2011: Download
- Aircrafts & Oxford Pets: Will be downloaded automatically.
Adding a custom dataset is simple! Modify the build_dataset function inside core.datasets.py to integrate your dataset.
We provide multiple evaluation scripts for both quantitative and qualitative assessment of HiT.
Run the following to compute accuracy:
python main.py --eval \
--data-path $IMNETPATH --batch 512 --input-size 224 \
--eval-crop-ratio 0.875 \
--config-file $CONFIGFILE \
--mini-batches 1 --output_dir $OUTPUTPATH \
--num_workers 8 \
--resume path/to/checkpointRun this to compute insertion/deletion scores:
python CausalMetrics.py --gpu-id $GPU \
--config-file $CONFIGFILE \
--weights $WEIGHTS \
--data-path $DATAPATH \
--data-set $DATASET \
--cam-type base \
--model.params.num_classes $NUMCLASSES int
⚠️ Note: The metric is unnormalized. You'll need to normalize it manually.
Similar to the insertion metric — just add an output path:
python visualization.py --gpu-id $GPU \
--config-file $CONFIGFILE \
--weights $WEIGHTS \
--data-path $DATAPATH \
--data-set $DATASET \
--cam-type base \
--output-path $OUTPUTPATH \
--model.params.num_classes $NUMCLASSES intUse layer-saliency.py, which uses the same arguments:
python layer-saliency.py --gpu-id $GPU \
--config-file $CONFIGFILE \
--weights $WEIGHTS \
--data-path $DATAPATH \
--data-set $DATASET \
--cam-type base \
--model.params.num_classes $NUMCLASSES intWe trained our models using 8 A100 GPUs with the following:
IMNETPATH=path/to/imagenet
OUTPUTPATH=path/to/output
CONFIGFILE=configs/hit-b.yaml
torchrun --standalone --nnodes=1 --nproc_per_node=8 main.py \
--data-path $IMNETPATH --batch 512 --lr 8e-4 \
--epochs 600 --weight-decay 0.05 --sched cosine --input-size 224 \
--eval-crop-ratio 0.875 --reprob 0.0 --warmup-epochs 20 \
--drop 0.0 --seed 0 --opt adamw --warmup-lr 1e-6 \
--mixup 0.8 --drop-path 0.05 --cutmix 1.0 --unscale-lr --repeated-aug \
--smoothing 0.1 --color-jitter 0.3 --ThreeAugment \
--config-file $CONFIGFILE \
--mini-batches 1 --output_dir $OUTPUTPATH \
--num_workers 8 --distributed --world_size 8 \
--model.params.drop_path 0.0 float --model.params.attention_dropout 0.2 floatTo override config values on-the-fly, use:
--model.params.param_to_change VALUE TYPEExample: --model.name hit-no-pool str disables pooling.
To resume training:
--resume $OUTPUTPATH/checkpoint.pthTo finetune on a new dataset:
python main.py \
--data-path $DATAPATH \
--finetune $PATHWEIGHTS \
--data-set $DATASET \
[...] \
--model.params.num_classes $NUMCLASSES intNote: Multi-node + multi-GPU setups were not tested.
If you use this work, please cite us:
@InProceedings{jeanneret2025disentangling,
author = {Jeanneret, Guillaume and Simon, Lo{\"\i}c and Jurie, Fr{\'e}d{\'e}ric},
title = {Disentangling Visual Transformers: Patch-level Interpretability for Image Classification},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops},
month = {June},
year = {2025}
}We thank the DeiT repository for their amazing work!
If you find a bug or something doesn't work, feel free to reach out or open an issue!
