Skip to content

tune_kernels.py Cannot Give Better Result Than Pytorch baseline #8

@Dr-Left

Description

@Dr-Left

My Environment

torch version: 2.8.0
cuda: 12.1.105 (Build cuda_12.1.r12.1/compiler.32688072_0)
triton 3.2.0
Hardware: NVIDIA H100 NVL 94GB

Question

For many of the kernels, the pytorch baseline has a comparable or even better latency with the best triton configs. What's wrong?

Tuning Summary:

pytorch baseline for:

  • multi_lora_xw_sb: 0.279ms
  • multi_lora_dyw_dsa: 0.279ms
  • multi_lora_dyw_dsa_tma: 0.277ms
  • lora_XW_SB_TMA: 0.278ms
  • lora_dyw_dsa: 0.277ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:550 - ================================================================================
2025-11-07 14:40:30.252 | INFO     | __main__:main:551 - TUNING SUMMARY
2025-11-07 14:40:30.252 | INFO     | __main__:main:552 - ================================================================================
2025-11-07 14:40:30.252 | INFO     | __main__:main:554 - Device short name: h100-nvl
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_xw_sb: 0.293ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dyw_dsa: 0.284ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dys_dyb: 0.040ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=None, block_size_k=128, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_xw_sb_tma: 0.272ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_lora_dyw_dsa_tma: 0.272ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_xw_sb: 0.511ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_dyw_dsa: 0.259ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=256, block_size_k=64, group_size_m=8, num_stages=3, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)
2025-11-07 14:40:30.252 | INFO     | __main__:main:556 - fused_multi_lora_dys_dyb: 0.119ms
2025-11-07 14:40:30.252 | INFO     | __main__:main:557 -   Config: LoRATritonConfig(block_size_m=128, block_size_n=None, block_size_k=128, group_size_m=8, num_stages=4, num_warps=8, epilogue_subtile=None, loop_unroll_factor=None, flatten=None)

Full logging when running tune_kernels.py:
https://gist.github.com/Dr-Left/c1889749d27fabdb1ec966b0e4060d2b

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions