Problem
natten.functional.na3d (and na1d/na2d) are incompatible with functorch transforms
(torch.func.jvp, torch.func.grad, torch.func.vmap) because CutlassFNAAutogradFns
uses the old-style torch.autograd.Function API without implementing setup_context.
Error
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/main/notes/extending.func.html
Minimal Reproduction
import torch
import natten
q = torch.randn(1, 2, 4, 4, 4, 16, requires_grad=True).cuda()
k = torch.randn(1, 2, 4, 4, 4, 16).cuda()
v = torch.randn(1, 2, 4, 4, 4, 16).cuda()
def fn(q):
return natten.functional.na3d(q, k, v, (2, 3, 3), scale=1.0)
tangent = torch.randn_like(q)
out, jvp_out = torch.func.jvp(fn, (q,), (tangent,))
Environment
- natten version: 0.21.5
- PyTorch version: 2.8.0+cu128
- CUDA: 12.8
- GPU: L40S (compute capability 8.9)
Context
We are training a MeanFlow model (https://github.com/Gsunshine/meanflow) for
Navier-Stokes PDE forecasting. MeanFlow requires torch.func.jvp through the
entire network forward pass to compute the JVP-based training loss. This makes
natten's neighborhood attention unusable in our training pipeline.
Expected Behavior
natten.functional.na3d should support functorch transforms by implementing
the new-style autograd.Function API:
class CutlassFNAAutogradFn(torch.autograd.Function):
@staticmethod
def setup_context(ctx, inputs, output): # ← required
...
@staticmethod
def forward(q, k, v, ...):
...
@staticmethod
def jvp(ctx, *grad_inputs): # ← optional but ideal
...
Workarounds Attempted
1. torch.autograd.functional.jvp
We tried torch.autograd.functional.jvp as an alternative since it does not
require setup_context. However it has several critical issues in practice:
Incompatible with torch.compile:
torch.compile with aot_autograd does not support double backward, which
torch.autograd.functional.jvp requires internally (create_graph=True).
RuntimeError: torch.compile with aot_autograd does not currently support double backward
Incompatible with FSDP:
FSDP flattens parameters into 1D shards. When torch.autograd.functional.jvp
tries to reshape parameters during the forward pass, FSDP raises:
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([9216]) but got torch.Size([384, 24])
Memory overhead:
Holds 3 full forward pass graphs simultaneously (primal + tangent + backward
graph), making it impractical for large models even at small batch sizes.
Speed:
~3x slower than a normal forward+backward due to the double backward requirement.
2. flex_attention (torch.nn.attention.flex_attention)
We tried replacing na3d with PyTorch's flex_attention using a neighborhood mask.
While flex_attention is functorch-compatible, it fails inside torch.compile with:
torch._dynamo.exc.Unsupported: view functorch tensors are not supported
by meta conversion
3. natten reference backend
We tried forcing the pure PyTorch reference backend via natten.use_reference_fna(True)
but this attribute does not exist in natten 0.21.5.
We also tried natten.allow_flex_compile() and natten.allow_flex_compile_backprop()
but these still route through CutlassFNAAutogradFns and hit the same error.
Current Status
We are forced to use global attention (F.scaled_dot_product_attention) instead
of neighborhood attention, losing the locality inductive bias and memory efficiency
of natten. We are also forced to run without torch.compile entirely, significantly
impacting training speed.
Request
Please add setup_context support to CutlassFNAAutogradFns to enable
compatibility with torch.func.jvp and other functorch transforms. This would
make natten usable in any training pipeline that requires higher-order derivatives
or functional transforms, such as MeanFlow, neural ODEs, or any physics-informed
neural network training.
Ideally also adding a jvp staticmethod would allow true forward-mode AD through
natten kernels without falling back to double backward.
Problem
natten.functional.na3d(and na1d/na2d) are incompatible with functorch transforms(
torch.func.jvp,torch.func.grad,torch.func.vmap) becauseCutlassFNAAutogradFnsuses the old-style
torch.autograd.FunctionAPI without implementingsetup_context.Error
RuntimeError: In order to use an autograd.Function with functorch transforms
(vmap, grad, jvp, jacrev, ...), it must override the setup_context staticmethod.
For more details, please see https://pytorch.org/docs/main/notes/extending.func.html
Minimal Reproduction
import torch
import natten
q = torch.randn(1, 2, 4, 4, 4, 16, requires_grad=True).cuda()
k = torch.randn(1, 2, 4, 4, 4, 16).cuda()
v = torch.randn(1, 2, 4, 4, 4, 16).cuda()
def fn(q):
return natten.functional.na3d(q, k, v, (2, 3, 3), scale=1.0)
tangent = torch.randn_like(q)
out, jvp_out = torch.func.jvp(fn, (q,), (tangent,))
Environment
Context
We are training a MeanFlow model (https://github.com/Gsunshine/meanflow) for
Navier-Stokes PDE forecasting. MeanFlow requires
torch.func.jvpthrough theentire network forward pass to compute the JVP-based training loss. This makes
natten's neighborhood attention unusable in our training pipeline.
Expected Behavior
natten.functional.na3dshould support functorch transforms by implementingthe new-style
autograd.FunctionAPI:class CutlassFNAAutogradFn(torch.autograd.Function):
@staticmethod
def setup_context(ctx, inputs, output): # ← required
...
@staticmethod
def forward(q, k, v, ...):
...
@staticmethod
def jvp(ctx, *grad_inputs): # ← optional but ideal
...
Workarounds Attempted
1. torch.autograd.functional.jvp
We tried
torch.autograd.functional.jvpas an alternative since it does notrequire
setup_context. However it has several critical issues in practice:Incompatible with
torch.compile:torch.compilewithaot_autograddoes not support double backward, whichtorch.autograd.functional.jvprequires internally (create_graph=True).RuntimeError: torch.compile with aot_autograd does not currently support double backward
Incompatible with FSDP:
FSDP flattens parameters into 1D shards. When
torch.autograd.functional.jvptries to reshape parameters during the forward pass, FSDP raises:
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([9216]) but got torch.Size([384, 24])
Memory overhead:
Holds 3 full forward pass graphs simultaneously (primal + tangent + backward
graph), making it impractical for large models even at small batch sizes.
Speed:
~3x slower than a normal forward+backward due to the double backward requirement.
2. flex_attention (torch.nn.attention.flex_attention)
We tried replacing na3d with PyTorch's flex_attention using a neighborhood mask.
While flex_attention is functorch-compatible, it fails inside torch.compile with:
torch._dynamo.exc.Unsupported: view functorch tensors are not supported
by meta conversion
3. natten reference backend
We tried forcing the pure PyTorch reference backend via
natten.use_reference_fna(True)but this attribute does not exist in natten 0.21.5.
We also tried
natten.allow_flex_compile()andnatten.allow_flex_compile_backprop()but these still route through
CutlassFNAAutogradFnsand hit the same error.Current Status
We are forced to use global attention (
F.scaled_dot_product_attention) insteadof neighborhood attention, losing the locality inductive bias and memory efficiency
of natten. We are also forced to run without
torch.compileentirely, significantlyimpacting training speed.
Request
Please add
setup_contextsupport toCutlassFNAAutogradFnsto enablecompatibility with
torch.func.jvpand other functorch transforms. This wouldmake natten usable in any training pipeline that requires higher-order derivatives
or functional transforms, such as MeanFlow, neural ODEs, or any physics-informed
neural network training.
Ideally also adding a
jvpstaticmethod would allow true forward-mode AD throughnatten kernels without falling back to double backward.