Skip to content

Fix _checkpointed_forward missing padding_mask parameter#4966

Closed
factnn wants to merge 1 commit into
NVIDIA:mainfrom
factnn:fix/mtp-checkpointed-forward-padding-mask
Closed

Fix _checkpointed_forward missing padding_mask parameter#4966
factnn wants to merge 1 commit into
NVIDIA:mainfrom
factnn:fix/mtp-checkpointed-forward-padding-mask

Conversation

@factnn
Copy link
Copy Markdown

@factnn factnn commented May 25, 2026

Summary

_checkpointed_forward() in MultiTokenPredictionLayer was missing the padding_mask parameter that the call site passes (added in #2645). The refactor in #4593 did not include it in the new signature, causing TypeError on the recompute path (recompute_granularity == 'full').

Changes

Add padding_mask to:

  • _checkpointed_forward() method signature
  • custom_forward closure signature and its call to _proj_and_transformer_layer
  • Both te_checkpoint and tensor_parallel.checkpoint positional arg lists

Fixes #4933

  _checkpointed_forward() was missing the padding_mask parameter that the call site passes (added in NVIDIA#2645). The refactor in NVIDIA#4593 did not include it in the new signature, causing TypeError on the recompute path.

  Add padding_mask to the method signature, custom_forward closure, and both checkpoint call sites.

  Fixes NVIDIA#4933
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 25, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented May 26, 2026

Thanks for your PR! Our colleague was a bit faster with #4963, so let's close this (and mine)

@ko3n1g ko3n1g closed this May 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 CI failure: MultiTokenPredictionLayer._checkpointed_forward() got unexpected kwarg 'padding_mask'

2 participants