Skip to content

[feature] Multi-GPU support#17

Closed
yoshikisd wants to merge 125 commits intocdtools-developers:masterfrom
yoshikisd:feature/multi-gpu
Closed

[feature] Multi-GPU support#17
yoshikisd wants to merge 125 commits intocdtools-developers:masterfrom
yoshikisd:feature/multi-gpu

Conversation

@yoshikisd
Copy link
Collaborator

@yoshikisd yoshikisd commented Feb 24, 2025

Summary of added feature (as of 07/23/2025)

Check out this comment for a summary of features added to enable multi-GPU support for CDTools. This update also includes a separation of the optimizer from the CDIModel, which was done in an attempt to directly use PyTorch's DistributedDataParallel.

The outdated first comment

This PR is a starting point to address Issue #8: adding multi-GPU support for CDTools.

This is a work-in-progress. I'm interested in exploring a couple different parallelization approaches while trying to preserve the simplicity of the high-level CDTools interface and ensure backwards-compatability. This PR is not in a state that I feel would be ready to merge into the master branch. I've submitted this as a draft request to see if there's any thoughts you folks may have about handling multi-GPU support. If you have any recommendations on things to try/test, I'd be happy to discuss it!

Multi-GPU implementation based on DistributedDataParallel

I've gotten one naive implementation of multi-GPU support operational using PyTorch DistributedDataParallel to perform data parallelism (more details here https://pytorch.org/tutorials/beginner/dist_overview.html):

  • 2ad6db2, ad3d816: Made some changes to the base CDIModel to change how the DataLoader is set up and sampled if multiple GPUs are used.
  • c393f47, fc40b77, e8ed010, 94dcc3f, efe2ac1: Created a multi-GPU version of the examples/fancy_ptycho.py called examples/fancy_ptycho_multi_gpu_ddp.py. The dataset and model inspection methods work even when using multiple GPUs.
  • 1eb8704: Created examples/fancy_ptycho_multi_gpu_ddp_speed_test.py to perform a comparative test of the reconstruction speed/loss as a function of the number of GPUs used. This is based on examples/fancy_ptycho.py and examples/fancy_ptycho_multi_gpu_ddp.py

Below is an output from examples/fancy_ptycho_multi_gpu_ddp_speed_test.py tested using up to 2 NVIDIA RTX 6000 Ada Generation cards on a Linux server. Both a single- and double-GPU test were ran with 2 trials over 100 total epochs. The plots show the mean and standard deviation of the time each epoch was measured at (timer started before the dataset was loaded) as well as the associated loss at each epoch. The horizontal shift between the two plots likely reflect the longer time it takes to load the model to multiple GPUs. The width of the 2 GPU curve (i.e., the total time taken for reconstruction) is roughly half that of the 1 GPU curve.
Pasted image 20250223213358

Items to look into

  • Look into why the losses become less stable as GPU counts increases
  • Make sure that the batching scheme is behaving as expected as GPU counts increases
  • Avoid changing the high-level interface of CDTools
  • Add PyTests
  • Look into multi-GPU functionality through Hugging Face Accelerator

Diagnosing issues

I've had issues with running parallelized PyTorch scripts which may not be caused by PyTorch itself, but rather stems from communication issues between NVIDIA GPUs via the NCCL (NVIDIA Collective Communications Library) backend. These issues seem to strongly depend on the exact details of how the machine is set up. To see if there's an issue with NCCL, build and run the following test from https://github.com/NVIDIA/cuda-samples:

  • cuda-samples/Samples/0_Introduction/simpleP2P and/or
  • cuda-samples/Samples/5_Domain_Specific/p2pBandwidthLatencyTest (per the NCCL troubleshooting guide)

If this test fails, or hangs for several minutes, you may have a GPU-GPU communication issue.

Another symptom of communication-related hanging is if all activated GPUs report 100% usage on nvidia-smi and nothing seems to be happening.

I've included some websites below which may be helpful for solving your issues.

@yoshikisd yoshikisd changed the title Feature/multi gpu [WIP] Support for multi-GPU Feb 24, 2025
@yoshikisd yoshikisd linked an issue Feb 24, 2025 that may be closed by this pull request
@gnzng

This comment was marked as resolved.

@yoshikisd

This comment was marked as resolved.

@yoshikisd

This comment was marked as outdated.

@yoshikisd yoshikisd marked this pull request as ready for review July 21, 2025 02:53
@yoshikisd yoshikisd marked this pull request as draft July 21, 2025 03:38
@yoshikisd
Copy link
Collaborator Author

yoshikisd commented Jul 21, 2025

I've gotten both the reconstructor and multi-GPU implementation into a stable state that's ready for review. It may be best to have a call to go over these changes. But I'll leave this documentation here providing an overview of the changes made with this PR.

cdtools.reconstructors

Basic idea:

Module cdtools.reconstructors handles the optimization (‘reconstruction’) of CDTools ptychography models given a CDIModel and a corresponding CDataset. Similar to so-called ‘trainers’ used by the PyTorch community, the Reconstructor class and subclasses set up nested loops to run optimization steps and data loading steps.

This module was created to separate optimization-related code from CDIModel. The motivation of this change was to make the CDTools more compatible with other PyTorch-related packages which expect libraries/scripts to follow specific code patterns. Specifically, problems arise if optimization is a method of torch.nn.module/cdtools.models.CDIModel, which was the case as of v0.3.0.pypi1.

Implementation:

cdtools.reconstructors is largely based on methods from cdtools.models.CDIModel. For the base class cdtools.reconstructors.Reconstructor, the method(s):

  • _run_epoch and optimize are based on cdtools.models.CDIModel.AD_Optimize.
  • setup_dataloader is based on similar blocks of data-loading code present in cdtools.models.CDIModel.{Adam, LBFGS, SGD}_optimize
  • adjust_optimizer is not implemented and must be defined within the subclasses to adjust the optimizer hyperparameters.

For the subclasses cdtools.reconstructors.{Adam, LBFGS, SGD}, the method(s):

  • __init__ initializes an optimizer attribute as a torch.optim.{Adam, LBFGS, SGD} using the model parameters.
  • adjust_optimizer changes the relevant optimizer hyperparameters
  • optimize are based on `cdtools.models.CDIModel.{Adam, LBFGS, SGD}

The main change to the code behavior is that the hyperparameters are updated between different reconstruction loops without creating a new torch.optim.{Adam, LBFGS, SGD}. In some cases, this can result in an improvement in reconstruction quality as discussed in (#17 (comment)).

Changes in cdtools.models.CDIModel and backwards compatibility

The following changes have been made in cdtools.models.CDIModel methods:

  • AD_Optimize has been removed.
  • {Adam, LBFGS, SGD}_Optimize now initialize a reconstructor attribute as a cdtools.reconstructors.{Adam, LBFGS, SGD} and make a call to their optimize method.

The latter change allows scripts written using v0.3.0.pypi1 to be backwards compatible with this update.

However, this comes at the cost of creating a circular dependency between cdtools.models.CDIModel and cdtools.reconstructors. From a reconstruction standpoint it doesn’t seem to cause any issues. However, it’s not clear if this may cause issues in the long term in terms of maintainability. Thus, we should consider deprecating {Adam, LBFGS, SGD}_Optimize in favor of the following code pattern:

  import cdtools

  filename = ...
  dataset = ...
  model = cdtools.models.FancyPtycho.from_dataset(...)

  device = 'cuda'
  model.to(device=device)
  dataset.get_as(device=device)

+ recon = cdtools.reconstructor.Adam(model, dataset)

- for loss in model.Adam_optimize(50, dataset, lr=0.02, batch_size=10):
+ for loss in recon.optimize(50, lr=0.02, batch_size=10):
    print(model.report())
    if model.epoch % 10 == 0:
        model.inspect(dataset)

PyTest

Three tests have been created in test_reconstructors.py which function similar to the tests for the models. These three tests run reconstruction loops using both the traditional method (model.XXX_Optimize) and new method (recon.optimize), but use different optimizers (Adam, LBFGS, SGD), datasets (gold ball dataset for Adam and SGD, siemens star optical data for LBFGS), and models (FancyPtycho for Adam and SGD, RPI for LBFGS). These three tests check the following:

  • Calls to Reconstructor.adjust_optimizer are updating the hyperparameters
  • Ensure that the model attribute of the Reconstructor points to the original model it was provided as an initialization parameter
  • Ensure that reconstructions performed by the old and new methods produce exactly identical results given the same random seed
  • Ensure that the final loss of the reconstructions remain below a specified threshold

Additionally, the first test test_Adam_gold_balls checks if data loading methods specific to single-GPU operations are working as intended.

cdtools.tools.distributed and other changes made to enable multi-GPU

Basic idea:

Module cdtools.tools.distributed enables CDTools reconstruction scripts to be ran as single-node multi-GPU jobs. Parallelization is based on distributed data parallelism, where identical copies of the model are shared across several participating GPUs, and each GPU works on a fraction of the total dataset. After each gradient calculation, the gradients are synchronized across all GPUs before model updates occur.

Multi-GPU jobs can be launched through torchrun or a convenience console script cdt-torchrun without making any changes to a reconstruction script. In the simplest case, a multi-GPU job can be launched as…

cdt-torchrun
    --ngpus=<nGPUs>
    YOUR_RECONSTRUCTION_SCRIPT.py

which is equivalent to the following torchrun call

torchrun
    --standalone
    --nnodes=1
    --nproc_per_node=$nGPUs
    -m cdtools.tools.distributed.single_to_multi_gpu
    --backend=nccl
    --timeout=30
    --nccl_p2p_disable=1
    YOUR_RECONSTRUCTION_SCRIPT.py

All standard plotting and saving methods will be executed from a single GPU to avoid redundant plot/file generation.

Implementation:

The simple "nutshell" explanation for what's being done with multi-GPU CDTool is described by the flowchart below:
Pasted image 20250729152345

The actual implementation with torchrun involves several steps:

Process spawning (torchrun and cdt-torchrun)

  1. Spawn several subprocesses using torchrun, which provides environment variables that inform each process what rank and world_size (total number of GPUs) they have.

Setting up distributive communication and script starting condition (cdtools.tools.distributed.run_single_gpu_script)

  1. In each subprocess, set up additional environment variables to enable/disable NCCL peer-to-peer communication, assign GPU IDs to each subprocess, and force each subprocess to only “see” its assigned GPU (e.g., on a 4 GPU node, the Rank 3 GPU can only use GPU ID 3). This latter behavior enables the use of device=’cuda’ in the reconstruction script rather than device=cuda:RANK.
  2. In each subprocess, initialize the process group using the NCCL backend (by default) which allows each subprocess and their GPUs to communicate with each other through torch.distributed.all_reduce and torch.distributed.broadcast calls.
  3. In each subprocess, broadcast or receive Rank 0’s RNG seed.

Running reconstruction script in each subprocess (CDIModel, CDataset, and cdtools.tools.distributed.sync_and_avg_gradients)

  1. Set up attributes in CDIModel and CDataset for the rank, world_size, and multi_gpu_used to enable plotting/saving methods to be performed off of a single GPU (prevents several duplicate saving/plotting actions) as well as multi-GPU-specific actions (i.e., gradient/loss/hyperparameter synchronization)
  2. When executing a reconstruction loop for the first time (or creating cdtools.reconstructors.Adam directly), set up a distributed data loader which allows each GPU to work on a fraction of the dataset while ensuring the whole dataset is used during each epoch.
  3. When executing a reconstruction loop, adjust hyperparameters, scheduler, and data loader epoch (allows dataset shuffling). During the reconstruction loop, synchronize all gradients, losses, and hyperparameter changes (from the scheduler) between all participating GPUs.
  4. Update the model with the synchronized gradients.
  5. Repeat 7-8 until the user-defined epochs have been completed.

Terminating distributive communication

  1. In each subprocess, close the process group. (cdtools.tools.distributed.run_single_gpu_script)

Flowchart of the multi-GPU job creation/execution
Screenshot 2025-07-21 at 11 20 35

Module cdtools.tools.distributed (specifically the file single_to_multi_gpu.py and function run_single_gpu_script) handles steps 2-6 (excluding 5a-5d). Step 5a is handled within the __init__ of CDIModel and CDataset. Steps 5b-5d are handled within the base Reconstructor class.

Convenience functions in cdtools.tools.distributed

  • run_single_to_multi_gpu: Runs a single-GPU reconstruction script as a single-node multi-GPU job via torchrun. This function can be called in the cli as the python console script cdt-torchrun: As shown earlier, cdt-torchrun is a shorthand for a lengthy torchrun cli command that allows a single-GPU script to be executed as a multi-GPU job.
  • run_speed_test and report_speed_test: These functions are used to run tests to assess the reconstruction performance and speed as a function of varying GPU counts. run_speed_test executes the test using N GPUs and M trials on a reconstruction script which needs to be wrapped by report_speed_test as shown below:
  import cdtools

+ @cdtools.tools.distributed.report_speed_test
+ def main():
      filename = ...
      dataset = ...
      model = cdtools.models.FancyPtycho.from_dataset(...)

      device = 'cuda'
      model.to(device=device)
      dataset.get_as(device=device)

      for loss in recon.optimize(50, lr=0.02, batch_size=10):
        print(model.report())
        if model.epoch % 10 == 0:
            model.inspect(dataset)
+     return model

+ if __name__ == '__main__':
+     main()

The speed test reads each trial's data as a pickle dump file whose location and name is specified by run_speed_test.

What is supported/has been tested by multi-GPU?

In principle, any subclass of the base models discussed below should support multi-GPU operations.

  • Models: FancyPtycho and Bragg2DPtycho have been tested. The following CDIModel plotting and saving methods are tested in the PyTests: save_to_h5, save_on_exit, save_on_exception, compare, save_figures, save_checkpoint.
  • Datasets: Only Ptycho2DDataset has been tested. The following CDataset plotting and saving methods are tested in the PyTests: to_cxi, inspect
  • Reconstructors: Only Adam has been tested.
  • Communication backend: Only NCCL is supported

Example scripts

Two new example scripts are created to demonstrate how the speed test works. fancy_ptycho_speed_test.py is the fancy ptycho example script but modified to make it compatible with the speed test. distributed_speed_test.py runs the actual speed test on fancy_ptycho_speed_test.py.

PyTests

Two tests have been created to assess multi-GPU performance with the pymark multigpu. These tests use 2 GPUs and check whether...

  • Plotting/saving methods of the model and dataset work as intended
  • Reconstruction performance using 2 GPUs is similar to a single GPU (tested via run_speed_test and report_speed_test)

@yoshikisd yoshikisd marked this pull request as ready for review July 21, 2025 18:54
@yoshikisd yoshikisd changed the title [WIP] Multi-GPU support and separation of optimizers from CDIModel [feature] Multi-GPU support and separation of optimizers from CDIModel Jul 21, 2025
from cdtools.reconstructors import Adam
self.reconstructor = Adam(model=self,
dataset=dataset,
subset=subset)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be updated so that if the user changes which optimizer they're using, it doesn't pass a silent error. I would lean towards just creating a new optimizer for each new function (less good behavior, but preserving old behavior for old scripts), but I could be swayed for sure. What do other folks think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should update the examples to use the new pattern with an explicitly constructed reconstructor

def __init__(self,
model: CDIModel,
dataset: CDataset,
subset: Union[int, List[int]] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe subset goes in the optimize function call?

self.model.training_history += self.model.report() + '\n'
return loss

def optimize(self,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we haven't quite hit the perfect code pattern yet for how to update learning rates, minibatch sizes, etc. We should cycle back.

@yoshikisd
Copy link
Collaborator Author

Per discussions with @allevitan, we've decided it would be best to separate this gargantuan PR into two. The first will deal with the creation of the Reconstructors class and refactoring of the CDIModel methods. The second will introduce multi-GPU CDTools.

For now, I'll keep this PR open if (1) we can nicely merge the changes made by the first PR into this one and (2) we can squash the commits. Otherwise, I'll just create another feature branch for multi-GPU.

@yoshikisd yoshikisd changed the title [feature] Multi-GPU support and separation of optimizers from CDIModel [feature] Multi-GPU support Nov 4, 2025
@yoshikisd
Copy link
Collaborator Author

I'll be closing this PR in favor of opening up a different one which, IMO, has a somewhat nicer implementation of multi-GPU operation that avoids having to set up processes to just call torchrun. The cost of doing this is adding a few extra lines of code into the reconstruction script. But, perhaps this implementation may be easier to review.

Also, having >100 commits (majority of them being experimental) might make our lives difficult if we ever need to roll back to a previous version for whatever reason.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add support for multi-GPU

3 participants