[MoE] Use CPU split-size sum for EP permute output size#3627
[MoE] Use CPU split-size sum for EP permute output size#3627sanketpurandare wants to merge 1 commit into
Conversation
Compute the EP permute output size from the CPU output-splits tensor and pass that size into repeat_interleave(output_size=...) and the matching arange path. This keeps graph capture from constructing a long symbolic expression over the split-size list for the permute size. This depends on PyTorch support for unbacked SymInt output sizes in repeat_interleave/arange: pytorch/pytorch#186573 Related TorchTitan issue: #3336 stack-info: PR: #3627, branch: sanketpurandare/stack/21
3685786 to
0526936
Compare
| total = ( | ||
| output_size | ||
| if output_size is not None | ||
| else num_global_tokens_per_local_expert_E.sum() |
There was a problem hiding this comment.
It works but num_global_tokens_per_local_expert_E is a GPU tensor and summing it up gives a GPU scalar value. if we pass this to repeat_interleave or arange it will cause another cpu-gpu sync.
| device = num_global_tokens_per_local_expert_E.device | ||
| total = ( | ||
| output_size | ||
| if output_size is not None |
There was a problem hiding this comment.
I kept the output size passing as optional, so that users of the _permute function may or may not pass it.
| # output[p] = input[input_starts[seg] + (p - output_starts[seg])] | ||
| seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave( | ||
| segment_lens | ||
| segment_lens, output_size=total |
There was a problem hiding this comment.
do you want this to be on CPU or GPU
There was a problem hiding this comment.
total should be a python int on cpu
There was a problem hiding this comment.
can you use routed_input_RD.shape[0]? Here's what claude potentially hallucinates:
Why routed_input_RD.shape[0] is known without adding a sync — provenance
This is the part that makes it free rather than just relocating the sync. Trace where the shape
came from in dispatch():
output_splits_list = output_splits.tolist() # line 298 — the ONE existing D2H sync
...
routed_input_RD = all_to_all_single_autograd(
routed_input_ND, output_splits_list, input_splits_list, self.ep_mesh, # line 301-306
)
all_to_all_single is given the output split sizes as a Python list of host ints. The output
tensor's row count is sum(output_splits_list) — computed by host-side arithmetic at allocation
time, before any device work. So routed_input_RD.shape[0] is a host int that was fixed by the
.tolist() sync that already happened at line 298 (the unavoidable one, needed because
all_to_all_single requires list splits).
Therefore reading routed_input_RD.shape[0] in _permute reuses a value the host already knows. It
adds zero new transfers, while deleting the repeat_interleave sync from 2a. Net: two D2H syncs
→ one.
(Under CooR/compile the same holds symbolically: output_splits.tolist() produces unbacked
SymInts, sum(...) is a host-side symbolic int, the a2a output shape[0] is that SymInt, and
repeat_interleave accepts a SymInt output_size — so no .item() graph break is inserted either.)
| permuted_indices = ( | ||
| input_starts[seg_ids] | ||
| + torch.arange(seg_ids.shape[0], device=device) | ||
| + torch.arange(total, device=device) |
There was a problem hiding this comment.
do you want this on CPU or GPU
There was a problem hiding this comment.
total should be a python int on cpu
| # This relies on PyTorch support for unbacked SymInt output sizes in | ||
| # repeat_interleave/arange. The CPU sum gives tracing one symbolic size | ||
| # instead of a long expression over the split-size list. |
There was a problem hiding this comment.
Could you explain a bit more on before vs. after? I couldn't follow all the compiler language
There was a problem hiding this comment.
When we convert a tensor to a list each item of that list is a sym-int. If I sum the elements of the list I get a sym-int expression. For instance tensor t = [a, b, c, d], when I call list_t = t.to_list(), list_t = [u0, u1, u2, u3], where each ui is a sym-int. And then if I sum it up I will get a expression u0+u1+u2+u3. When the compiler gets such an expression for size allocation in repeat_interleave or arange, it will try to calculate the upper bound which ends up being large. Instead if we first call sum on the tensor and then call .item I will get one sym-int
There was a problem hiding this comment.
the sentence doesn't finish:
Instead if we first call sum on the tensor and then call .item I will get one sym-int
so that the largest bound is 65536? compared with the sum of a list of 65536's?
Stacked PRs:
[MoE] Use CPU split-size sum for EP permute output size
Compute the EP permute output size from the CPU output-splits tensor and pass that size into repeat_interleave(output_size=...) and the matching arange path. This keeps graph capture from constructing a long symbolic expression over the split-size list for the permute size.
This depends on PyTorch support for unbacked SymInt output sizes in repeat_interleave/arange:
pytorch/pytorch#186573
Related TorchTitan issue:
#3336