Skip to content

Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3577

Closed
wwwjn wants to merge 7 commits into
mainfrom
moe-padding
Closed

Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3577
wwwjn wants to merge 7 commits into
mainfrom
moe-padding

Conversation

@wwwjn

@wwwjn wwwjn commented Jun 8, 2026

Copy link
Copy Markdown
Contributor

What's the problem

  • [Currently] combine() wrongly assume the tokens are evenly sharded on each rank

    out_TD = torch.zeros(
    x_TD.shape[0] * self.sp_size,
    x_TD.shape[-1],
    device=x_TD.device,
    dtype=x_TD.dtype,
    )
    if not self.score_before_experts:
    routed_output_RD = (
    routed_output_RD.to(torch.float32)
    * metadata.topk_scores_experts_sorted_N.reshape(-1, 1)
    ).to(routed_output_RD.dtype)
    # With SP, token indices are 0-based within the local shard.
    # Offset to global positions for the full-size scatter buffer.
    if self.sp_size > 1:
    token_indices_experts_sorted_N = (
    metadata.token_indices_experts_sorted_N + x_TD.shape[0] * self.sp_rank
    (infer global SPMD in local SPMD region)

    • If uneven sharded, out_TD will have different shapes across SP ranks.
    • We should directly ban if input number of tokens in input batch can not be evenly sharded by SP ranks
  • [Future] Router will use spmd_types soon, and router is per SP rank. Per SP rank should have even sharding

  • [Future] we want to avoid dispatch/load_balacing the padded token, we should be able to do that by adding metadata field to record the actually local tokens for each sp rank

What does this PR do?

This PR is doing "virtual padding" , and passing metadata around

  • Calculate num_local_tokens_after_padding = (T + pad_tokens) // sp_size in MoE module
  • Pass num_local_tokens_after_padding to GroupedExperts module, then to combine()
  • combine() returns a tensor with shape (num_local_tokens_after_padding * sp_rank, .... )
  • slice the combined tensor to (T, ...) in MoE

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 8, 2026
self.top_k = config.top_k
self.score_before_experts = config.score_before_experts
# Sequence-parallel split coordinates. EP dispatchers update these in
# wire_meshes(); the local dispatcher keeps the TP=1 defaults.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Introducing this sp_rank, sp_size in LocalTokenDispatcher is not ideal, but it's used in local_num_valid_tokens() , I want to share local_num_valid_tokens() implementation across AllToAllTokenDispather, DeepEP/HybridEP

Comment thread torchtitan/models/common/token_dispatcher.py Outdated
Comment thread torchtitan/models/common/moe.py Outdated
Comment thread torchtitan/models/common/decoder.py Outdated
# are never routed and are sliced off below.
out_TD = torch.zeros(
x_TD.shape[0] * self.sp_size,
num_local_tokens_after_padding * self.sp_size,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My impression is that this "virtual padding" idea would work when GroupedExperts is in a local region inside a DTensor global region, where DTensor handles the shard / all-gather.

After we migrate to spmd_types, the all-gather would require this metadata as well (spmd.redistribute takes this shape arg), with which spmd_types could also achieve similar "pad / unpad only around collectives" effect. O/w we still have to do "real padding" in model code. cc @pianpwk

I'm OK with this temporary solution (after cleanup) to unblock your vLLM + MoE work.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes let me clean up.

For spmd_types in MoE region, we don't need "real padding" in our current setup, just passing this "vitural padded shape" as metadata around.

[`moe_sharding.py`](moe_sharding.py).

Comment thread torchtitan/models/common/moe.py
@pytorch-bot pytorch-bot Bot added the ciflow/rl label Jun 9, 2026
@wwwjn wwwjn marked this pull request as ready for review June 9, 2026 18:46
@wwwjn wwwjn requested review from fegin and wconstab as code owners June 9, 2026 18:46
@wwwjn wwwjn changed the title [Do not review] Add local-shard MoE padding Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatc Jun 9, 2026
@wwwjn wwwjn changed the title Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatc Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch Jun 9, 2026
@wwwjn

wwwjn commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Close this because of #3595 . Let's move discussion there

@wwwjn wwwjn closed this Jun 9, 2026

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds right to me, one concrete issue before landing

@tianyu-l tianyu-l deleted the moe-padding branch June 9, 2026 21:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants