Ignore if this is intentional but do_bench_cudagraph replays the same buffers so the working set stays in L2 the benchmarks are measuring L2 bandwidth, not DRAM. Easy to see at N=16384 where the 256 MB working set blows past the 96 MB L2 and numbers drop to something physical.
Also the PyTorch baseline is 6 unfused kernels via .pow(2).mean().rsqrt(), F.rms_norm is the right comparison since that's what anyone would actually replace.
Numbers on RTX 5090, float16, M=4096
| N |
cutile (current) |
pytorch (current) |
cutile (mem bandwidth) |
pytorch (fused) |
| 1024 |
3117 |
404 |
1253 |
1374 |
| 2048 |
4142 |
408 |
1344 |
1355 |
| 4096 |
5019 |
262 |
1434 |
1381 |
| 8192 |
1778 |
192 |
1382 |
1270 |
| 16384 |
1451 |
185 |
1366 |
1019 |
Also I think both the issues (or maybe not an issue just a preference) would apply to all the mem bound kernels.
cc @arjkesh - probably a well thought choice on your side but leaving this open in case anyone wonders why they're seeing 3000+ GB/s and thinks their kernel broke physics.
Ignore if this is intentional but
do_bench_cudagraphreplays the same buffers so the working set stays in L2 the benchmarks are measuring L2 bandwidth, not DRAM. Easy to see at N=16384 where the 256 MB working set blows past the 96 MB L2 and numbers drop to something physical.Also the PyTorch baseline is 6 unfused kernels via
.pow(2).mean().rsqrt(),F.rms_normis the right comparison since that's what anyone would actually replace.Numbers on RTX 5090, float16, M=4096
Also I think both the issues (or maybe not an issue just a preference) would apply to all the mem bound kernels.
cc @arjkesh - probably a well thought choice on your side but leaving this open in case anyone wonders why they're seeing 3000+ GB/s and thinks their kernel broke physics.