A lightweight model for feature knowledge distillation using histopathology foundational models.
Shazam proposes a small and efficient model that distills knowledge from extracted features using histopathology foundational models. This approach effectively leverages the strong representational power of large-scale foundational models while optimizing computational efficiency through a lightweight distillation process.
-
Feature Knowledge Distillation
Transfers rich representations from foundational models into a smaller, more efficient model. -
Lightweight and Scalable
Achieves high accuracy with lower computational cost, suitable for practical deployment in clinical settings. -
Superior Performance
Outperforms existing CPath models and other fusion-based methods across multiple evaluation benchmarks.
- Feature Extraction: Leverages pretrained foundational histopathology models to extract low-level, mid-level and high-level features from images.
- Knowledge Distillation: A small model learns to replicate the representational power of the foundational models.
- Model Evaluation: The distilled model is evaluated and compared against existing methods like Virchow2.
This pipeline supports survival prediction using multi-teacher distillation from foundational models.
-
Case-to-feature Mapping
- File:
survival_analysis/jsonlink.py - Map case IDs to feature
.ptpaths using a JSON dictionary.
- File:
-
WSI Patch Extraction
- File:
CLAM/create_patches_features_fp.py - Cut patches from WSIs and store in
.h5files. ⚠️ Ifpatches/contains fewer.h5files than the number of WSIs, verify the original.svsslides.
- File:
-
CSV Splitting for Multi-GPU
- File:
survival_analysis/splitcsv.py - Generate per-fold CSV files for multi-GPU training.
- File:
-
Feature Extraction with Multi-teacher Models
- Files:
CLAM/extract_BRCA4cls.sh - Extract features using foundational models (Virchow2, Uni_v2, etc.).
- Files:
-
Single-model Training
- Files:
survival_analysis/single_BRCA4cls.sh - Train baseline single-model (non-distilled) classifiers.
- Files:
-
Multi-teacher Distillation Training
- File:
Shazam_v2/multi_moe_distill_v3.pyShazam_v2/multi_moe_distill4cls.py - Train student model with attention-based distillation across modalities.
- File:
- Feature Extraction: Leverages pretrained foundational histopathology models to extract high-level features from images.
- Knowledge Distillation: A small model learns to replicate the representational power of the foundational models.
- Model Evaluation: The distilled model is evaluated and compared against existing methods like Virchow2.
We directly use the environment configuration provided by the CLAM project.
conda env create -f env.ymlconda activate clam_latestpython train.py This section explains the end-to-end tensor shape transformations inside the CrossAttentionClassifierWithDistillation model.
Each feature .pt file contains a tuple:
(features, labels) = torch.load("xxx_features.pt")-
features: shape =[N, C_i]
where:N: number of patches (WSIs)C_i: feature dimension of model i, e.g., 1280 (Virchow), 1024 (Uni), etc.
-
labels: shape =[N](long, class indices)
During training:
train_dataset = TensorDataset(*train_features_list, train_labels)which means input to model:
features = [x1, x2, x3, x4] # x_i shape: [B, C_i]Each foundational model's features x_i ∈ [B, C_i] are mapped into a shared dimension d_model:
Output shape: `[B, d_model]` for each modalityAfter mapping:
features_stacked = torch.stack([mapped_1, mapped_2, mapped_3, mapped_4], dim=1)- Shape:
[B, 4, d_model]
(treat each feature source as a token in attention)
Each layer applies attention across the 4 modalities (tokens):
Q, K, V: [B, 4, d_model] → Attention → Output: [B, 4, d_model]Repeated num_layers times (e.g. 5).
fused_features = features.mean(dim=1)- Shape:
[B, d_model]
(aggregated representation for classification)
fused_features → Linear → ReLU → LayerNorm → Linear → logits- Output logits:
[B, num_classes]
For distillation:
student_features: [B, d_model]
expert_features_list: [B, C_i] # for each i
mapped_expert = FeatureMapper(C_i → d_model)- Compute cosine + Huber loss between
student_featuresand eachmapped_expert
| Stage | Shape | Description |
|---|---|---|
| Raw Input | [B, C_i] |
One per modality |
| After Mapping | [B, d_model] × 4 |
Standardized into shared dimension |
| Stack (4 modalities) | [B, 4, d_model] |
Cross-attention input |
| After Cross-Attention | [B, 4, d_model] |
Contextually refined features |
| Mean Fusion | [B, d_model] |
Aggregated single representation |
| Classifier Output | [B, num_classes] |
Final prediction logits |
| Expert Mapping | [B, d_model] |
Used in distillation loss |


