On a B200, when I run the benchmark_rmsnorm file WITHOUT torch compiling the rmsnorm_fwd the results look like:
=== RMSNorm Forward Pass Benchmark ===
Tensor dimensions: [9450, 5120]
Input and Output Data type: torch.bfloat16
Input tensor shapes:
x: torch.Size([9450, 5120]), dtype: torch.bfloat16
w: torch.Size([5120]), dtype: torch.float32
Executing kernel...
Kernel execution time: 0.0413 ms
Mem throughput: 4685.00 GB/s
Ref kernel execution time: 0.0649 ms
Ref mem throughput: 2981.00 GB/s
However, after doing compiled_func_cute = torch.compile(rmsnorm_fwd) the result I get is:
=== RMSNorm Forward Pass Benchmark ===
Tensor dimensions: [9450, 5120]
Input and Output Data type: torch.bfloat16
Input tensor shapes:
x: torch.Size([9450, 5120]), dtype: torch.bfloat16
w: torch.Size([5120]), dtype: torch.float32
Executing kernel...
Kernel execution time: 0.1067 ms
Mem throughput: 1815.00 GB/s
Ref kernel execution time: 0.0624 ms
Ref mem throughput: 3103.00 GB/s
I wonder where the speed regression is coming from after torch compiling because if we use this within a larger block of operations alongside torch compile, the overhead will eat all the speedup even if the CuTe kernel is compatible with torch compile. Any help is appreciated :)
On a B200, when I run the
benchmark_rmsnormfile WITHOUT torch compiling thermsnorm_fwdthe results look like:However, after doing
compiled_func_cute = torch.compile(rmsnorm_fwd)the result I get is:I wonder where the speed regression is coming from after torch compiling because if we use this within a larger block of operations alongside torch compile, the overhead will eat all the speedup even if the CuTe kernel is compatible with torch compile. Any help is appreciated :)