Skip to content

Speed regression in rmsnorm_fwd when used with torch compile #77

@Maharshi-Pandya

Description

@Maharshi-Pandya

On a B200, when I run the benchmark_rmsnorm file WITHOUT torch compiling the rmsnorm_fwd the results look like:

=== RMSNorm Forward Pass Benchmark ===
Tensor dimensions: [9450, 5120]
Input and Output Data type: torch.bfloat16
Input tensor shapes:
x: torch.Size([9450, 5120]), dtype: torch.bfloat16
w: torch.Size([5120]), dtype: torch.float32
Executing kernel...
Kernel execution time: 0.0413 ms
Mem throughput: 4685.00 GB/s
Ref kernel execution time: 0.0649 ms
Ref mem throughput: 2981.00 GB/s

However, after doing compiled_func_cute = torch.compile(rmsnorm_fwd) the result I get is:

=== RMSNorm Forward Pass Benchmark ===
Tensor dimensions: [9450, 5120]
Input and Output Data type: torch.bfloat16
Input tensor shapes:
x: torch.Size([9450, 5120]), dtype: torch.bfloat16
w: torch.Size([5120]), dtype: torch.float32
Executing kernel...
Kernel execution time: 0.1067 ms
Mem throughput: 1815.00 GB/s
Ref kernel execution time: 0.0624 ms
Ref mem throughput: 3103.00 GB/s

I wonder where the speed regression is coming from after torch compiling because if we use this within a larger block of operations alongside torch compile, the overhead will eat all the speedup even if the CuTe kernel is compatible with torch compile. Any help is appreciated :)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions