[graph_trainer] Device-hoist CooR artifacts (drop --virtual-local-rank)#3620
Draft
bobrenjc93 wants to merge 1 commit into
Draft
[graph_trainer] Device-hoist CooR artifacts (drop --virtual-local-rank)#3620bobrenjc93 wants to merge 1 commit into
bobrenjc93 wants to merge 1 commit into
Conversation
a2440e7 to
5840e06
Compare
A CooR precompile artifact is compiled on one rank (cuda:0) but loaded on every rank. Until now that required torchrun --virtual-local-rank, which sets CUDA_VISIBLE_DEVICES so every worker sees its GPU as cuda:0. That hack hides peer GPUs and breaks features that need them visible (e.g. symmetric-memory expert parallelism, #3561). This device-hoists the artifact so each rank runs on its real device (cuda:{local_rank}) with all GPUs visible: - inductor_passes.py: enable torch._inductor.config.runtime_device_index during precompile (gated on compile_on_one_rank), for full and regional inductor. The inductor regions then resolve the launch device from their input tensors at runtime instead of the baked compile-time index. - precompile.py: _hoist_graph_device_to_current remaps the eager FX graph's baked constants (plain attrs, buffers, params) and device= literals (torch.device and string forms, incl. nested args) onto the current device at load. - run_train_precompile.sh: drop --virtual-local-rank. - test_precompile.py: unit test for the FX device remap. Requires the inductor-side runtime_device_index support (separate pytorch PR). Verified e2e on llama3 debugmodel (TP=2 x FSDP=4), full and regional inductor, without --virtual-local-rank, bitwise-equal loss/grad_norm to the prior VLR path.
6e91e50 to
2f45776
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Device-hoists CooR precompile artifacts so each rank runs on its real device
(
cuda:{local_rank}) with all GPUs visible, removing the--virtual-local-rankhack. That hack set
CUDA_VISIBLE_DEVICESso every worker saw its GPU ascuda:0— which hides peer GPUs and breaks features that need them visible(e.g. symmetric-memory expert parallelism, #3561).
How it works
A CooR artifact is compiled on one rank (cuda:0) but loaded on all ranks. Two
device-baked layers are hoisted:
inductor_passes.pyenablestorch._inductor.config.runtime_device_indexduring precompile (gated oncompile_on_one_rank), for full and regional inductor. The generated wrapperthen derives the launch device from an input tensor
(
_runtime_device = arg0_1.device.index) instead of the baked compile-timeindex — an explicit, data-driven contract that stays correct regardless of
ambient
torch.cuda.current_device().precompile._hoist_graph_device_to_currentremaps thedeserialized graph's baked constants (plain attrs, registered buffers, params)
and
device=literals (torch.deviceand string forms, incl. nested args)onto the current device at load. This fixes eager ops (e.g. vocab-parallel
embedding / SDPA) mixing a
cuda:0constant withcuda:{local_rank}activations.
run_train_precompile.shdrops--virtual-local-rank. A unit test for the FXremap is added.
Validation (llama3 debugmodel, TP=2 × FSDP=4, no --virtual-local-rank)
Full no-VLR is bitwise-identical to the canonical VLR path (
--debug.seed 42 --debug.deterministic): 8.13228 / 7.81150 / 7.04731.Follow-ups before un-drafting
tensor.device.index == local_rankmid-step.scripts/loss_compare.py.runtime_device_indexPR first.