Summary
The mamba3_siso_bwd_kernel_dqkv Triton kernel exhibits a 38.7x performance regression on GB200 (SM100a / Blackwell) compared to H200 (SM90 / Hopper). Training achieves only 5% MFU on GB200 versus 33% MFU on H200 for an identical 1.4B parameter Mamba3 SISO model. The backward kernel alone accounts for 90.76% of total GPU time on GB200.
The root cause is an NVIDIA ptxas internal compiler error (C7907) that eliminates all num_warps=4 and num_warps=8 autotuner configurations for the backward kernel on SM100a. The 3 surviving num_warps=2 configs exhibit extreme register spilling (42 to 50 KB per thread), directly causing the regression.
Mamba2 (SSD) is unaffected and achieves consistent ~32% MFU across all GPUs.
Training throughput
Measured with d_model=2048, n_layers=48, batch=4, bf16, seq_len=4096:
| GPU |
Mamba2 TFLOP/s |
MFU |
Mamba3 TFLOP/s |
MFU |
| A100 |
102.5 |
33% |
37.5 |
12% |
| H200 |
239.2 |
24% |
323.6 |
33% |
| GB200 |
360.4 |
32% |
56.0 |
5% |
Kernel level comparison
Self CUDA time over 10 timed iterations:
| Kernel |
H200 (ms) |
GB200 (ms) |
Ratio |
mamba3_siso_bwd_kernel_dqkv |
587 |
22,746 |
38.7x slower |
mamba3_siso_fwd_kernel |
602 |
450 |
0.75x (faster) |
mamba3_siso_bwd_kernel_rotary_bias_angles |
368 |
375 |
~1.0x |
aten::mm (cuBLAS) |
1,324 |
543 |
0.41x (faster) |
Root Cause
When Triton compiles mamba3_siso_bwd_kernel_dqkv for SM100a, ptxas triggers C7907 on all num_warps=4 and num_warps=8 configs (5 produce trap stubs, 1 crashes outright). Only the 3 num_warps=2 configs survive, and those are capped at 32 registers with 42 to 50 KB of spill traffic per thread. With 64 threads per block, that amounts to roughly 5.8 MB of spill traffic per thread block, which dwarfs the kernel's useful data movement.
On H200, all 9 autotuner configs compile and launch, giving access to higher warp configurations with substantially lower register pressure.
Additionally, SM100's tensor memory budget independently prevents num_warps>=4 from launching on the forward kernel (required tensor memory: 544, hardware limit: 512), suggesting the Mamba3 kernels may need tile size tuning or register pressure reduction for Blackwell regardless of the C7907 fix.
Error output
Internal Triton PTX codegen error
`ptxas` stderr:
ptxas fatal : (C7907) Internal compiler error.
ptxas fatal : Ptx assembly aborted due to errors
Companion issue filed upstream: triton-lang/triton#9933
Environment
| Component |
Version |
| GPU |
NVIDIA GB200 (Blackwell, SM100a, compute capability 10.0) |
| GPU (baseline) |
NVIDIA H200 (Hopper, SM90, compute capability 9.0) |
| Driver |
580.65.06 |
| CUDA toolkit |
13.1 (V13.1.115, built 2025-12-16) |
| Triton |
3.6.0 |
| PyTorch |
2.11.0a0+eb65b36914.nv26.02 |
| mamba_ssm |
2.3.1 |
Workaround
None known at this time. The forward pass and other kernels are unaffected.
Summary
The
mamba3_siso_bwd_kernel_dqkvTriton kernel exhibits a 38.7x performance regression on GB200 (SM100a / Blackwell) compared to H200 (SM90 / Hopper). Training achieves only 5% MFU on GB200 versus 33% MFU on H200 for an identical 1.4B parameter Mamba3 SISO model. The backward kernel alone accounts for 90.76% of total GPU time on GB200.The root cause is an NVIDIA
ptxasinternal compiler error (C7907) that eliminates allnum_warps=4andnum_warps=8autotuner configurations for the backward kernel on SM100a. The 3 survivingnum_warps=2configs exhibit extreme register spilling (42 to 50 KB per thread), directly causing the regression.Mamba2 (SSD) is unaffected and achieves consistent ~32% MFU across all GPUs.
Training throughput
Measured with d_model=2048, n_layers=48, batch=4, bf16, seq_len=4096:
Kernel level comparison
Self CUDA time over 10 timed iterations:
mamba3_siso_bwd_kernel_dqkvmamba3_siso_fwd_kernelmamba3_siso_bwd_kernel_rotary_bias_anglesaten::mm(cuBLAS)Root Cause
When Triton compiles
mamba3_siso_bwd_kernel_dqkvfor SM100a,ptxastriggers C7907 on allnum_warps=4andnum_warps=8configs (5 produce trap stubs, 1 crashes outright). Only the 3num_warps=2configs survive, and those are capped at 32 registers with 42 to 50 KB of spill traffic per thread. With 64 threads per block, that amounts to roughly 5.8 MB of spill traffic per thread block, which dwarfs the kernel's useful data movement.On H200, all 9 autotuner configs compile and launch, giving access to higher warp configurations with substantially lower register pressure.
Additionally, SM100's tensor memory budget independently prevents
num_warps>=4from launching on the forward kernel (required tensor memory: 544, hardware limit: 512), suggesting the Mamba3 kernels may need tile size tuning or register pressure reduction for Blackwell regardless of the C7907 fix.Error output
Companion issue filed upstream: triton-lang/triton#9933
Environment
Workaround
None known at this time. The forward pass and other kernels are unaffected.