Skip to content

[Bug] Fix MoE SP token combine indices#3604

Open
Matrix-Z97 wants to merge 1 commit into
pytorch:mainfrom
Matrix-Z97:EP_BUG_FIX
Open

[Bug] Fix MoE SP token combine indices#3604
Matrix-Z97 wants to merge 1 commit into
pytorch:mainfrom
Matrix-Z97:EP_BUG_FIX

Conversation

@Matrix-Z97

Copy link
Copy Markdown

Summary

Fix incorrect MoE token placement when Tensor Parallelism, Expert Parallelism, and Sequence Parallelism are enabled together.

When SP is enabled, the MoE input x_BLD is sharded along the sequence dimension. After flattening the local shard, token indices are only local to the current SP shard. The previous combine logic converted local indices to global indices with a simple per-rank offset:

# Previous logic
if self.sp_size > 1:
    token_indices_experts_sorted_N = (
        metadata.token_indices_experts_sorted_N + x_TD.shape[0] * self.sp_rank
    )
else:
    token_indices_experts_sorted_N = metadata.token_indices_experts_sorted_N

This is equivalent to:

global_idx = local_idx + local_num_tokens * sp_rank

This only works when B == 1. For batched inputs, each batch has its own sequence shard, so the simple offset places tokens into the wrong global positions. This can cause noticeable numerical differences when running MoE with TP + EP.

This PR fixes the issue by passing the original local (B, L, D) shape through dispatcher metadata and reconstructing global token indices with batch-aware sequence offsets:

local_pos = local_idx % local_seq_len
batch_idx = local_idx // local_seq_len
global_idx = batch_idx * global_seq_len + sp_rank * local_seq_len + local_pos

Changes

  • Add input_shape_BLD to MoE dispatch metadata.
  • Add a shared _sp_global_token_indices() helper in the token dispatcher.
  • Fix SP scatter indices in AllToAllTokenDispatcher.combine().
  • Apply the same SP placement fix to DeepEPTokenDispatcher and HybridEPTokenDispatcher.
  • Pass (B, L, D) from common MoE and GPT-OSS MoE dispatch call sites.

Validation

Precision validation was performed on a DeepSeek-V3 debug model using 4 H100 GPUs.

Output comparison:

  • FSDP2 + TP2
dsv3_debug_tp
  • FSDP2 + TP2 + EP2, without this PR
dsv3_debug_tp_ep_old
  • FSDP2 + TP2 + EP2, with this PR
dsv3_debug_tp_ep_new

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 10, 2026
@pytorch-bot

pytorch-bot Bot commented Jun 10, 2026

Copy link
Copy Markdown

The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:

  • ciflow/8gpu

Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows.

@Matrix-Z97 Matrix-Z97 changed the title Fix MoE SP token combine indices [Bug] Fix MoE SP token combine indices Jun 10, 2026
@wwwjn wwwjn self-assigned this Jun 10, 2026
topk_expert_ids_TK: torch.Tensor,
num_local_tokens_per_expert_E: torch.Tensor,
*,
input_shape_BLD: tuple[int, int, int] | None = None,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to pass the input_shape_BLD into dispatch()? We can directly pass to combine() and change combine()'s function signature. Now we are passing it via dispatch() and store it in dispatch() Metatdata.

I find a orthogonal problem (uneven shard when SP on L dimension #3619) , but ends up using the similar solution with yours.

Comment on lines +691 to +694
num_output_tokens = combined_TD.shape[0] * self.sp_size
if metadata.input_shape_BLD is not None:
B, L, _ = metadata.input_shape_BLD
num_output_tokens = B * L * self.sp_size

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems a bug when with flattening / unflattening introduced in recent refactor. cc @acisseJZhong

Thanks for catching it!

Could you rebase on latest main? We just introduced a field to make it work with uneven sharding.

@tianyu-l tianyu-l added bug Something isn't working high priority labels Jun 10, 2026
"""Metadata returned by AllToAllTokenDispatcher.dispatch() for use in combine()."""

input_shape: tuple # for _unpermute
input_shape_BLD: tuple[int, int, int] | None # for SP global token indices

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad name in two regards:

  • almost the same as one line above; does it mean very different things? If so should differentiate in naming
  • _BLD suffix means the tensor itself has this shape, according to the convention. Here it's actually a tuple of 3 integers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot. high priority

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants