Skip to content

Add basic EP support (no overlapping)#54

Open
GarlGuo wants to merge 15 commits into
mainfrom
basic-ep-support
Open

Add basic EP support (no overlapping)#54
GarlGuo wants to merge 15 commits into
mainfrom
basic-ep-support

Conversation

@GarlGuo
Copy link
Copy Markdown
Member

@GarlGuo GarlGuo commented Apr 30, 2026

This PR will use triton + symmetric memory to provide basic EP support for SonicMoE. We implement collectives with close-to-peak network bandwidth rate and change the metadata accordingly.

Co-authored by Claude Code


The forward dispatches each rank's T_local tokens to the experts that hold them via NVLink symmetric memory, runs the grouped GEMMs locally, and combines back across NVLink. A runtime NetworkProfiler benchmarks the three dispatch and three combine primitives on the local hardware and picks the fastest pair per workload.

EP world size 8:

torchrun --nproc_per_node=8 --standalone benchmarks/distributed/moe-ep.py --thiek 131072,4096,1536,128,8

The EP forward exposes two optional flags that trade off activation memory, NVLink bandwidth in backward, and a host-stall on the forward.

--redispatch_x_in_backward (default to False): instead of saving the post-dispatch x_compute for the backward, save only the pre-dispatch x_local and re-dispatch in the backward via a Copy-Engine all-gather on a side stream.

torchrun --nproc_per_node=8 --standalone benchmarks/distributed/moe-ep.py --thiek 131072,4096,1536,128,8 --redispatch_x_in_backward

--CPU_sync_on_runtime (default to False): initiate D2H sync to shrink the saved activation cache. The trade-off is a single host stall per forward. Inference mode skips this since no cache is saved.

torchrun --nproc_per_node=8 --standalone benchmarks/distributed/moe-ep.py --thiek 131072,4096,1536,128,8 --CPU_sync_on_runtime

Example usage:

import os
import torch
import torch.distributed as dist
from sonicmoe import MoE
from sonicmoe.distributed_utils import NetworkProfiler
from sonicmoe.enums import ActivationType
from sonicmoe.functional.ep import moe_ep_TC_softmax_topk_forward

rank = int(os.environ["RANK"])
local_rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(local_rank)

device = torch.device(f"cuda:{local_rank}")
dist.init_process_group("nccl", device_id=device)

T, H, I, E, K = 131072, 4096, 1536, 128, 8   # T is the global token count
T_local, E_local = T // world_size, E // world_size

# Build the global MoE once, then slice each rank's E_local expert shard.
moe = MoE(
    num_experts=E,
    num_experts_per_tok=K,
    hidden_size=H,
    intermediate_size=I,
    activation_function=ActivationType.SWIGLU,
    add_bias=False,
    std=0.02,
).to(device=device, dtype=torch.bfloat16)
for p in moe.parameters():
    dist.broadcast(p.data, src=0)

# QuACK's grouped GEMM requires the original (E, *, *) strides, preserved via
# empty_strided + copy_ after permuting to the EP layout.

# EP: shard expert weights evenly across all ranks
w1_sharded = moe.c_fc.weight[rank * E_local : (rank + 1) * E_local].permute(1, 2, 0)    # (2I, H, E_local) view
w2_sharded = moe.c_proj.weight[rank * E_local : (rank + 1) * E_local].permute(0, 2, 1)  # (E_local, I, H) view
w1_sharded_contiguous = torch.empty_strided(w1_sharded.shape, w1_sharded.stride(), dtype=w1_sharded.dtype, device=device).copy_(w1_sharded)
w2_sharded_contiguous = torch.empty_strided(w2_sharded.shape, w2_sharded.stride(), dtype=w2_sharded.dtype, device=device).copy_(w2_sharded)

# !!!!! We assume the router weights are replicated across ranks !!!!!
router_w = moe.router.weight

# Pick the fastest dispatch and combine primitives for this GPU cluster once per (T_local, H, K, dtype).
# We have also construct a `sonicmoe.distributed_utils.RuntimeEPConfig` from scratch by overwriting the Dispatch and Combine mode.
ep_config = NetworkProfiler(T_local=T_local, H=H, K=K, dtype=torch.bfloat16).profile()

# we always assume DP -> EP -> DP !!!
x_local = torch.randn(T_local, H, device=device, dtype=torch.bfloat16)
output_local = moe_ep_TC_softmax_topk_forward(
    x_local,
    router_w,
    w1_sharded_contiguous, None,
    w2_sharded_contiguous, None,
    K=K, E=E,
    ep_config=ep_config,
    activation_type=ActivationType.SWIGLU,
)

Example output:

EP forward+backward + local baselines  EP world size W 4, Minibatch size T 131072 (Per-rank microbatch size T_local 32768), H 4096, I 1536, E 128 (E_local 32), K 8, dtype bf16, routing: softmax_over_topk, w1 layout: interleaved , bias: False, redispatch_x_in_backward: False, CPU_sync_on_runtime: False
[NetworkProfiler] T_local=32768 H=4096 K=8 W=4
  Dispatch:    AG_TRITON=1.20ms (668.9 GB/s)  A2A_TRITON=2.30ms (699.9 GB/s)  RANK_DEDUP_DISPATCH_TRITON=1.10ms (658.8 GB/s)  →  winner=RANK_DEDUP_DISPATCH_TRITON
  Combine: A2A_TRITON=2.42ms (665.7 GB/s)  RS_COMBINE_TRITON=1.87ms  RANK_DEDUP_COMBINE_TRITON=1.70ms  →  winner=RANK_DEDUP_COMBINE_TRITON
Final config: dispatch=RANK_DEDUP_DISPATCH_TRITON (profiled), agg=RANK_DEDUP_COMBINE_TRITON (profiled)
Dispatch + Combine time: 1.10 + 1.70 = 2.80 ms
max ref o val                    0.057444
mean ref o val                   0.007419
max abs diff on o                0.000308
mean rel diff on o               0.015812

max abs ref value dx             0.080078
mean abs ref value dx            0.010559
max abs diff on dx               0.000977
mean rel diff on dx              0.026390

max abs ref value drouter_w      3.581285
mean abs ref value drouter_w     0.568302
max abs diff on drouter_w        0.025783
mean rel diff on drouter_w       0.028664

max abs ref value dw1            0.492188
mean abs ref value dw1           0.060791
max abs diff on dw1              0.001953
mean rel diff on dw1             0.017429

max abs ref value dw2            0.468750
mean abs ref value dw2           0.060547
max abs diff on dw2              0.001953
mean rel diff on dw2             0.021249


══ Saved-activation cache audit (training, redispatch_x=False, CPU_sync=False) ══
  X cache:               1.00 GiB  shape=(131072, 4096)
  h cache:               6.00 GiB  shape=(1048576, 3072)
  Total:                 7.00 GiB

── EP throughput ──
 EP Fwd (inference mode) Average time: 10.35    ms, Per-rank TFLOPS: 956  , Net EP TFLOPS: 3824 
 EP Fwd (training mode)  Average time: 10.36    ms, Per-rank TFLOPS: 955  , Net EP TFLOPS: 3820 
 EP Bwd (derived)        Average time: 17.98    ms, Per-rank TFLOPS: 1101 , Net EP TFLOPS: 4404 

── Per-rank baselines (T=32768, E=128, K=8, single-GPU with T_local tokens) ──
 Per-rank Fwd (T=32768, E=128, K=8, training)     Average time: 8.68     ms, TFLOPS: 1140 
 Per-rank Bwd (T=32768, E=128, K=8)               Average time: 18.80    ms, TFLOPS: 1053 

── Per-rank baselines (T=32768, E=32, K=8, single-GPU with T_local tokens) ──
 Per-rank Fwd (T=32768, E=32, K=8, training)      Average time: 8.70     ms, TFLOPS: 1137 
 Per-rank Bwd (T=32768, E=32, K=8)                Average time: 19.24    ms, TFLOPS: 1028 

── Per-rank baselines (T=131072, E=128, K=8, single-GPU full EP scale) ──
 Per-rank Fwd (T=131072, E=128, K=8, training)    Average time: 30.81    ms, TFLOPS: 1285 
 Per-rank Bwd (T=131072, E=128, K=8)              Average time: 63.81    ms, TFLOPS: 1241 

══ Exposed network latency (EP vs. local T_local) ══
  Training fwd:
    EP Fwd:                                  10.36    ms
    Per-rank Fwd (T_local, E=128):           8.68     ms
        Slowdown:                            1.68     ms (16.2%)
    Per-rank Fwd (T_local, E=32):            8.70     ms
        Slowdown:                            1.66     ms (16.0%)

  Backward:
    EP Bwd:                                  17.98    ms
    Per-rank Bwd (T_local, E=128):           18.80    ms
        Slowdown:                            -0.82    ms (-4.6%)
    Per-rank Bwd (T_local, E=32):            19.24    ms
        Slowdown:                            -1.27    ms (-7.0%)

══ EP scaling efficiency ══
  Training fwd:
    Single-GPU full EP-scale (T=131072):     30.81    ms
    EP W=4 (T=131072):                       10.36    ms
    Observed speedup:                        2.97    × (ideal 4×, 74.3% over linear scaling)

  Backward:
    Single-GPU full EP-scale (T=131072):     63.81    ms
    EP W=4 (T=131072):                       17.98    ms
    Observed speedup:                        3.55    × (ideal 4×, 88.7% over linear scaling)
PASS

@GarlGuo GarlGuo marked this pull request as draft April 30, 2026 03:11
@GarlGuo GarlGuo marked this pull request as ready for review May 11, 2026 08:38
@GarlGuo GarlGuo requested a review from mayank31398 May 11, 2026 08:38
@GarlGuo GarlGuo self-assigned this May 11, 2026
@yifanzhang-pro
Copy link
Copy Markdown

@codex review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants