Skip to content

sketch: preliminary data parallel training#2

Open
ckgresla wants to merge 3 commits into
DLYuanGod:mainfrom
ckgresla:dp
Open

sketch: preliminary data parallel training#2
ckgresla wants to merge 3 commits into
DLYuanGod:mainfrom
ckgresla:dp

Conversation

@ckgresla
Copy link
Copy Markdown

@ckgresla ckgresla commented Apr 10, 2026

Great work with the MegaTrain paper -- I thought it would be cool if we could have data parallel training and took a swing at it.

Would you folks be open to adding multigpu (and perhaps multinode) support?

Summary ala Claude

Adds N-GPU data parallelism to MegaTrain's CPU-offloaded streaming architecture. Each GPU gets its own worker process via torch.multiprocessing (spawn), shares CPU master weights through shared memory, and accumulates gradients with a lock. Single-GPU path is unchanged.

  • infinity/model/mp_state.py — SharedState: shared-memory weights, grads, spawn-context queues/locks, cudaHostRegister for pinned DMA speed
  • infinity/model/mp_worker.py — Worker process: GPU context, forward/backward, K-slab grad D2H, lock-based grad accumulation
  • infinity/model/cpu_master.py_init_multiprocessing(), _forward_and_backward_multiprocess(), overlapped grad zeroing
  • infinity/config/training.pydevices list, world_size, batch divisibility validation
  • examples/train.py--devices CLI arg, multi-GPU metrics
  • examples/dp_compare.py — Throughput comparison script
  • tests/test_data_parallel.py — Gradient equivalence tests

Key design decisions

  • spawn (not fork) to avoid CUDA fork guard — modules use .share_memory() for zero-copy pickle
  • cudaHostRegister on shared-memory flats recovers pinned DMA speed (~1.8x H2D improvement)
  • Global grad_lock for cross-process gradient accumulation
  • Overlapped grad zeroing — main zeros grads while workers do H2D, hiding the cost

Gradient equivalence (Gemma 2B, eager attention)

All tests pass (pytest tests/test_data_parallel.py):

+-------------------------------------------+------------+----------------+
| Test                                      | Loss diff  | Max grad diff  |
+-------------------------------------------+------------+----------------+
| float32: 1 GPU bs=4 vs 4 GPUs bs=1       | 1.42e-09   | 2.46e-05       |
| bf16:    1 GPU bs=4 vs 4 GPUs bs=1       | 6.36e-03   | 7.81e-02       |
| bf16:    1 GPU bs=4 vs 2 GPUs bs=2       | 1.27e-03   | 4.69e-02       |
+-------------------------------------------+------------+----------------+

float32 is essentially exact. bf16 divergence is comparable to standard PyTorch DDP:

+--------------------------------------------+------------+------------+
| bf16 comparison (Gemma 2B, 2 GPUs)         | Max abs    | Max rel    |
+--------------------------------------------+------------+------------+
| Standard DDP (HF fwd, NCCL all-reduce)     |   3.12e-02 |   1.98e-01 |
| Our DP (streaming fwd, shm grad accum)     |   4.69e-02 |   2.83e-01 |
+--------------------------------------------+------------+------------+

Both in the same order of magnitude (~1.4x factor). Our slightly larger divergence is expected: DDP uses NCCL all-reduce with fp32 reduction buckets, while we accumulate bf16 gradients directly via add_() in shared memory.

Throughput (Gemma 2B, 24 training steps, alpaca_gpt4_en)

+---------------------------------------+---------------+---------------+
| Config                                |   Step time   |   Throughput  |
+---------------------------------------+---------------+---------------+
| 1 GPU  [0]       bs=8   (8/gpu)      |     3.74s     |    274 tok/s  |
| 6 GPUs [0-5]     bs=48  (8/gpu)      |     2.89s     |   2124 tok/s  |
+---------------------------------------+---------------+---------------+
  Throughput speedup: 7.76x

Loss curves (24 steps)

1 GPU (bs=8)                          6 GPUs (bs=48, 8/gpu)

Loss                                  Loss
3.2 |*                                3.4 |*
    |                                     |
2.8 |                                 2.8 |
    |                                     |
2.4 |                                 2.4 | * *     *       * *   *
    | *                                   |   * * * * * * *     *   * * *
2.0 |                                 2.0 |               *           * *
    |   *                                 |
1.8 |                                 1.6 |
    | * * *   * *                         |
1.6 |       *       *                 1.2 |
    |                                     |
1.4 |         * *     * *   *  *      0.8 |
    |                       *             |
1.2 |                   *             0.4 |
    +--+--+--+--+--+--+--+--+--->        +--+--+--+--+--+--+--+--+--->
    0  3  6  9  12 15 18 21 24         0  3  6  9  12 15 18 21 24
                Step                              Step

Both runs converge. The 6-GPU run has higher absolute loss because each step sees 6x more diverse samples (larger effective batch). Per-sample convergence is expected to be slightly slower with larger batch — standard DP behavior.

Step-by-step loss comparison

+------+--------------+--------------+
| Step |   1GPU loss  |   6GPU loss  |
+------+--------------+--------------+
|    1 |       2.2282 |       3.3193 |
|    2 |       3.1984 |       2.3277 |
|    3 |       2.2133 |       2.2392 |
|    4 |       1.7525 |       2.1685 |
|    5 |       1.9061 |       2.1347 |
|    6 |       1.5827 |       2.1951 |
|    7 |       1.6861 |       2.0498 |
|    8 |       1.4348 |       2.2318 |
|    9 |       1.3381 |       1.9840 |
|   10 |       1.5077 |       2.0424 |
|   11 |       1.9615 |       2.2301 |
|   12 |       1.6933 |       2.0737 |
|   13 |       1.3753 |       2.0553 |
|   14 |       1.4499 |       2.2821 |
|   15 |       1.7189 |       2.2272 |
|   16 |       1.4370 |       2.2489 |
|   17 |       1.4989 |       2.0788 |
|   18 |       1.5219 |       2.3426 |
|   19 |       1.2279 |       2.2766 |
|   20 |       1.4889 |       2.3957 |
|   21 |       1.3022 |       2.0365 |
|   22 |       1.4804 |       2.1234 |
|   23 |       1.7118 |       2.1813 |
|   24 |       1.3957 |       2.0079 |
+------+--------------+--------------+

Hardware

  • tinybox: 6× RTX 4090 (24GB each), 512GB DDR5, AMD Ryzen
  • PCIe topology: GPU pairs [2,3] and [4,5] share host bridges (PHB)
  • /dev/shm backed shared memory for cross-process weight/grad sharing

Test plan

  • pytest tests/test_data_parallel.py — 3/3 gradient equivalence tests pass
  • Single-GPU path unchanged (backward compatible)
  • 24-step training comparison: both setups converge
  • 7.76x throughput scaling on 6 GPUs
  • Test on larger model (31B) where compute dominates H2D

🤖 Generated with Claude Code

ckgresla and others added 3 commits April 9, 2026 17:19
Single-process, multi-GPU data parallelism that preserves MegaTrain's
CPU-offloaded layer streaming architecture. Each GPU gets its own
execution context (buffers, streams, events, templates, grad slabs)
while sharing CPU master weights and optimizer states.

Key changes:
- _GPUContext dataclass encapsulates per-GPU state
- forward_and_backward orchestrator splits batches across GPUs via threads
- Each GPU's gradients use global_valid_tokens for correct loss scaling
- _prepare_4d_causal_mask for SDPA/eager attention compatibility
- Config: devices list, world_size, --devices CLI arg

Numerical validation (Gemma 2 2B, eager attn, deterministic):

| Test                    | Loss diff | Grad diff (abs) | Grad diff (rel) |
|-------------------------+-----------+-----------------+-----------------|
| float32: 1 GPU vs 4 GPU |    1.4e-9 |          2.5e-5 | 2.8e-5          |
| bf16: 1 GPU vs 4 GPU    |    6.4e-3 |          3.1e-2 | ~1%             |
| bf16: 1 GPU vs 2 GPU    |    1.3e-3 |          1.6e-2 | ~0.8%           |

Walltime (Qwen3.5-0.8B, flash_attention_2, batch=8, 5 steps):

| Config | Time/step | Speedup |
|--------+-----------+---------|
| 1 GPU  |     2.31s | 1.0x    |
| 2 GPUs |     3.58s | 0.64x   |
| 4 GPUs |     7.48s | 0.31x   |

Multi-GPU is currently slower due to GIL-serialized CPU->pinned memory
copies in _load_layer_to_buffer_async. Shared pinned buffer optimization
needed for actual speedup.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Switch from fork to spawn to resolve CUDA fork guard. Workers receive
SharedState via pickle with zero-copy shm handles for CPU master weights.

- mp_state.py: SharedState with shared-memory modules, grad tensors,
  spawn-context queues/locks, cudaHostRegister for pinned DMA speed
- mp_worker.py: per-GPU worker process with independent GPU context,
  forward/backward, K-slab grad D2H, lock-based grad accumulation
- cpu_master.py: _init_multiprocessing spawns workers, overlapped
  grad zeroing, _forward_and_backward_multiprocess dispatches to workers
- train.py: removed fork start method (spawn context used internally)

Gradient equivalence validated:
  float32 DP4: loss diff 1.4e-9, grad diff 2.5e-5
  bf16 DP4:    loss diff 6.4e-3, grad diff 3.9e-1
  bf16 DP2:    loss diff 1.3e-3, grad diff 2.8e-1

Throughput (Gemma 2B, 2x 4090):
  1 GPU:  1518 tok/s
  2 GPUs: 2462 tok/s (1.62x)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
6 GPU bs=48 vs 1 GPU bs=8 on Gemma 2B, 24 steps each.
Result: 7.76x throughput (274 → 2124 tok/s).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@DLYuanGod
Copy link
Copy Markdown
Owner

Hi, thanks for the PR — really appreciate the contribution!

We’ll take some time to carefully review the changes and get back to you as soon as possible. If everything looks good, we’ll proceed with merging.

Thanks again! 🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants