Skip to content

Fix MTP recompute padding mask forwarding#4963

Open
BestJuly wants to merge 3 commits into
NVIDIA:mainfrom
BestJuly:lit/issue_4933
Open

Fix MTP recompute padding mask forwarding#4963
BestJuly wants to merge 3 commits into
NVIDIA:mainfrom
BestJuly:lit/issue_4933

Conversation

@BestJuly
Copy link
Copy Markdown
Contributor

@BestJuly BestJuly commented May 25, 2026

Summary

Fixes #4933.

Validation

  • PYTHONPATH=.venv/lib/python3.12/site-packages:$PYTHONPATH /usr/bin/python3.12 -m torch.distributed.run --standalone --nproc-per-node 8 -m pytest -q --tb=short --capture=fd tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-1-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-2-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[1-4-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-1-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-2-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[2-4-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[4-1-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward[4-2-True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_fp8_support[True] tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_packed_sequences_with_full_recompute
  • PYTHONPATH=.venv/lib/python3.12/site-packages:$PYTHONPATH /usr/bin/python3.12 -m torch.distributed.run --standalone --nproc-per-node 8 -m pytest -q --tb=short --capture=fd tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py::TestMegatronFSDPE2E::test_compatible_with_nd_parallel[optim_grads_params_double_buffer-TP2]
  • PYTHONPATH=.venv/lib/python3.12/site-packages:$PYTHONPATH /usr/bin/python3.12 -m torch.distributed.run --standalone --nproc-per-node 8 -m pytest -q --tb=short --capture=fd tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py::TestMegatronFSDPE2E::test_compatible_with_nd_parallel[optim_grads_params_double_buffer-EP2_ETP2] tests/unit_tests/distributed/megatron_fsdp/test_mcore_fully_sharded_data_parallel.py::TestMegatronFSDPE2E::test_compatible_with_nd_parallel[optim_grads_params_double_buffer-OUTER_DP2_EP2]

Dev branch

Checked latest origin/dev (56481b050) and this issue is not present there: megatron/core/transformer/multi_token_prediction.py contains no padding_mask usage, so there is no _checkpointed_forward/padding_mask signature mismatch to fix. No dev PR is needed.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 25, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@BestJuly
Copy link
Copy Markdown
Contributor Author

Hi @ko3n1g , this should fix the issue, could you help merge this? Thanks.

@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented May 26, 2026

@copy-pr-bot copy-pr-bot Bot requested a deployment to public May 26, 2026 10:46 Abandoned
MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:

    TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
    an unexpected keyword argument 'padding_mask'

at multi_token_prediction.py:1301. The kwarg was introduced in NVIDIA#2645
on the call site; the _checkpointed_forward refactor in NVIDIA#4593 dropped
padding_mask from the recompute path.

Add padding_mask:
  * to _checkpointed_forward's signature
  * to custom_forward's signature so it flows into _proj_and_transformer_layer
  * positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
    other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
    closure-captured arg like attention_bias)
  * to the recompute_method == 'block' fallback that also calls
    _proj_and_transformer_layer directly

Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in NVIDIA#4931 to mask this exact failure.

Closes NVIDIA#4933

Signed-off-by: oliver könig <okoenig@nvidia.com>
@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented May 26, 2026

/ok to test 4a8785b

1 similar comment
@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented May 26, 2026

/ok to test 4a8785b

@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented May 26, 2026

/ok to test c89ec01

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 CI failure: MultiTokenPredictionLayer._checkpointed_forward() got unexpected kwarg 'padding_mask'

3 participants