Skip to content

na3d incompatible with torch.func.jvp and functorch transforms — missing setup_context in CutlassFNAAutogradFns #319

@JayGC

Description

@JayGC

Problem

natten.functional.na3d (and na1d/na2d) are incompatible with functorch transforms
(torch.func.jvp, torch.func.grad, torch.func.vmap) because CutlassFNAAutogradFns
uses the old-style torch.autograd.Function API without implementing setup_context.

Error

RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/main/notes/extending.func.html

Minimal Reproduction

import torch
import natten

q = torch.randn(1, 2, 4, 4, 4, 16, requires_grad=True).cuda()
k = torch.randn(1, 2, 4, 4, 4, 16).cuda()
v = torch.randn(1, 2, 4, 4, 4, 16).cuda()

def fn(q):
return natten.functional.na3d(q, k, v, (2, 3, 3), scale=1.0)

tangent = torch.randn_like(q)
out, jvp_out = torch.func.jvp(fn, (q,), (tangent,))

Environment

  • natten version: 0.21.5
  • PyTorch version: 2.8.0+cu128
  • CUDA: 12.8
  • GPU: L40S (compute capability 8.9)

Context

We are training a MeanFlow model (https://github.com/Gsunshine/meanflow) for
Navier-Stokes PDE forecasting. MeanFlow requires torch.func.jvp through the
entire network forward pass to compute the JVP-based training loss. This makes
natten's neighborhood attention unusable in our training pipeline.

Expected Behavior

natten.functional.na3d should support functorch transforms by implementing
the new-style autograd.Function API:

class CutlassFNAAutogradFn(torch.autograd.Function):
@staticmethod
def setup_context(ctx, inputs, output): # ← required
...
@staticmethod
def forward(q, k, v, ...):
...
@staticmethod
def jvp(ctx, *grad_inputs): # ← optional but ideal
...

Workarounds Attempted

1. torch.autograd.functional.jvp

We tried torch.autograd.functional.jvp as an alternative since it does not
require setup_context. However it has several critical issues in practice:

Incompatible with torch.compile:
torch.compile with aot_autograd does not support double backward, which
torch.autograd.functional.jvp requires internally (create_graph=True).

RuntimeError: torch.compile with aot_autograd does not currently support double backward

Incompatible with FSDP:
FSDP flattens parameters into 1D shards. When torch.autograd.functional.jvp
tries to reshape parameters during the forward pass, FSDP raises:

RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([9216]) but got torch.Size([384, 24])

Memory overhead:
Holds 3 full forward pass graphs simultaneously (primal + tangent + backward
graph), making it impractical for large models even at small batch sizes.

Speed:
~3x slower than a normal forward+backward due to the double backward requirement.

2. flex_attention (torch.nn.attention.flex_attention)

We tried replacing na3d with PyTorch's flex_attention using a neighborhood mask.
While flex_attention is functorch-compatible, it fails inside torch.compile with:

torch._dynamo.exc.Unsupported: view functorch tensors are not supported
by meta conversion

3. natten reference backend

We tried forcing the pure PyTorch reference backend via natten.use_reference_fna(True)
but this attribute does not exist in natten 0.21.5.

We also tried natten.allow_flex_compile() and natten.allow_flex_compile_backprop()
but these still route through CutlassFNAAutogradFns and hit the same error.

Current Status

We are forced to use global attention (F.scaled_dot_product_attention) instead
of neighborhood attention, losing the locality inductive bias and memory efficiency
of natten. We are also forced to run without torch.compile entirely, significantly
impacting training speed.

Request

Please add setup_context support to CutlassFNAAutogradFns to enable
compatibility with torch.func.jvp and other functorch transforms. This would
make natten usable in any training pipeline that requires higher-order derivatives
or functional transforms, such as MeanFlow, neural ODEs, or any physics-informed
neural network training.

Ideally also adding a jvp staticmethod would allow true forward-mode AD through
natten kernels without falling back to double backward.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions