Skip to content

Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3595

Merged
wwwjn merged 6 commits into
gh/wwwjn/22/basefrom
gh/wwwjn/22/head
Jun 10, 2026
Merged

Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3595
wwwjn merged 6 commits into
gh/wwwjn/22/basefrom
gh/wwwjn/22/head

Conversation

@wwwjn

@wwwjn wwwjn commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Stack from ghstack (oldest at bottom):

What's the problem

  • [Currently] combine() wrongly assume the tokens are evenly sharded on each rank

    out_TD = torch.zeros(
    x_TD.shape[0] * self.sp_size,
    x_TD.shape[-1],
    device=x_TD.device,
    dtype=x_TD.dtype,
    )
    if not self.score_before_experts:
    routed_output_RD = (
    routed_output_RD.to(torch.float32)
    * metadata.topk_scores_experts_sorted_N.reshape(-1, 1)
    ).to(routed_output_RD.dtype)
    # With SP, token indices are 0-based within the local shard.
    # Offset to global positions for the full-size scatter buffer.
    if self.sp_size > 1:
    token_indices_experts_sorted_N = (
    metadata.token_indices_experts_sorted_N + x_TD.shape[0] * self.sp_rank
    (infer global SPMD in local SPMD region)

    • If uneven sharded, out_TD will have different shapes across SP ranks.
    • We should directly ban if input number of tokens in input batch can not be evenly sharded by SP ranks
  • [Future] Router will use spmd_types soon, and router is per SP rank. Per SP rank should have even sharding

  • [Future] we want to avoid dispatch/load_balacing the padded token, we should be able to do that by adding metadata field to record the actually local tokens for each sp rank

What does this PR do?

This PR is doing "virtual padding" , and passing metadata around

  • Calculate num_local_tokens_after_padding = (T + pad_tokens) // sp_size in MoE module
  • Pass num_local_tokens_after_padding to GroupedExperts module, then to combine()
  • combine() returns a tensor with shape (num_local_tokens_after_padding * sp_rank, .... )
  • slice the combined tensor to (T, ...) in MoE

[ghstack-poisoned]
@wwwjn wwwjn requested review from fegin, tianyu-l and wconstab as code owners June 9, 2026 20:57
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 9, 2026
@wwwjn wwwjn changed the title Apply moe-padding changes Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch Jun 9, 2026

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds right to me, one concrete issue before landing

Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment on lines +246 to +249
trainer_parallelism = self.trainer.parallelism
sp_degree = trainer_parallelism.tensor_parallel_degree
# RL policy inputs are shaped by BatchConfig, not TrainingConfig.
seq_len = self.batcher.batch.seq_len

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking these in post_init is not safe, because CLI can override -- we have to do these check in update_from_config

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is valid. Even TrainingConfig is not used in PolicyTrainer today, a user could also mistakenly override by CLI --trainer.training.seq_len to hack that.

Then the problem is passing BatchConfig from controller, into PolicyTrainer's TrainingConfig(), then calling self.model.update_from_config(config). So my updated plan is:

  • remove the check in post_init
  • Pass PolicyTrainer.TrianingConfig.seq_len to be BatchConfig.seq_len after parsing CLI override here
  • Then we are good to only check in Decoder.update_config()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass PolicyTrainer.TrianingConfig.seq_len to be BatchConfig.seq_len after parsing CLI override here

As discussed earlier, PolicyTrainer.TrainingConfig shouldn't have seq_len. In fact, for pretaining, we should have BatchConfig in Dataloader.Config, not in Trainer.TrainingConfig

@wwwjn wwwjn Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that's the right direction, let me add a TODO and also a github issue to track this

Comment thread torchtitan/models/common/token_dispatcher.py Outdated
@pianpwk

pianpwk commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

sorry, I saw this in the closed PR:

After we migrate to spmd_types, the all-gather would require this metadata as well (spmd.redistribute takes this shape arg), with which spmd_types could also achieve similar "pad / unpad only around collectives" effect.

what's the shape arg mentioned here?

@wwwjn

wwwjn commented Jun 9, 2026

Copy link
Copy Markdown
Contributor Author

sorry, I saw this in the closed PR:

After we migrate to spmd_types, the all-gather would require this metadata as well (spmd.redistribute takes this shape arg), with which spmd_types could also achieve similar "pad / unpad only around collectives" effect.

what's the shape arg mentioned here?

num_local_tokens_after_padding = (T + pad_tokens) // sp_size

This parameter here, num_local_tokens_after_padding we are passing this to combine() and doing a virtual padding ,

@pianpwk

pianpwk commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

This parameter here, num_local_tokens_after_padding we are passing this to combine() and doing a virtual padding ,

oh my question was, redistribute() takes this?

wwwjn added 2 commits June 9, 2026 20:35
[ghstack-poisoned]
[ghstack-poisoned]

@tianyu-l tianyu-l left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice! some nit comments

Comment thread torchtitan/models/common/decoder.py Outdated
f"length {max_seq_len}."
)


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accidental?

Comment thread torchtitan/trainer.py Outdated
Comment thread torchtitan/trainer.py Outdated
Comment thread torchtitan/experiments/rl/trainer.py Outdated
Comment thread torchtitan/experiments/rl/trainer.py Outdated
[ghstack-poisoned]
@wwwjn wwwjn changed the base branch from gh/wwwjn/22/base to main June 10, 2026 14:52
[ghstack-poisoned]
@wwwjn wwwjn changed the base branch from main to gh/wwwjn/22/base June 10, 2026 14:57
[ghstack-poisoned]
@wwwjn

wwwjn commented Jun 10, 2026

Copy link
Copy Markdown
Contributor Author

CPU test failing might be cause of spmd_types version, and seems not related this PR
I rerun the test locally

  • torch 2.13.0.dev20260609+cpu
  • spmd_types==0.2.1

Command result:

CUDA_VISIBLE_DEVICES= titan-rl/bin/python -m pytest tests/unit_tests/
Result: 1 passed in 10.50s.

@wwwjn wwwjn merged commit bbd66c6 into gh/wwwjn/22/base Jun 10, 2026
10 of 11 checks passed
wwwjn added a commit that referenced this pull request Jun 10, 2026
… and fix combine() shape mismatch #3595 (#3619)

#3595 is merged to wrong base, replay that PR
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/rl ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants