Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3577
Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3577wwwjn wants to merge 7 commits into
combine() shape mismatch#3577Conversation
| self.top_k = config.top_k | ||
| self.score_before_experts = config.score_before_experts | ||
| # Sequence-parallel split coordinates. EP dispatchers update these in | ||
| # wire_meshes(); the local dispatcher keeps the TP=1 defaults. |
There was a problem hiding this comment.
Introducing this sp_rank, sp_size in LocalTokenDispatcher is not ideal, but it's used in local_num_valid_tokens() , I want to share local_num_valid_tokens() implementation across AllToAllTokenDispather, DeepEP/HybridEP
| # are never routed and are sliced off below. | ||
| out_TD = torch.zeros( | ||
| x_TD.shape[0] * self.sp_size, | ||
| num_local_tokens_after_padding * self.sp_size, |
There was a problem hiding this comment.
My impression is that this "virtual padding" idea would work when GroupedExperts is in a local region inside a DTensor global region, where DTensor handles the shard / all-gather.
After we migrate to spmd_types, the all-gather would require this metadata as well (spmd.redistribute takes this shape arg), with which spmd_types could also achieve similar "pad / unpad only around collectives" effect. O/w we still have to do "real padding" in model code. cc @pianpwk
I'm OK with this temporary solution (after cleanup) to unblock your vLLM + MoE work.
There was a problem hiding this comment.
Yes let me clean up.
For spmd_types in MoE region, we don't need "real padding" in our current setup, just passing this "vitural padded shape" as metadata around.
combine() shape mismatc
combine() shape mismatccombine() shape mismatch
|
Close this because of #3595 . Let's move discussion there |
tianyu-l
left a comment
There was a problem hiding this comment.
sounds right to me, one concrete issue before landing
What's the problem
[Currently] combine() wrongly assume the tokens are evenly sharded on each rank
torchtitan/torchtitan/models/common/token_dispatcher.py
Lines 439 to 456 in c0428bb
[Future] Router will use spmd_types soon, and router is per SP rank. Per SP rank should have even sharding
[Future] we want to avoid dispatch/load_balacing the padded token, we should be able to do that by adding metadata field to record the actually local tokens for each sp rank
What does this PR do?
This PR is doing "virtual padding" , and passing metadata around