Skip to content

Add Multi-GPU Sigmoid-based Loss Training Pipeline#2

Open
tintindas wants to merge 1 commit into
GenIntel:mainfrom
tintindas:sigmoid_loss
Open

Add Multi-GPU Sigmoid-based Loss Training Pipeline#2
tintindas wants to merge 1 commit into
GenIntel:mainfrom
tintindas:sigmoid_loss

Conversation

@tintindas

Copy link
Copy Markdown

Summary

This PR integrates a new distributed training script into the NOVUM framework that enables multi-GPU training with a Sigmoid-based loss.

The implementation uses torchrun for process management and torch.distributed for synchronization.

Key Changes

Distributed Training Setup

  • Added setup() and cleanup() helpers for initializing/destroying torch.distributed process groups with nccl backend.

  • Automatically assigns devices per rank for multi-GPU scaling.

  • Training script can be launched via:

    torchrun --standalone --nnodes=1 --nproc_per_node=2 src/multi_train.py \
    --config config/default.yaml \
    --experiment_name <name_of_experiment>

Model & Feature Bank

  • Integrated NetE2E backbone wrapped in DistributedDataParallel.

  • Set up FeatureBank with support for SigLip-based updates (forward_siglip).

  • Added checkpoint saving that includes model state, FeatureBank memory, and experiment metadata.

Loss Function

  • Introduced SigmoidLoss as the main criterion for contrastive learning between image features and feature bank embeddings.

  • Learnable parameters: t_prime (logit scale) and b (logit bias).

Data Handling

  • DistributedSampler for sharding dataset across GPUs.

  • Ensures reproducible shuffling across epochs via sampler.set_epoch(epoch).

Feature Bank Synchronization

  • After each epoch, synchronizes FeatureBank memory across GPUs via all_reduce.

  • Ensures consistency of negative sample pool across distributed processes.

Checkpointing

  • Saves model + FeatureBank state every 5 epochs.

  • Includes timestamp and args for reproducibility.

Validation

  • Verified training launches correctly with torchrun using 2 GPUs.

  • Confirmed metrics written to CSV files with headers.

  • Ensured FeatureBank synchronization produces consistent embeddings across ranks.

  • Checkpoint files correctly include DDP-wrapped model state and FeatureBank memory.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant