From b216826c7f4f33d71a0b82fc3a08fcb46e927e9e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Tue, 26 May 2026 10:07:51 +0000 Subject: [PATCH] fix: forward padding_mask through MTP recompute path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit MultiTokenPredictionLayer.forward calls self._checkpointed_forward( padding_mask=padding_mask, ...) (multi_token_prediction.py:1305), but _checkpointed_forward and its inner custom_forward never accepted padding_mask. With recompute_granularity == 'full' and self.training, this raised: TypeError: MultiTokenPredictionLayer._checkpointed_forward() got an unexpected keyword argument 'padding_mask' at multi_token_prediction.py:1301. The kwarg was introduced in #2645 on the call site; the _checkpointed_forward refactor in #4593 dropped padding_mask from the recompute path. Add padding_mask: * to _checkpointed_forward's signature * to custom_forward's signature so it flows into _proj_and_transformer_layer * positionally to te_checkpoint and tensor_parallel.checkpoint, matching the other tensor / None args (padding_mask is a rolled tensor, not a non-tensor closure-captured arg like attention_bias) * to the recompute_method == 'block' fallback that also calls _proj_and_transformer_layer directly Also remove the @pytest.mark.flaky_in_dev markers from test_forward_backward, test_fp8_support, and test_packed_sequences_with_full_recompute, which were added in #4931 to mask this exact failure. Closes #4933 Signed-off-by: oliver könig --- megatron/core/transformer/multi_token_prediction.py | 6 ++++++ tests/unit_tests/transformer/test_multi_token_prediction.py | 3 --- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d9d22587a07..9a87331b6f8 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1095,6 +1095,7 @@ def _checkpointed_forward( hidden_states: Tensor, decoder_input: Tensor, attention_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, @@ -1130,6 +1131,7 @@ def custom_forward( hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1141,6 +1143,7 @@ def custom_forward( hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, @@ -1187,6 +1190,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1206,6 +1210,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1232,6 +1237,7 @@ def checkpoint_handler(): hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 583d4a16fa5..605da7d481d 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -422,7 +422,6 @@ def test_sharded_state_dict(self, tp, cp): assert f"mtp.layers.{i}.hnorm.weight" in sharded_state_dict.keys() assert f"mtp.layers.{i}.eh_proj.weight" in sharded_state_dict.keys() - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("2.1.0"), reason="grouped_gemm requires TransformerEngine >= 2.1.0", @@ -530,7 +529,6 @@ def set_ckpt_path(ckpt_path): for name, param in gpt_model[0].named_parameters(): assert param.main_grad is not None - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("1.7.0"), reason="Only transformer-engine>=1.7.0 supports MoE FP8 training", @@ -627,7 +625,6 @@ def test_packed_sequences(self, tp, cp): for name, param in gpt_model[0].named_parameters(): assert param.main_grad is not None, f"Gradient missing for {name}" - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("2.1.0"), reason="grouped_gemm requires TransformerEngine >= 2.1.0",