diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d9d22587a07..9a87331b6f8 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1095,6 +1095,7 @@ def _checkpointed_forward( hidden_states: Tensor, decoder_input: Tensor, attention_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, @@ -1130,6 +1131,7 @@ def custom_forward( hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1141,6 +1143,7 @@ def custom_forward( hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, @@ -1187,6 +1190,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1206,6 +1210,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1232,6 +1237,7 @@ def checkpoint_handler(): hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, diff --git a/tests/unit_tests/transformer/test_multi_token_prediction.py b/tests/unit_tests/transformer/test_multi_token_prediction.py index 583d4a16fa5..605da7d481d 100644 --- a/tests/unit_tests/transformer/test_multi_token_prediction.py +++ b/tests/unit_tests/transformer/test_multi_token_prediction.py @@ -422,7 +422,6 @@ def test_sharded_state_dict(self, tp, cp): assert f"mtp.layers.{i}.hnorm.weight" in sharded_state_dict.keys() assert f"mtp.layers.{i}.eh_proj.weight" in sharded_state_dict.keys() - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("2.1.0"), reason="grouped_gemm requires TransformerEngine >= 2.1.0", @@ -530,7 +529,6 @@ def set_ckpt_path(ckpt_path): for name, param in gpt_model[0].named_parameters(): assert param.main_grad is not None - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("1.7.0"), reason="Only transformer-engine>=1.7.0 supports MoE FP8 training", @@ -627,7 +625,6 @@ def test_packed_sequences(self, tp, cp): for name, param in gpt_model[0].named_parameters(): assert param.main_grad is not None, f"Gradient missing for {name}" - @pytest.mark.flaky_in_dev @pytest.mark.skipif( not HAVE_TE or not is_te_min_version("2.1.0"), reason="grouped_gemm requires TransformerEngine >= 2.1.0",