From b1b302613f2e1f1e5738c2e8363ce5e6452ab241 Mon Sep 17 00:00:00 2001 From: Carlos Gomes Date: Fri, 8 May 2026 23:08:37 +0200 Subject: [PATCH] guard fuser grad checks on non-leaf nodes (#2919) * guard fuser grad checks on non-leaf nodes Signed-off-by: CarlosGomes98 * rely on set_output_requires_grad flag, update docstring Signed-off-by: CarlosGomes98 * make code clearer Signed-off-by: CarlosGomes98 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: CarlosGomes98 * rely on set_output_requires_grad flag, update docstring Signed-off-by: CarlosGomes98 * make code clearer Signed-off-by: CarlosGomes98 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert cudnn-frontend submodule bump Signed-off-by: CarlosGomes98 Made-with: Cursor * Tweak comment Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: CarlosGomes98 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/fuser.py | 28 ++++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index a3c7e1bac7..5283af8144 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -65,6 +65,7 @@ def forward( input_: torch.Tensor, fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], + set_output_requires_grad: bool, *params_and_extra_inputs: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -79,6 +80,8 @@ def forward( Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation + set_output_requires_grad: bool + Whether to set ``requires_grad`` flags on returned tensors *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -138,7 +141,8 @@ def forward( ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(idx >= fuser.first_op_requiring_backward) + if set_output_requires_grad: + y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys # Flatten list of extra outputs @@ -190,7 +194,8 @@ def forward( for tensor in [x] + extra_outputs_flat: tensor._do_not_clear = True - x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) + if set_output_requires_grad: + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: return x, *extra_outputs_flat @@ -293,6 +298,7 @@ def backward( dx, # input_ None, # fuser None, # basic_op_kwargs + None, # set_output_requires_grad *grad_params_flat, *grad_extra_inputs_flat, ) @@ -501,20 +507,22 @@ def __call__( op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) # Fuser forward pass - if is_grad_enabled: - forward_func = _OperationFuserAutogradFunction.apply - args = [] - else: - forward_func = _OperationFuserAutogradFunction.forward - args = [None] - args += ( + # Note: We call forward directly when is_grad_enabled=False, + # which can expose non-leaf tensors to the inner ops. Avoid + # problems in this case by passing set_output_requires_grad=False. + args = ( input, self, basic_op_kwargs, + is_grad_enabled, # set_output_requires_grad *self._flat_basic_op_params, *extra_inputs, ) - return forward_func(*args) + + if not is_grad_enabled: + return _OperationFuserAutogradFunction.forward(None, *args) + + return _OperationFuserAutogradFunction.apply(*args) def register_forward_fusion(