sketch: preliminary data parallel training#2
Open
ckgresla wants to merge 3 commits into
Open
Conversation
Single-process, multi-GPU data parallelism that preserves MegaTrain's CPU-offloaded layer streaming architecture. Each GPU gets its own execution context (buffers, streams, events, templates, grad slabs) while sharing CPU master weights and optimizer states. Key changes: - _GPUContext dataclass encapsulates per-GPU state - forward_and_backward orchestrator splits batches across GPUs via threads - Each GPU's gradients use global_valid_tokens for correct loss scaling - _prepare_4d_causal_mask for SDPA/eager attention compatibility - Config: devices list, world_size, --devices CLI arg Numerical validation (Gemma 2 2B, eager attn, deterministic): | Test | Loss diff | Grad diff (abs) | Grad diff (rel) | |-------------------------+-----------+-----------------+-----------------| | float32: 1 GPU vs 4 GPU | 1.4e-9 | 2.5e-5 | 2.8e-5 | | bf16: 1 GPU vs 4 GPU | 6.4e-3 | 3.1e-2 | ~1% | | bf16: 1 GPU vs 2 GPU | 1.3e-3 | 1.6e-2 | ~0.8% | Walltime (Qwen3.5-0.8B, flash_attention_2, batch=8, 5 steps): | Config | Time/step | Speedup | |--------+-----------+---------| | 1 GPU | 2.31s | 1.0x | | 2 GPUs | 3.58s | 0.64x | | 4 GPUs | 7.48s | 0.31x | Multi-GPU is currently slower due to GIL-serialized CPU->pinned memory copies in _load_layer_to_buffer_async. Shared pinned buffer optimization needed for actual speedup. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Switch from fork to spawn to resolve CUDA fork guard. Workers receive SharedState via pickle with zero-copy shm handles for CPU master weights. - mp_state.py: SharedState with shared-memory modules, grad tensors, spawn-context queues/locks, cudaHostRegister for pinned DMA speed - mp_worker.py: per-GPU worker process with independent GPU context, forward/backward, K-slab grad D2H, lock-based grad accumulation - cpu_master.py: _init_multiprocessing spawns workers, overlapped grad zeroing, _forward_and_backward_multiprocess dispatches to workers - train.py: removed fork start method (spawn context used internally) Gradient equivalence validated: float32 DP4: loss diff 1.4e-9, grad diff 2.5e-5 bf16 DP4: loss diff 6.4e-3, grad diff 3.9e-1 bf16 DP2: loss diff 1.3e-3, grad diff 2.8e-1 Throughput (Gemma 2B, 2x 4090): 1 GPU: 1518 tok/s 2 GPUs: 2462 tok/s (1.62x) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
6 GPU bs=48 vs 1 GPU bs=8 on Gemma 2B, 24 steps each. Result: 7.76x throughput (274 → 2124 tok/s). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Owner
|
Hi, thanks for the PR — really appreciate the contribution! We’ll take some time to carefully review the changes and get back to you as soon as possible. If everything looks good, we’ll proceed with merging. Thanks again! 🚀 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Great work with the MegaTrain paper -- I thought it would be cool if we could have data parallel training and took a swing at it.
Would you folks be open to adding multigpu (and perhaps multinode) support?
Summary ala Claude
Adds N-GPU data parallelism to MegaTrain's CPU-offloaded streaming architecture. Each GPU gets its own worker process via
torch.multiprocessing(spawn), shares CPU master weights through shared memory, and accumulates gradients with a lock. Single-GPU path is unchanged.infinity/model/mp_state.py— SharedState: shared-memory weights, grads, spawn-context queues/locks,cudaHostRegisterfor pinned DMA speedinfinity/model/mp_worker.py— Worker process: GPU context, forward/backward, K-slab grad D2H, lock-based grad accumulationinfinity/model/cpu_master.py—_init_multiprocessing(),_forward_and_backward_multiprocess(), overlapped grad zeroinginfinity/config/training.py—deviceslist,world_size, batch divisibility validationexamples/train.py—--devicesCLI arg, multi-GPU metricsexamples/dp_compare.py— Throughput comparison scripttests/test_data_parallel.py— Gradient equivalence testsKey design decisions
.share_memory()for zero-copy picklecudaHostRegisteron shared-memory flats recovers pinned DMA speed (~1.8x H2D improvement)grad_lockfor cross-process gradient accumulationGradient equivalence (Gemma 2B, eager attention)
All tests pass (
pytest tests/test_data_parallel.py):float32 is essentially exact. bf16 divergence is comparable to standard PyTorch DDP:
Both in the same order of magnitude (~1.4x factor). Our slightly larger divergence is expected: DDP uses NCCL all-reduce with fp32 reduction buckets, while we accumulate bf16 gradients directly via
add_()in shared memory.Throughput (Gemma 2B, 24 training steps, alpaca_gpt4_en)
Loss curves (24 steps)
Both runs converge. The 6-GPU run has higher absolute loss because each step sees 6x more diverse samples (larger effective batch). Per-sample convergence is expected to be slightly slower with larger batch — standard DP behavior.
Step-by-step loss comparison
Hardware
/dev/shmbacked shared memory for cross-process weight/grad sharingTest plan
pytest tests/test_data_parallel.py— 3/3 gradient equivalence tests pass🤖 Generated with Claude Code