Skip to content

Tuner12/Shazam

Repository files navigation

Shazam Logo

Shazam

A lightweight model for feature knowledge distillation using histopathology foundational models.


📌 Project Overview

Shazam proposes a small and efficient model that distills knowledge from extracted features using histopathology foundational models. This approach effectively leverages the strong representational power of large-scale foundational models while optimizing computational efficiency through a lightweight distillation process.

✅ Key Highlights:

  • Feature Knowledge Distillation
    Transfers rich representations from foundational models into a smaller, more efficient model.

  • Lightweight and Scalable
    Achieves high accuracy with lower computational cost, suitable for practical deployment in clinical settings.

  • Superior Performance
    Outperforms existing CPath models and other fusion-based methods across multiple evaluation benchmarks.

🔬 Shazam v2

📂 Project Structure

Project Structure

  1. Feature Extraction: Leverages pretrained foundational histopathology models to extract low-level, mid-level and high-level features from images.
  2. Knowledge Distillation: A small model learns to replicate the representational power of the foundational models.
  3. Model Evaluation: The distilled model is evaluated and compared against existing methods like Virchow2.

This pipeline supports survival prediction using multi-teacher distillation from foundational models.

  1. Case-to-feature Mapping

    • File: survival_analysis/jsonlink.py
    • Map case IDs to feature .pt paths using a JSON dictionary.
  2. WSI Patch Extraction

    • File: CLAM/create_patches_features_fp.py
    • Cut patches from WSIs and store in .h5 files.
    • ⚠️ If patches/ contains fewer .h5 files than the number of WSIs, verify the original .svs slides.
  3. CSV Splitting for Multi-GPU

    • File: survival_analysis/splitcsv.py
    • Generate per-fold CSV files for multi-GPU training.
  4. Feature Extraction with Multi-teacher Models

    • Files: CLAM/extract_BRCA4cls.sh
    • Extract features using foundational models (Virchow2, Uni_v2, etc.).
  5. Single-model Training

    • Files: survival_analysis/single_BRCA4cls.sh
    • Train baseline single-model (non-distilled) classifiers.
  6. Multi-teacher Distillation Training

    • File: Shazam_v2/multi_moe_distill_v3.py Shazam_v2/multi_moe_distill4cls.py
    • Train student model with attention-based distillation across modalities.


📂 Shazam v1

Project Structure

  1. Feature Extraction: Leverages pretrained foundational histopathology models to extract high-level features from images.
  2. Knowledge Distillation: A small model learns to replicate the representational power of the foundational models.
  3. Model Evaluation: The distilled model is evaluated and compared against existing methods like Virchow2.

⚙️ Environment Setup

We directly use the environment configuration provided by the CLAM project.

1. Create the Conda Environment

conda env create -f env.yml

2. Activate the Environment

conda activate clam_latest

3. Train the Model

python train.py 

Tutorial for Shazam

This section explains the end-to-end tensor shape transformations inside the CrossAttentionClassifierWithDistillation model.

🔢 Input Tensors

Each feature .pt file contains a tuple:

(features, labels) = torch.load("xxx_features.pt")
  • features: shape = [N, C_i]
    where:

    • N: number of patches (WSIs)
    • C_i: feature dimension of model i, e.g., 1280 (Virchow), 1024 (Uni), etc.
  • labels: shape = [N] (long, class indices)

During training:

train_dataset = TensorDataset(*train_features_list, train_labels)

which means input to model:

features = [x1, x2, x3, x4]   # x_i shape: [B, C_i]

🧠 Step 1: Feature Mapping

Each foundational model's features x_i ∈ [B, C_i] are mapped into a shared dimension d_model:

Output shape: `[B, d_model]` for each modality

🧠 Step 2: Stack Features Across Modalities

After mapping:

features_stacked = torch.stack([mapped_1, mapped_2, mapped_3, mapped_4], dim=1)
  • Shape: [B, 4, d_model]
    (treat each feature source as a token in attention)

🔁 Step 3: Self-Attention Layers

Each layer applies attention across the 4 modalities (tokens):

Q, K, V: [B, 4, d_model] → AttentionOutput: [B, 4, d_model]

Repeated num_layers times (e.g. 5).


🔄 Step 4: Feature Fusion

fused_features = features.mean(dim=1)
  • Shape: [B, d_model]
    (aggregated representation for classification)

🎯 Step 5: Classifier

fused_featuresLinearReLULayerNormLinearlogits
  • Output logits: [B, num_classes]

🧪 Step 6: Feature Distillation Loss

For distillation:

student_features: [B, d_model]
expert_features_list: [B, C_i]  # for each i
mapped_expert = FeatureMapper(C_id_model)
  • Compute cosine + Huber loss between student_features and each mapped_expert

🧾 Final Summary

Stage Shape Description
Raw Input [B, C_i] One per modality
After Mapping [B, d_model] × 4 Standardized into shared dimension
Stack (4 modalities) [B, 4, d_model] Cross-attention input
After Cross-Attention [B, 4, d_model] Contextually refined features
Mean Fusion [B, d_model] Aggregated single representation
Classifier Output [B, num_classes] Final prediction logits
Expert Mapping [B, d_model] Used in distillation loss


About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors