Skip to content

[MoE] Use CPU split-size sum for EP permute output size#3627

Open
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/14from
sanketpurandare/stack/21
Open

[MoE] Use CPU split-size sum for EP permute output size#3627
sanketpurandare wants to merge 1 commit into
sanketpurandare/stack/14from
sanketpurandare/stack/21

Conversation

@sanketpurandare

@sanketpurandare sanketpurandare commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

Stacked PRs:


[MoE] Use CPU split-size sum for EP permute output size

Compute the EP permute output size from the CPU output-splits tensor and pass that size into repeat_interleave(output_size=...) and the matching arange path. This keeps graph capture from constructing a long symbolic expression over the split-size list for the permute size.

This depends on PyTorch support for unbacked SymInt output sizes in repeat_interleave/arange:

pytorch/pytorch#186573

Related TorchTitan issue:

#3336

Compute the EP permute output size from the CPU output-splits tensor and pass that size into repeat_interleave(output_size=...) and the matching arange path. This keeps graph capture from constructing a long symbolic expression over the split-size list for the permute size.

This depends on PyTorch support for unbacked SymInt output sizes in repeat_interleave/arange:

pytorch/pytorch#186573

Related TorchTitan issue:

#3336

stack-info: PR: #3627, branch: sanketpurandare/stack/21
total = (
output_size
if output_size is not None
else num_global_tokens_per_local_expert_E.sum()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this only not work?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works but num_global_tokens_per_local_expert_E is a GPU tensor and summing it up gives a GPU scalar value. if we pass this to repeat_interleave or arange it will cause another cpu-gpu sync.

device = num_global_tokens_per_local_expert_E.device
total = (
output_size
if output_size is not None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why if-else

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept the output size passing as optional, so that users of the _permute function may or may not pass it.

# output[p] = input[input_starts[seg] + (p - output_starts[seg])]
seg_ids = torch.arange(segment_lens.shape[0], device=device).repeat_interleave(
segment_lens
segment_lens, output_size=total

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want this to be on CPU or GPU

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total should be a python int on cpu

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you use routed_input_RD.shape[0]? Here's what claude potentially hallucinates:


Why routed_input_RD.shape[0] is known without adding a sync — provenance

This is the part that makes it free rather than just relocating the sync. Trace where the shape
came from in dispatch():

  output_splits_list = output_splits.tolist()      # line 298 — the ONE existing D2H sync
  ...
  routed_input_RD = all_to_all_single_autograd(
      routed_input_ND, output_splits_list, input_splits_list, self.ep_mesh,   # line 301-306
  )

all_to_all_single is given the output split sizes as a Python list of host ints. The output
tensor's row count is sum(output_splits_list) — computed by host-side arithmetic at allocation
time, before any device work. So routed_input_RD.shape[0] is a host int that was fixed by the
.tolist() sync that already happened at line 298 (the unavoidable one, needed because
all_to_all_single requires list splits).

Therefore reading routed_input_RD.shape[0] in _permute reuses a value the host already knows. It
adds zero new transfers, while deleting the repeat_interleave sync from 2a. Net: two D2H syncs
→ one.

(Under CooR/compile the same holds symbolically: output_splits.tolist() produces unbacked
SymInts, sum(...) is a host-side symbolic int, the a2a output shape[0] is that SymInt, and
repeat_interleave accepts a SymInt output_size — so no .item() graph break is inserted either.)

permuted_indices = (
input_starts[seg_ids]
+ torch.arange(seg_ids.shape[0], device=device)
+ torch.arange(total, device=device)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want this on CPU or GPU

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

total should be a python int on cpu

Comment on lines +267 to +269
# This relies on PyTorch support for unbacked SymInt output sizes in
# repeat_interleave/arange. The CPU sum gives tracing one symbolic size
# instead of a long expression over the split-size list.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain a bit more on before vs. after? I couldn't follow all the compiler language

@sanketpurandare sanketpurandare Jun 11, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we convert a tensor to a list each item of that list is a sym-int. If I sum the elements of the list I get a sym-int expression. For instance tensor t = [a, b, c, d], when I call list_t = t.to_list(), list_t = [u0, u1, u2, u3], where each ui is a sym-int. And then if I sum it up I will get a expression u0+u1+u2+u3. When the compiler gets such an expression for size allocation in repeat_interleave or arange, it will try to calculate the upper bound which ends up being large. Instead if we first call sum on the tensor and then call .item I will get one sym-int

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the sentence doesn't finish:

Instead if we first call sum on the tensor and then call .item I will get one sym-int

so that the largest bound is 65536? compared with the sum of a list of 65536's?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants