[Bug] Fix MoE SP token combine indices#3604
Conversation
|
The following ciflow label(s) have been added but CI has not been triggered yet because the workflows are awaiting approval:
Once a maintainer approves the workflows (scroll to the bottom of the PR page), the corresponding CI jobs will be triggered automatically. Please ping one of the reviewers if you do not have access to approve and run workflows. |
| topk_expert_ids_TK: torch.Tensor, | ||
| num_local_tokens_per_expert_E: torch.Tensor, | ||
| *, | ||
| input_shape_BLD: tuple[int, int, int] | None = None, |
There was a problem hiding this comment.
Do we need to pass the input_shape_BLD into dispatch()? We can directly pass to combine() and change combine()'s function signature. Now we are passing it via dispatch() and store it in dispatch() Metatdata.
I find a orthogonal problem (uneven shard when SP on L dimension #3619) , but ends up using the similar solution with yours.
| num_output_tokens = combined_TD.shape[0] * self.sp_size | ||
| if metadata.input_shape_BLD is not None: | ||
| B, L, _ = metadata.input_shape_BLD | ||
| num_output_tokens = B * L * self.sp_size |
There was a problem hiding this comment.
Seems a bug when with flattening / unflattening introduced in recent refactor. cc @acisseJZhong
Thanks for catching it!
Could you rebase on latest main? We just introduced a field to make it work with uneven sharding.
| """Metadata returned by AllToAllTokenDispatcher.dispatch() for use in combine().""" | ||
|
|
||
| input_shape: tuple # for _unpermute | ||
| input_shape_BLD: tuple[int, int, int] | None # for SP global token indices |
There was a problem hiding this comment.
bad name in two regards:
- almost the same as one line above; does it mean very different things? If so should differentiate in naming
_BLDsuffix means the tensor itself has this shape, according to the convention. Here it's actually a tuple of 3 integers.
Summary
Fix incorrect MoE token placement when Tensor Parallelism, Expert Parallelism, and Sequence Parallelism are enabled together.
When SP is enabled, the MoE input
x_BLDis sharded along the sequence dimension. After flattening the local shard, token indices are only local to the current SP shard. The previous combine logic converted local indices to global indices with a simple per-rank offset:This is equivalent to:
This only works when B == 1. For batched inputs, each batch has its own sequence shard, so the simple offset places tokens into the wrong global positions. This can cause noticeable numerical differences when running MoE with TP + EP.
This PR fixes the issue by passing the original local (B, L, D) shape through dispatcher metadata and reconstructing global token indices with batch-aware sequence offsets:
Changes
Validation
Precision validation was performed on a DeepSeek-V3 debug model using 4 H100 GPUs.
Output comparison:
FSDP2 + TP2FSDP2 + TP2 + EP2, without this PRFSDP2 + TP2 + EP2, with this PR