Skip to content

fix: forward padding_mask through MTP recompute path#4983

Closed
ko3n1g wants to merge 1 commit into
NVIDIA:mainfrom
ko3n1g:ko3n1g/fix/mtp-checkpointed-forward-padding-mask
Closed

fix: forward padding_mask through MTP recompute path#4983
ko3n1g wants to merge 1 commit into
NVIDIA:mainfrom
ko3n1g:ko3n1g/fix/mtp-checkpointed-forward-padding-mask

Conversation

@ko3n1g
Copy link
Copy Markdown
Contributor

@ko3n1g ko3n1g commented May 26, 2026

Claude summary

Fixes #4933.

Bug

MultiTokenPredictionLayer.forward calls self._checkpointed_forward(..., padding_mask=padding_mask, ...) at megatron/core/transformer/multi_token_prediction.py:1305, but _checkpointed_forward (and its inner custom_forward) never accepted padding_mask. With recompute_granularity == 'full' and self.training, this raised:

TypeError: MultiTokenPredictionLayer._checkpointed_forward() got an unexpected keyword argument 'padding_mask'
  at megatron/core/transformer/multi_token_prediction.py:1301

Two commits that don't compose:

The failure surfaced across TestMultiTokenPrediction::test_forward_backward[*-*-True] (9 combos), test_fp8_support[True], and test_packed_sequences_with_full_recompute. It also broke three Megatron-FSDP nodes in test_compatible_with_nd_parallel.

Fix

Forward padding_mask through the recompute path:

  1. Add padding_mask: Optional[Tensor] = None to _checkpointed_forward's signature.
  2. Add padding_mask to custom_forward's signature so it reaches _proj_and_transformer_layer (which already accepts it).
  3. Pass padding_mask positionally to both te_checkpoint and tensor_parallel.checkpoint. It's a tensor (rolled in _get_embeddings), so it follows the same positional-arg pattern as attention_mask rather than being captured via closure like non-tensor args (attention_bias, inference_params, packed_seq_params).
  4. Also forward padding_mask in the recompute_method == 'block' fallback path which calls _proj_and_transformer_layer directly.

Test cleanup

Removed @pytest.mark.flaky_in_dev from three tests that were marked in #4931 specifically to mask this failure:

  • TestMultiTokenPrediction::test_forward_backward
  • TestMultiTokenPrediction::test_fp8_support
  • TestMultiTokenPrediction::test_packed_sequences_with_full_recompute

Repro

pytest tests/unit_tests/transformer/test_multi_token_prediction.py::TestMultiTokenPrediction::test_forward_backward

MultiTokenPredictionLayer.forward calls self._checkpointed_forward(
padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but
_checkpointed_forward and its inner custom_forward never accepted
padding_mask. With recompute_granularity == 'full' and self.training,
this raised:

    TypeError: MultiTokenPredictionLayer._checkpointed_forward() got
    an unexpected keyword argument 'padding_mask'

at multi_token_prediction.py:1301. The kwarg was introduced in NVIDIA#2645
on the call site; the _checkpointed_forward refactor in NVIDIA#4593 dropped
padding_mask from the recompute path.

Add padding_mask:
  * to _checkpointed_forward's signature
  * to custom_forward's signature so it flows into _proj_and_transformer_layer
  * positionally to te_checkpoint and tensor_parallel.checkpoint, matching the
    other tensor / None args (padding_mask is a rolled tensor, not a non-tensor
    closure-captured arg like attention_bias)
  * to the recompute_method == 'block' fallback that also calls
    _proj_and_transformer_layer directly

Also remove the @pytest.mark.flaky_in_dev markers from
test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute,
which were added in NVIDIA#4931 to mask this exact failure.

Closes NVIDIA#4933

Signed-off-by: oliver könig <okoenig@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 26, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@ko3n1g
Copy link
Copy Markdown
Contributor Author

ko3n1g commented May 26, 2026

/ok to test b216826

@ko3n1g
Copy link
Copy Markdown
Contributor Author

ko3n1g commented May 26, 2026

close in favor of #4963

@ko3n1g ko3n1g closed this May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 CI failure: MultiTokenPredictionLayer._checkpointed_forward() got unexpected kwarg 'padding_mask'

1 participant