Using "virtual padding" to calculate number_local_tokens per SP rank, and fix combine() shape mismatch#3595
Conversation
tianyu-l
left a comment
There was a problem hiding this comment.
sounds right to me, one concrete issue before landing
| trainer_parallelism = self.trainer.parallelism | ||
| sp_degree = trainer_parallelism.tensor_parallel_degree | ||
| # RL policy inputs are shaped by BatchConfig, not TrainingConfig. | ||
| seq_len = self.batcher.batch.seq_len |
There was a problem hiding this comment.
checking these in post_init is not safe, because CLI can override -- we have to do these check in update_from_config
There was a problem hiding this comment.
I see, this is valid. Even TrainingConfig is not used in PolicyTrainer today, a user could also mistakenly override by CLI --trainer.training.seq_len to hack that.
Then the problem is passing BatchConfig from controller, into PolicyTrainer's TrainingConfig(), then calling self.model.update_from_config(config). So my updated plan is:
- remove the check in post_init
- Pass
PolicyTrainer.TrianingConfig.seq_lento beBatchConfig.seq_lenafter parsing CLI override here - Then we are good to only check in Decoder.update_config()
There was a problem hiding this comment.
Pass PolicyTrainer.TrianingConfig.seq_len to be BatchConfig.seq_len after parsing CLI override here
As discussed earlier, PolicyTrainer.TrainingConfig shouldn't have seq_len. In fact, for pretaining, we should have BatchConfig in Dataloader.Config, not in Trainer.TrainingConfig
There was a problem hiding this comment.
I see, that's the right direction, let me add a TODO and also a github issue to track this
|
sorry, I saw this in the closed PR:
what's the shape arg mentioned here? |
This parameter here, |
oh my question was, redistribute() takes this? |
tianyu-l
left a comment
There was a problem hiding this comment.
nice! some nit comments
| f"length {max_seq_len}." | ||
| ) | ||
|
|
||
|
|
|
CPU test failing might be cause of spmd_types version, and seems not related this PR
Command result: CUDA_VISIBLE_DEVICES= titan-rl/bin/python -m pytest tests/unit_tests/ |
Stack from ghstack (oldest at bottom):
What's the problem
[Currently] combine() wrongly assume the tokens are evenly sharded on each rank
torchtitan/torchtitan/models/common/token_dispatcher.py
Lines 439 to 456 in c0428bb
[Future] Router will use spmd_types soon, and router is per SP rank. Per SP rank should have even sharding
[Future] we want to avoid dispatch/load_balacing the padded token, we should be able to do that by adding metadata field to record the actually local tokens for each sp rank
What does this PR do?
This PR is doing "virtual padding" , and passing metadata around