Skip to content

[graph_trainer] Device-hoist CooR artifacts (drop --virtual-local-rank)#3620

Draft
bobrenjc93 wants to merge 1 commit into
bobren/gt-full-inductor-tp-constantsfrom
bobren/gt-device-hoisting
Draft

[graph_trainer] Device-hoist CooR artifacts (drop --virtual-local-rank)#3620
bobrenjc93 wants to merge 1 commit into
bobren/gt-full-inductor-tp-constantsfrom
bobren/gt-device-hoisting

Conversation

@bobrenjc93

Copy link
Copy Markdown
Contributor

Summary

Device-hoists CooR precompile artifacts so each rank runs on its real device
(cuda:{local_rank}) with all GPUs visible, removing the --virtual-local-rank
hack. That hack set CUDA_VISIBLE_DEVICES so every worker saw its GPU as
cuda:0 — which hides peer GPUs and breaks features that need them visible
(e.g. symmetric-memory expert parallelism, #3561).

Depends on the inductor-side runtime_device_index support in PyTorch
(separate PR). Stacked on #3596.

How it works

A CooR artifact is compiled on one rank (cuda:0) but loaded on all ranks. Two
device-baked layers are hoisted:

  1. Inductor regionsinductor_passes.py enables
    torch._inductor.config.runtime_device_index during precompile (gated on
    compile_on_one_rank), for full and regional inductor. The generated wrapper
    then derives the launch device from an input tensor
    (_runtime_device = arg0_1.device.index) instead of the baked compile-time
    index — an explicit, data-driven contract that stays correct regardless of
    ambient torch.cuda.current_device().
  2. Eager FX graphprecompile._hoist_graph_device_to_current remaps the
    deserialized graph's baked constants (plain attrs, registered buffers, params)
    and device= literals (torch.device and string forms, incl. nested args)
    onto the current device at load. This fixes eager ops (e.g. vocab-parallel
    embedding / SDPA) mixing a cuda:0 constant with cuda:{local_rank}
    activations.

run_train_precompile.sh drops --virtual-local-rank. A unit test for the FX
remap is added.

Validation (llama3 debugmodel, TP=2 × FSDP=4, no --virtual-local-rank)

config result step-1/2 loss
full 8.13228 / 7.81150
regional 8.13265 / 7.81226
full, VLR (back-compat) 8.13228 / 7.81150

Full no-VLR is bitwise-identical to the canonical VLR path (--debug.seed 42 --debug.deterministic): 8.13228 / 7.81150 / 7.04731.

Follow-ups before un-drafting

  • Multi-rank integration test asserting tensor.device.index == local_rank mid-step.
  • Full-precision loss+grad_norm parity (all ranks) via scripts/loss_compare.py.
  • EP / CP configs to exercise peer/comm kernels under the input-derived device contract.
  • Land the PyTorch runtime_device_index PR first.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@bobrenjc93 bobrenjc93 force-pushed the bobren/gt-full-inductor-tp-constants branch from a2440e7 to 5840e06 Compare June 10, 2026 20:41
A CooR precompile artifact is compiled on one rank (cuda:0) but loaded on every
rank. Until now that required torchrun --virtual-local-rank, which sets
CUDA_VISIBLE_DEVICES so every worker sees its GPU as cuda:0. That hack hides
peer GPUs and breaks features that need them visible (e.g. symmetric-memory
expert parallelism, #3561).

This device-hoists the artifact so each rank runs on its real device
(cuda:{local_rank}) with all GPUs visible:

- inductor_passes.py: enable torch._inductor.config.runtime_device_index during
  precompile (gated on compile_on_one_rank), for full and regional inductor.
  The inductor regions then resolve the launch device from their input tensors
  at runtime instead of the baked compile-time index.
- precompile.py: _hoist_graph_device_to_current remaps the eager FX graph's
  baked constants (plain attrs, buffers, params) and device= literals (torch.device
  and string forms, incl. nested args) onto the current device at load.
- run_train_precompile.sh: drop --virtual-local-rank.
- test_precompile.py: unit test for the FX device remap.

Requires the inductor-side runtime_device_index support (separate pytorch PR).

Verified e2e on llama3 debugmodel (TP=2 x FSDP=4), full and regional inductor,
without --virtual-local-rank, bitwise-equal loss/grad_norm to the prior VLR path.
@bobrenjc93 bobrenjc93 force-pushed the bobren/gt-device-hoisting branch from 6e91e50 to 2f45776 Compare June 10, 2026 20:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant