diff --git a/megatron/core/transformer/multi_token_prediction.py b/megatron/core/transformer/multi_token_prediction.py index d9d22587a07..c4131b1d1aa 100755 --- a/megatron/core/transformer/multi_token_prediction.py +++ b/megatron/core/transformer/multi_token_prediction.py @@ -1095,6 +1095,7 @@ def _checkpointed_forward( hidden_states: Tensor, decoder_input: Tensor, attention_mask: Optional[Tensor] = None, + padding_mask: Optional[Tensor] = None, context: Optional[Tensor] = None, context_mask: Optional[Tensor] = None, rotary_pos_emb: Optional[Tensor] = None, @@ -1130,6 +1131,7 @@ def custom_forward( hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1141,6 +1143,7 @@ def custom_forward( hidden_states=hidden_states, decoder_input=decoder_input, attention_mask=attention_mask, + padding_mask=padding_mask, context=context, context_mask=context_mask, rotary_pos_emb=rotary_pos_emb, @@ -1187,6 +1190,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb, @@ -1206,6 +1210,7 @@ def checkpoint_handler(): hidden_states, decoder_input, attention_mask, + padding_mask, context, context_mask, rotary_pos_emb,