From 8acddec81e68215e64748da1720e050e79a2f246 Mon Sep 17 00:00:00 2001 From: factnn <166481866+factnn@users.noreply.github.com> Date: Mon, 25 May 2026 15:16:03 +0800 Subject: [PATCH] Fix _checkpointed_forward missing padding_mask parameter _checkpointed_forward() was missing the padding_mask parameter that the call site passes (added in #2645). The refactor in #4593 did not include it in the new signature, causing TypeError on the recompute path. Add padding_mask to the method signature, custom_forward closure, and both checkpoint call sites. Fixes #4933 --- megatron/core/transformer/multi_token_prediction.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d9d22587a07..c4131b1d1aa 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1095,6 +1095,7 @@ def _checkpointed_forward( hidden_states: Tensor, decoder_input: Tensor, attention_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, @@ -1130,6 +1131,7 @@ def custom_forward( hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1141,6 +1143,7 @@ def custom_forward( hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, @@ -1187,6 +1190,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1206,6 +1210,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb,