diff --git a/flash_sparse_attn/ops/triton/autotuner.py b/flash_sparse_attn/ops/triton/autotuner.py new file mode 100644 index 0000000..9da796b --- /dev/null +++ b/flash_sparse_attn/ops/triton/autotuner.py @@ -0,0 +1,247 @@ +import triton + + +def get_fwd_dense_autotune_configs(): + configs = [] + for tile_m in [64, 128, 256]: + for tile_n in [32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_fwd_sparse_autotune_configs(): + configs = [] + for tile_m in [64, 128, 256]: + for tile_n in [32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_fwd_gated_autotune_configs(): + configs = [] + for tile_m in [64, 128, 256]: + for tile_n in [32, 64, 128]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_bwd_dense_autotune_configs(): + configs = [] + for tile_m in [32, 64, 128]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_bwd_sparse_autotune_configs(): + configs = [] + for tile_m in [32, 64, 128]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_bwd_gated_autotune_configs(): + configs = [] + for tile_m in [32, 64, 128]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_dec_dense_autotune_configs(): + configs = [] + for tile_m in [16, 32, 64]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_dec_sparse_autotune_configs(): + configs = [] + for tile_m in [16, 32, 64]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def get_dec_gated_autotune_configs(): + configs = [] + for tile_m in [16, 32, 64]: + for tile_n in [64, 128, 256]: + for num_warps in [4, 8]: + for num_stages in [1, 2]: + configs.append( + triton.Config( + {"TILE_M": tile_m, "TILE_N": tile_n}, + num_warps=num_warps, + num_stages=num_stages, + num_ctas=1, + ) + ) + return configs + + +def make_fwd_dense_autotuned_kernel(jit_kernel): + configs = get_fwd_dense_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_fwd_sparse_autotuned_kernel(jit_kernel): + configs = get_fwd_sparse_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_fwd_gated_autotuned_kernel(jit_kernel): + configs = get_fwd_gated_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_bwd_dense_autotuned_kernel(jit_kernel): + configs = get_bwd_dense_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_bwd_sparse_autotuned_kernel(jit_kernel): + configs = get_bwd_sparse_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_bwd_gated_autotuned_kernel(jit_kernel): + configs = get_bwd_gated_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_dec_dense_autotuned_kernel(jit_kernel): + configs = get_dec_dense_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_dec_sparse_autotuned_kernel(jit_kernel): + configs = get_dec_sparse_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +def make_dec_gated_autotuned_kernel(jit_kernel): + configs = get_dec_gated_autotune_configs() + return triton.autotune( + configs=configs, + key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"], + )(jit_kernel) + + +class AutotunedKernel: + STRIP_KWARGS = {"TILE_M", "TILE_N", "num_warps", "num_stages", "num_ctas"} + + def __init__(self, autotuned_kernel): + self._autotuned = autotuned_kernel + + def __getitem__(self, grid): + autotuned = self._autotuned + + class _Launcher: + def __call__(_, *args, **kwargs): + for key in AutotunedKernel.STRIP_KWARGS: + kwargs.pop(key, None) + return autotuned[grid](*args, **kwargs) + + return _Launcher() + + def __getattr__(self, name): + return getattr(self._autotuned, name) diff --git a/flash_sparse_attn/ops/triton/flash_dense_bwd.py b/flash_sparse_attn/ops/triton/flash_dense_bwd.py index 7f67048..65112b8 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_bwd.py @@ -17,6 +17,7 @@ flash_bwd_preprocess, flash_bwd_postprocess, kernel_repr, + autotuner, ) @@ -163,6 +164,8 @@ def _bwd_dense_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -639,6 +642,18 @@ def _bwd_dense_kernel( _bwd_dense_kernel = cache_utils.wrap_kernel(_bwd_dense_kernel) +_bwd_dense_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _bwd_dense_kernel_autotuned + if _bwd_dense_kernel_autotuned is None: + jit_kernel = _bwd_dense_kernel._kernel + autotuned = autotuner.make_bwd_dense_autotuned_kernel(jit_kernel) + _bwd_dense_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _bwd_dense_kernel_autotuned + + def _flash_dense_attn_backward( query: torch.Tensor, key: torch.Tensor, @@ -649,6 +664,7 @@ def _flash_dense_attn_backward( is_causal: bool = False, softmax_scale: float = None, window_size: Tuple[int, int] = (None, None), + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -682,13 +698,19 @@ def _flash_dense_attn_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_dense_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_dense_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) seqlen_q_rounded = int(math.ceil(seqlen_q / TILE_M) * TILE_M) head_dim_rounded = int(math.ceil(head_dim / 32) * 32) @@ -746,7 +768,7 @@ def _flash_dense_attn_backward( batch_size=batch_size, ) - _bwd_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -789,9 +811,11 @@ def _flash_dense_attn_backward( None, None, None, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -840,6 +864,7 @@ def _flash_dense_attn_varlen_backward( window_size: Tuple[int, int] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -876,13 +901,19 @@ def _flash_dense_attn_varlen_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_dense_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_dense_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) total_q_rounded_padded = int( math.ceil((total_q + batch_size * TILE_M) / TILE_M) * TILE_M @@ -946,7 +977,7 @@ def _flash_dense_attn_varlen_backward( batch_size=batch_size, ) - _bwd_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -989,9 +1020,11 @@ def _flash_dense_attn_varlen_backward( cu_seqlens_k, seqused_q, seqused_k, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_dense_dec.py b/flash_sparse_attn/ops/triton/flash_dense_dec.py index 3ba711d..fcd663d 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_dec.py +++ b/flash_sparse_attn/ops/triton/flash_dense_dec.py @@ -17,6 +17,7 @@ mask, flash_dec_combine, kernel_repr, + autotuner, ) @@ -133,6 +134,8 @@ def _dec_dense_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -480,6 +483,18 @@ def _dec_dense_kernel( _dec_dense_kernel = cache_utils.wrap_kernel(_dec_dense_kernel) +_dec_dense_autotuned_kernel = None + + +def _get_autotuned_kernel(): + global _dec_dense_autotuned_kernel + if _dec_dense_autotuned_kernel is None: + _dec_dense_autotuned_kernel = autotuner.AutotunedKernel( + autotuner.make_dec_dense_autotuned_kernel(_dec_dense_kernel.kernel) + ) + return _dec_dense_autotuned_kernel + + def _flash_dense_attn_decode( query: torch.Tensor, key: torch.Tensor, @@ -492,6 +507,7 @@ def _flash_dense_attn_decode( is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -524,14 +540,21 @@ def _flash_dense_attn_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_dense_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_dense_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -581,7 +604,7 @@ def _flash_dense_attn_decode( num_splits=num_splits, ) - _dec_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -616,6 +639,8 @@ def _flash_dense_attn_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -657,6 +682,7 @@ def _flash_dense_attn_varlen_decode( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -690,14 +716,21 @@ def _flash_dense_attn_varlen_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_dense_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_dense_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -747,7 +780,7 @@ def _flash_dense_attn_varlen_decode( num_splits=num_splits, ) - _dec_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -782,6 +815,8 @@ def _flash_dense_attn_varlen_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_dense_fwd.py b/flash_sparse_attn/ops/triton/flash_dense_fwd.py index baf79c0..ad4367e 100644 --- a/flash_sparse_attn/ops/triton/flash_dense_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_dense_fwd.py @@ -17,6 +17,7 @@ mask, flash_fwd_combine, kernel_repr, + autotuner, ) @@ -134,6 +135,8 @@ def _fwd_dense_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -621,6 +624,18 @@ def _fwd_dense_kernel( _fwd_dense_kernel = cache_utils.wrap_kernel(_fwd_dense_kernel) +_fwd_dense_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _fwd_dense_kernel_autotuned + if _fwd_dense_kernel_autotuned is None: + jit_kernel = _fwd_dense_kernel._kernel + autotuned = autotuner.make_fwd_dense_autotuned_kernel(jit_kernel) + _fwd_dense_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _fwd_dense_kernel_autotuned + + def _flash_dense_attn_forward( query: torch.Tensor, key: torch.Tensor, @@ -636,6 +651,7 @@ def _flash_dense_attn_forward( pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float]: device = query.device @@ -671,16 +687,22 @@ def _flash_dense_attn_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_dense_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_dense_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -732,7 +754,7 @@ def _flash_dense_attn_forward( num_splits=num_splits, ) - _fwd_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -764,9 +786,11 @@ def _flash_dense_attn_forward( None, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, @@ -818,6 +842,7 @@ def _flash_dense_attn_varlen_forward( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float]: device = query.device @@ -856,16 +881,22 @@ def _flash_dense_attn_varlen_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_dense_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_dense_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_dense_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -913,7 +944,7 @@ def _flash_dense_attn_varlen_forward( num_splits=num_splits, ) - _fwd_dense_kernel[grid]( + kernel[grid]( query, key, value, @@ -945,9 +976,11 @@ def _flash_dense_attn_varlen_forward( seqused_k, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_gated_bwd.py b/flash_sparse_attn/ops/triton/flash_gated_bwd.py index d7fa1be..a399490 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_bwd.py @@ -17,6 +17,7 @@ flash_bwd_preprocess, flash_bwd_postprocess, kernel_repr, + autotuner, ) @@ -281,6 +282,8 @@ def _bwd_gated_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -1013,6 +1016,18 @@ def _bwd_gated_kernel( _bwd_gated_kernel = cache_utils.wrap_kernel(_bwd_gated_kernel) +_bwd_gated_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _bwd_gated_kernel_autotuned + if _bwd_gated_kernel_autotuned is None: + jit_kernel = _bwd_gated_kernel._kernel + autotuned = autotuner.make_bwd_gated_autotuned_kernel(jit_kernel) + _bwd_gated_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _bwd_gated_kernel_autotuned + + def _flash_gated_attn_backward( query: torch.Tensor, key: torch.Tensor, @@ -1029,6 +1044,7 @@ def _flash_gated_attn_backward( is_logsigmoid_gate: bool = True, is_adapt_gate: bool = True, window_size: Tuple[int, int] = (None, None), + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -1066,13 +1082,19 @@ def _flash_gated_attn_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_gated_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_gated_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) seqlen_q_rounded = int(math.ceil(seqlen_q / TILE_M) * TILE_M) head_dim_rounded = int(math.ceil(head_dim / 32) * 32) @@ -1136,7 +1158,7 @@ def _flash_gated_attn_backward( batch_size=batch_size, ) - _bwd_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -1197,9 +1219,11 @@ def _flash_gated_attn_backward( None, None, None, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -1259,6 +1283,7 @@ def _flash_gated_attn_varlen_backward( window_size: Tuple[int, int] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -1299,13 +1324,19 @@ def _flash_gated_attn_varlen_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_gated_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_gated_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) total_q_rounded_padded = int( math.ceil((total_q + batch_size * TILE_M) / TILE_M) * TILE_M @@ -1374,7 +1405,7 @@ def _flash_gated_attn_varlen_backward( batch_size=batch_size, ) - _bwd_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -1435,9 +1466,11 @@ def _flash_gated_attn_varlen_backward( cu_seqlens_k, seqused_q, seqused_k, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_gated_dec.py b/flash_sparse_attn/ops/triton/flash_gated_dec.py index 8d85519..6060b8c 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_dec.py +++ b/flash_sparse_attn/ops/triton/flash_gated_dec.py @@ -17,6 +17,7 @@ mask, flash_dec_combine, kernel_repr, + autotuner, ) @@ -213,6 +214,8 @@ def _dec_gated_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -771,6 +774,18 @@ def _dec_gated_kernel( _dec_gated_kernel = cache_utils.wrap_kernel(_dec_gated_kernel) +_dec_gated_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _dec_gated_kernel_autotuned + if _dec_gated_kernel_autotuned is None: + jit_kernel = _dec_gated_kernel.kernel + autotuned = autotuner.make_dec_gated_autotuned_kernel(jit_kernel) + _dec_gated_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _dec_gated_kernel_autotuned + + def _flash_gated_attn_decode( query: torch.Tensor, key: torch.Tensor, @@ -788,6 +803,7 @@ def _flash_gated_attn_decode( is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -826,14 +842,21 @@ def _flash_gated_attn_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_gated_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_gated_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -883,7 +906,7 @@ def _flash_gated_attn_decode( num_splits=num_splits, ) - _dec_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -928,6 +951,8 @@ def _flash_gated_attn_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -975,6 +1000,7 @@ def _flash_gated_attn_varlen_decode( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -1014,14 +1040,21 @@ def _flash_gated_attn_varlen_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_gated_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_gated_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -1071,7 +1104,7 @@ def _flash_gated_attn_varlen_decode( num_splits=num_splits, ) - _dec_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -1116,6 +1149,8 @@ def _flash_gated_attn_varlen_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_gated_fwd.py b/flash_sparse_attn/ops/triton/flash_gated_fwd.py index 3320f87..4b284b6 100644 --- a/flash_sparse_attn/ops/triton/flash_gated_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_gated_fwd.py @@ -17,6 +17,7 @@ mask, flash_dec_combine, kernel_repr, + autotuner, ) @@ -214,6 +215,8 @@ def _fwd_gated_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -977,6 +980,18 @@ def _fwd_gated_kernel( _fwd_gated_kernel = cache_utils.wrap_kernel(_fwd_gated_kernel) +_fwd_gated_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _fwd_gated_kernel_autotuned + if _fwd_gated_kernel_autotuned is None: + jit_kernel = _fwd_gated_kernel._kernel + autotuned = autotuner.make_fwd_gated_autotuned_kernel(jit_kernel) + _fwd_gated_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _fwd_gated_kernel_autotuned + + def _flash_gated_attn_forward( query: torch.Tensor, key: torch.Tensor, @@ -998,6 +1013,7 @@ def _flash_gated_attn_forward( pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float, float, float]: device = query.device @@ -1037,16 +1053,22 @@ def _flash_gated_attn_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_gated_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_gated_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -1098,7 +1120,7 @@ def _flash_gated_attn_forward( num_splits=num_splits, ) - _fwd_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -1140,9 +1162,11 @@ def _flash_gated_attn_forward( None, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, @@ -1202,6 +1226,7 @@ def _flash_gated_attn_varlen_forward( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float, float, float]: device = query.device @@ -1244,16 +1269,22 @@ def _flash_gated_attn_varlen_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_gated_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_gated_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_gated_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -1305,7 +1336,7 @@ def _flash_gated_attn_varlen_forward( num_splits=num_splits, ) - _fwd_gated_kernel[grid]( + kernel[grid]( query, key, value, @@ -1347,9 +1378,11 @@ def _flash_gated_attn_varlen_forward( None, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py index 70b4c10..ab3444b 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_bwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_bwd.py @@ -17,6 +17,7 @@ flash_bwd_preprocess, flash_bwd_postprocess, kernel_repr, + autotuner, ) @@ -186,6 +187,8 @@ def _bwd_sparse_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -696,6 +699,18 @@ def _bwd_sparse_kernel( _bwd_sparse_kernel = cache_utils.wrap_kernel(_bwd_sparse_kernel) +_bwd_sparse_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _bwd_sparse_kernel_autotuned + if _bwd_sparse_kernel_autotuned is None: + jit_kernel = _bwd_sparse_kernel._kernel + autotuned = autotuner.make_bwd_sparse_autotuned_kernel(jit_kernel) + _bwd_sparse_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _bwd_sparse_kernel_autotuned + + def _flash_sparse_attn_backward( query: torch.Tensor, key: torch.Tensor, @@ -707,6 +722,7 @@ def _flash_sparse_attn_backward( softmax_scale: float = None, softmax_threshold: float = None, window_size: Tuple[int, int] = (None, None), + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -741,13 +757,19 @@ def _flash_sparse_attn_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_sparse_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_sparse_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) seqlen_q_rounded = int(math.ceil(seqlen_q / TILE_M) * TILE_M) head_dim_rounded = int(math.ceil(head_dim / 32) * 32) @@ -805,7 +827,7 @@ def _flash_sparse_attn_backward( batch_size=batch_size, ) - _bwd_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -849,9 +871,11 @@ def _flash_sparse_attn_backward( None, None, None, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -901,6 +925,7 @@ def _flash_sparse_attn_varlen_backward( window_size: Tuple[int, int] = (None, None), seqused_q: Optional[torch.Tensor] = None, seqused_k: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: device = query.device @@ -938,13 +963,19 @@ def _flash_sparse_attn_varlen_backward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_bwd_sparse_launch_config( - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _bwd_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_bwd_sparse_launch_config( + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) total_q_rounded_padded = int( math.ceil((total_q + batch_size * TILE_M) / TILE_M) * TILE_M @@ -1008,7 +1039,7 @@ def _flash_sparse_attn_varlen_backward( batch_size=batch_size, ) - _bwd_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -1052,9 +1083,11 @@ def _flash_sparse_attn_varlen_backward( cu_seqlens_k, seqused_q, seqused_k, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD=qhead_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_dec.py b/flash_sparse_attn/ops/triton/flash_sparse_dec.py index 32a1384..99e33f3 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_dec.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_dec.py @@ -17,6 +17,7 @@ mask, flash_dec_combine, kernel_repr, + autotuner, ) @@ -140,6 +141,8 @@ def _dec_sparse_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -501,6 +504,18 @@ def _dec_sparse_kernel( _dec_sparse_kernel = cache_utils.wrap_kernel(_dec_sparse_kernel) +_dec_sparse_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _dec_sparse_kernel_autotuned + if _dec_sparse_kernel_autotuned is None: + jit_kernel = _dec_sparse_kernel.kernel + autotuned = autotuner.make_dec_sparse_autotuned_kernel(jit_kernel) + _dec_sparse_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _dec_sparse_kernel_autotuned + + def _flash_sparse_attn_decode( query: torch.Tensor, key: torch.Tensor, @@ -514,6 +529,7 @@ def _flash_sparse_attn_decode( is_quant: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -548,14 +564,21 @@ def _flash_sparse_attn_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_sparse_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_sparse_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -605,7 +628,7 @@ def _flash_sparse_attn_decode( num_splits=num_splits, ) - _dec_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -641,6 +664,8 @@ def _flash_sparse_attn_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, @@ -683,6 +708,7 @@ def _flash_sparse_attn_varlen_decode( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: device = query.device @@ -718,14 +744,21 @@ def _flash_sparse_attn_varlen_decode( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_dec_sparse_launch_config( - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = max(triton.next_power_of_2(qheads_per_kvhead), 16) + TILE_N = 128 + num_warps = num_stages = num_ctas = None + else: + kernel = _dec_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_dec_sparse_launch_config( + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = utils.num_splits_heuristic( seqlen_q=qheads_per_kvhead, @@ -775,7 +808,7 @@ def _flash_sparse_attn_varlen_decode( num_splits=num_splits, ) - _dec_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -811,6 +844,8 @@ def _flash_sparse_attn_varlen_decode( seqlen_q=qheads_per_kvhead, seqlen_k=seqlen_k, head_dim=head_dim, + SEQLEN_Q_CACHE=0, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py index ad03306..98a2968 100644 --- a/flash_sparse_attn/ops/triton/flash_sparse_fwd.py +++ b/flash_sparse_attn/ops/triton/flash_sparse_fwd.py @@ -17,6 +17,7 @@ mask, flash_dec_combine, kernel_repr, + autotuner, ) @@ -141,6 +142,8 @@ def _fwd_sparse_kernel( seqlen_q, seqlen_k, head_dim, + SEQLEN_Q_CACHE: tl.constexpr, + SEQLEN_K_CACHE: tl.constexpr, QHEADS_PER_KVHEAD_PACKGQA: tl.constexpr, TILE_M: tl.constexpr, TILE_N: tl.constexpr, @@ -656,6 +659,18 @@ def _fwd_sparse_kernel( _fwd_sparse_kernel = cache_utils.wrap_kernel(_fwd_sparse_kernel) +_fwd_sparse_kernel_autotuned = None + + +def _get_autotuned_kernel(): + global _fwd_sparse_kernel_autotuned + if _fwd_sparse_kernel_autotuned is None: + jit_kernel = _fwd_sparse_kernel._kernel + autotuned = autotuner.make_fwd_sparse_autotuned_kernel(jit_kernel) + _fwd_sparse_kernel_autotuned = autotuner.AutotunedKernel(autotuned) + return _fwd_sparse_kernel_autotuned + + def _flash_sparse_attn_forward( query: torch.Tensor, key: torch.Tensor, @@ -672,6 +687,7 @@ def _flash_sparse_attn_forward( pack_gqa: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float, float]: device = query.device @@ -708,16 +724,22 @@ def _flash_sparse_attn_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_sparse_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_sparse_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -769,7 +791,7 @@ def _flash_sparse_attn_forward( num_splits=num_splits, ) - _fwd_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -802,9 +824,11 @@ def _flash_sparse_attn_forward( None, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, @@ -857,6 +881,7 @@ def _flash_sparse_attn_varlen_forward( seqused_k: Optional[torch.Tensor] = None, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, float, float]: device = query.device @@ -896,16 +921,22 @@ def _flash_sparse_attn_varlen_forward( TILE_K = max(triton.next_power_of_2(head_dim), 16) - TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( - launch_template.get_fwd_sparse_launch_config( - is_split_kv=is_split_kv, - pack_gqa=pack_gqa, - qheads_per_kvhead=qheads_per_kvhead, - tile_k=TILE_K, - device=device, - arch=arch, + if is_autotune: + kernel = _get_autotuned_kernel() + TILE_M = TILE_N = 64 + num_warps = num_stages = num_ctas = None + else: + kernel = _fwd_sparse_kernel + TILE_M, TILE_N, num_warps, num_stages, num_ctas = ( + launch_template.get_fwd_sparse_launch_config( + is_split_kv=is_split_kv, + pack_gqa=pack_gqa, + qheads_per_kvhead=qheads_per_kvhead, + tile_k=TILE_K, + device=device, + arch=arch, + ) ) - ) num_splits = ( utils.num_splits_heuristic( @@ -957,7 +988,7 @@ def _flash_sparse_attn_varlen_forward( num_splits=num_splits, ) - _fwd_sparse_kernel[grid]( + kernel[grid]( query, key, value, @@ -990,9 +1021,11 @@ def _flash_sparse_attn_varlen_forward( None, qheads_per_kvhead, num_splits, - seqlen_q, - seqlen_k, - head_dim, + seqlen_q=seqlen_q, + seqlen_k=seqlen_k, + head_dim=head_dim, + SEQLEN_Q_CACHE=seqlen_q // 1024, + SEQLEN_K_CACHE=seqlen_k // 1024, QHEADS_PER_KVHEAD_PACKGQA=qheads_per_kvhead_packgqa, TILE_M=TILE_M, TILE_N=TILE_N, diff --git a/flash_sparse_attn/ops/triton/interface.py b/flash_sparse_attn/ops/triton/interface.py index c8366a5..c7c7f7b 100644 --- a/flash_sparse_attn/ops/triton/interface.py +++ b/flash_sparse_attn/ops/triton/interface.py @@ -59,13 +59,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if query.shape[1] == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -88,12 +91,16 @@ def forward( pack_gqa=pack_gqa, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) - ctx.save_for_backward(query, key, value, out, lse) + ctx.save_for_backward(query_orig, key_orig, value_orig, out, lse) ctx.is_causal = is_causal ctx.softmax_scale = softmax_scale ctx.window_size = window_size + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -116,6 +123,8 @@ def backward(ctx, dout: torch.Tensor, *args): is_causal=ctx.is_causal, softmax_scale=ctx.softmax_scale, window_size=ctx.window_size, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, *((None,) * 20) @@ -144,13 +153,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if max_seqlen_q == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -179,12 +191,14 @@ def forward( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) ctx.save_for_backward( - query, - key, - value, + query_orig, + key_orig, + value_orig, out, lse, cu_seqlens_q, @@ -197,6 +211,8 @@ def forward( ctx.window_size = window_size ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -235,6 +251,8 @@ def backward(ctx, dout: torch.Tensor, *args): window_size=ctx.window_size, seqused_q=seqused_q, seqused_k=seqused_k, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, *((None,) * 20) @@ -258,13 +276,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if query.shape[1] == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -288,13 +309,17 @@ def forward( pack_gqa=pack_gqa, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) - ctx.save_for_backward(query, key, value, out, lse) + ctx.save_for_backward(query_orig, key_orig, value_orig, out, lse) ctx.is_causal = is_causal ctx.softmax_scale = softmax_scale ctx.softmax_threshold = softmax_threshold ctx.window_size = window_size + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -318,6 +343,8 @@ def backward(ctx, dout: torch.Tensor, *args): softmax_scale=ctx.softmax_scale, softmax_threshold=ctx.softmax_threshold, window_size=ctx.window_size, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, *((None,) * 20) @@ -347,13 +374,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if max_seqlen_q == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -383,12 +413,14 @@ def forward( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) ctx.save_for_backward( - query, - key, - value, + query_orig, + key_orig, + value_orig, out, lse, cu_seqlens_q, @@ -402,6 +434,8 @@ def forward( ctx.window_size = window_size ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -441,6 +475,8 @@ def backward(ctx, dout: torch.Tensor, *args): window_size=ctx.window_size, seqused_q=seqused_q, seqused_k=seqused_k, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, *((None,) * 20) @@ -469,13 +505,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if query.shape[1] == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -505,10 +544,12 @@ def forward( pack_gqa=pack_gqa, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) ) - ctx.save_for_backward(query, key, value, alpha, delta, out, lse) + ctx.save_for_backward(query_orig, key_orig, value_orig, alpha, delta, out, lse) ctx.is_causal = is_causal ctx.softmax_scale = softmax_scale ctx.softmax_threshold = softmax_threshold @@ -516,6 +557,8 @@ def forward( ctx.is_logsigmoid_gate = is_logsigmoid_gate ctx.is_adapt_gate = is_adapt_gate ctx.window_size = window_size + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -544,6 +587,8 @@ def backward(ctx, dout: torch.Tensor, *args): is_logsigmoid_gate=ctx.is_logsigmoid_gate, is_adapt_gate=ctx.is_adapt_gate, window_size=ctx.window_size, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, da, dd, *((None,) * 20) @@ -578,13 +623,16 @@ def forward( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, + skip_checks: bool = False, + return_lse: bool = False, ): # Set is_causal to False if sequence length is 1 to avoid unnecessary masking overhead is_causal = False if max_seqlen_q == 1 else is_causal + query_orig, key_orig, value_orig = query, key, value if is_quant and ( query_scale is None or key_scale is None or value_scale is None ): @@ -620,13 +668,15 @@ def forward( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, + skip_checks=skip_checks, ) ) ctx.save_for_backward( - query, - key, - value, + query_orig, + key_orig, + value_orig, alpha, delta, out, @@ -645,6 +695,8 @@ def forward( ctx.window_size = window_size ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_k = max_seqlen_k + ctx.is_autotune = is_autotune + ctx.skip_checks = skip_checks if return_lse: # LSE gradient is not supported yet @@ -691,6 +743,8 @@ def backward(ctx, dout: torch.Tensor, *args): window_size=ctx.window_size, seqused_q=seqused_q, seqused_k=seqused_k, + is_autotune=ctx.is_autotune, + skip_checks=ctx.skip_checks, ) return dq, dk, dv, da, dd, *((None,) * 20) @@ -709,6 +763,8 @@ def flash_dense_attn_func( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -726,6 +782,8 @@ def flash_dense_attn_func( :param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. @@ -743,6 +801,8 @@ def flash_dense_attn_func( is_quant, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -757,10 +817,11 @@ def flash_dense_attn_with_kvcache_func( value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash dense attention function for decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -774,10 +835,11 @@ def flash_dense_attn_with_kvcache_func( :param value_scale: Optional per-tensor scale for FP8 value dequantization. :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads]. """ @@ -799,6 +861,7 @@ def flash_dense_attn_with_kvcache_func( is_quant=is_quant, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) @@ -826,6 +889,8 @@ def flash_dense_attn_varlen_func( seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -849,6 +914,8 @@ def flash_dense_attn_varlen_func( :param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. @@ -872,6 +939,8 @@ def flash_dense_attn_varlen_func( seqused_k, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -889,10 +958,11 @@ def flash_dense_attn_varlen_with_kvcache_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash dense attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -909,10 +979,11 @@ def flash_dense_attn_varlen_with_kvcache_func( :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. :param seqused_k: Optional tensor indicating the actual sequence lengths for keys/values. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q]. """ @@ -937,6 +1008,7 @@ def flash_dense_attn_varlen_with_kvcache_func( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) @@ -959,6 +1031,8 @@ def flash_sparse_attn_func( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -977,6 +1051,8 @@ def flash_sparse_attn_func( :param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. @@ -995,6 +1071,8 @@ def flash_sparse_attn_func( is_quant, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -1010,10 +1088,11 @@ def flash_sparse_attn_with_kvcache_func( value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash sparse attention function for decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -1028,10 +1107,11 @@ def flash_sparse_attn_with_kvcache_func( :param value_scale: Optional per-tensor scale for FP8 value dequantization. :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads]. """ @@ -1054,6 +1134,7 @@ def flash_sparse_attn_with_kvcache_func( is_quant=is_quant, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) @@ -1082,6 +1163,8 @@ def flash_sparse_attn_varlen_func( seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -1106,6 +1189,8 @@ def flash_sparse_attn_varlen_func( :param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. @@ -1130,6 +1215,8 @@ def flash_sparse_attn_varlen_func( seqused_k, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -1148,10 +1235,11 @@ def flash_sparse_attn_varlen_with_kvcache_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash sparse attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -1169,10 +1257,11 @@ def flash_sparse_attn_varlen_with_kvcache_func( :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. :param seqused_k: Optional tensor indicating the actual sequence lengths for keys/values. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q]. """ @@ -1198,6 +1287,7 @@ def flash_sparse_attn_varlen_with_kvcache_func( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) @@ -1225,6 +1315,8 @@ def flash_gated_attn_func( is_quant: bool = False, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -1248,6 +1340,8 @@ def flash_gated_attn_func( :param is_quant: Whether to quantize inputs to FP8 for attention computation. If True, query_scale, key_scale, and value_scale must be provided or will be computed from the input tensors. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, seqlen_q, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads, seqlen_q]. @@ -1271,6 +1365,8 @@ def flash_gated_attn_func( is_quant, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -1290,10 +1386,11 @@ def flash_gated_attn_with_kvcache_func( value_scale: Optional[torch.Tensor] = None, window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash gated attention function for decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -1312,10 +1409,11 @@ def flash_gated_attn_with_kvcache_func( :param value_scale: Optional per-tensor scale for FP8 value dequantization. :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads]. """ @@ -1342,6 +1440,7 @@ def flash_gated_attn_with_kvcache_func( is_quant=is_quant, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) @@ -1375,6 +1474,8 @@ def flash_gated_attn_varlen_func( seqused_k: Optional[torch.Tensor] = None, is_split_kv: bool = False, pack_gqa: bool = False, + is_autotune: bool = False, + skip_checks: bool = False, return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ @@ -1404,6 +1505,8 @@ def flash_gated_attn_varlen_func( :param seqused_k: Optional tensor of shape [total_seqlen_k] indicating the actual sequence lengths for keys/values. If provided, overrides cu_seqlens_k for masking. :param is_split_kv: Whether to enable split-KV for occupancy. :param pack_gqa: Whether to pack grouped-query attention. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. + :param skip_checks: Whether to skip input validation checks for faster performance. :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [total_seqlen_q, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [total_seqlen_q, num_heads_q]. @@ -1433,6 +1536,8 @@ def flash_gated_attn_varlen_func( seqused_k, is_split_kv, pack_gqa, + is_autotune, + skip_checks, return_lse, ) @@ -1455,10 +1560,11 @@ def flash_gated_attn_varlen_with_kvcache_func( window_size: Tuple[Optional[int], Optional[int]] = (None, None), is_quant: bool = False, seqused_k: Optional[torch.Tensor] = None, - return_lse: bool = False, out: Optional[torch.Tensor] = None, lse: Optional[torch.Tensor] = None, + is_autotune: bool = False, skip_checks: bool = False, + return_lse: bool = False, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Flash gated attention function for variable-length decoding with KV cache that computes the attention output and optionally the logsumexp. @@ -1480,10 +1586,11 @@ def flash_gated_attn_varlen_with_kvcache_func( :param window_size: Optional tuple (window_size_q, window_size_k) for local attention. If None, no local masking is applied. :param is_quant: Whether the inputs are quantized in FP8. If True, query_scale, key_scale, and value_scale must be provided for dequantization. :param seqused_k: Optional tensor indicating the actual sequence lengths for keys/values. - :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :param out: Optional preallocated output tensor with shape [batch_size, num_heads_q, head_dim]. :param lse: Optional preallocated logsumexp tensor with shape [batch_size, num_heads_q]. + :param is_autotune: Whether to use Triton autotuner for kernel launch configuration. :param skip_checks: Whether to skip input validation checks for faster performance. + :param return_lse: Whether to return the logsumexp tensor for numerical stability analysis. If True, returns a tuple (out, lse). If False, returns only out. :returns: If return_lse is False, returns out with shape [batch_size, num_heads_q, head_dim]. If return_lse is True, returns a tuple (out, lse), where lse has shape [batch_size, num_heads_q]. """ @@ -1513,6 +1620,7 @@ def flash_gated_attn_varlen_with_kvcache_func( seqused_k=seqused_k, out=out, lse=lse, + is_autotune=is_autotune, skip_checks=skip_checks, ) diff --git a/tests/benchmark_backward.py b/tests/benchmark_backward.py index 2a43906..74c8190 100644 --- a/tests/benchmark_backward.py +++ b/tests/benchmark_backward.py @@ -43,7 +43,7 @@ def benchmark_triton_dense_backward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) q = q.requires_grad_(True) k = k.requires_grad_(True) @@ -57,6 +57,8 @@ def benchmark_triton_dense_backward( is_causal=cfg.is_causal, softmax_scale=softmax_scale, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) dout = torch.randn_like(out) @@ -77,7 +79,7 @@ def benchmark_triton_sparse_backward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) q = q.requires_grad_(True) k = k.requires_grad_(True) @@ -93,6 +95,8 @@ def benchmark_triton_sparse_backward( softmax_scale=softmax_scale, softmax_threshold=softmax_threshold, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) dout = torch.randn_like(out) @@ -113,7 +117,7 @@ def benchmark_triton_gated_backward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) q = q.requires_grad_(True) k = k.requires_grad_(True) @@ -141,6 +145,8 @@ def benchmark_triton_gated_backward( is_logsigmoid_gate=False, is_adapt_gate=False, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) dout = torch.randn_like(out) @@ -163,7 +169,7 @@ def benchmark_fa_dense_backward( device=device, dtype=dtype, layout="bhsd", - input_source="llm", + input_source="random", ) q = q.requires_grad_(True) k = k.requires_grad_(True) @@ -202,7 +208,7 @@ def benchmark_cudnn_dense_backward( device=device, dtype=dtype, layout="bhsd", - input_source="llm", + input_source="random", ) q = q.requires_grad_(True) k = k.requires_grad_(True) diff --git a/tests/benchmark_decode.py b/tests/benchmark_decode.py index e9f9946..475ab35 100644 --- a/tests/benchmark_decode.py +++ b/tests/benchmark_decode.py @@ -53,6 +53,8 @@ def fn(): v, softmax_scale=softmax_scale, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -88,6 +90,8 @@ def fn(): value_scale=v_scale, window_size=(None, None), is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -115,78 +119,84 @@ def fn(): softmax_scale=softmax_scale, softmax_threshold=softmax_threshold, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) -def benchmark_triton_gated_decode( - cfg: BenchmarkConfig, device: str = "cuda", dtype=torch.bfloat16 +def benchmark_triton_sparse_decode_fp8( + cfg: BenchmarkConfig, device: str = "cuda" ) -> float: q, k, v = generate_inputs( cfg, device=device, - dtype=dtype, + dtype=torch.bfloat16, layout="bshd", input_source="random", ) q = q.squeeze(1) - alpha = torch.randn(cfg.batch_size, cfg.num_heads, device=device, dtype=dtype) - delta = torch.randn( - cfg.batch_size, cfg.seqlen_k, cfg.num_kv_heads, device=device, dtype=dtype - ) + + q_fp8, q_scale = quant.quantize_fp8(q) + k_fp8, k_scale = quant.quantize_fp8(k) + v_fp8, v_scale = quant.quantize_fp8(v) + softmax_scale = cfg.head_dim**-0.5 softmax_threshold = 1.0 - gate_threshold = 1.0 def fn(): - flash_gated_attn_with_kvcache_func( - q, - k, - v, - alpha, - delta, + flash_sparse_attn_with_kvcache_func( + q_fp8, + k_fp8, + v_fp8, softmax_scale=softmax_scale, softmax_threshold=softmax_threshold, - gate_threshold=gate_threshold, - is_logsigmoid_gate=False, + query_scale=q_scale, + key_scale=k_scale, + value_scale=v_scale, window_size=(None, None), + is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) -def benchmark_triton_sparse_decode_fp8( - cfg: BenchmarkConfig, device: str = "cuda" +def benchmark_triton_gated_decode( + cfg: BenchmarkConfig, device: str = "cuda", dtype=torch.bfloat16 ) -> float: q, k, v = generate_inputs( cfg, device=device, - dtype=torch.bfloat16, + dtype=dtype, layout="bshd", input_source="random", ) q = q.squeeze(1) - - q_fp8, q_scale = quant.quantize_fp8(q) - k_fp8, k_scale = quant.quantize_fp8(k) - v_fp8, v_scale = quant.quantize_fp8(v) - + alpha = torch.randn(cfg.batch_size, cfg.num_heads, device=device, dtype=dtype) + delta = torch.randn( + cfg.batch_size, cfg.seqlen_k, cfg.num_kv_heads, device=device, dtype=dtype + ) softmax_scale = cfg.head_dim**-0.5 softmax_threshold = 1.0 + gate_threshold = 1.0 def fn(): - flash_sparse_attn_with_kvcache_func( - q_fp8, - k_fp8, - v_fp8, + flash_gated_attn_with_kvcache_func( + q, + k, + v, + alpha, + delta, softmax_scale=softmax_scale, softmax_threshold=softmax_threshold, - query_scale=q_scale, - key_scale=k_scale, - value_scale=v_scale, + gate_threshold=gate_threshold, + is_logsigmoid_gate=False, window_size=(None, None), - is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -238,6 +248,8 @@ def fn(): value_scale=v_scale, window_size=(None, None), is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) diff --git a/tests/benchmark_forward.py b/tests/benchmark_forward.py index ff6b2ca..fe39669 100644 --- a/tests/benchmark_forward.py +++ b/tests/benchmark_forward.py @@ -44,7 +44,7 @@ def benchmark_triton_dense_forward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) softmax_scale = cfg.head_dim**-0.5 @@ -56,6 +56,8 @@ def fn(): is_causal=cfg.is_causal, softmax_scale=softmax_scale, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -69,7 +71,7 @@ def benchmark_triton_dense_forward_fp8( device=device, dtype=torch.bfloat16, layout="bshd", - input_source="llm", + input_source="random", ) q_fp8, q_scale = quant.quantize_fp8(q) @@ -90,6 +92,8 @@ def fn(): value_scale=v_scale, window_size=(None, None), is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -103,7 +107,7 @@ def benchmark_triton_sparse_forward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) softmax_scale = cfg.head_dim**-0.5 softmax_threshold = 1.0 @@ -117,6 +121,8 @@ def fn(): softmax_scale=softmax_scale, softmax_threshold=softmax_threshold, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -130,7 +136,7 @@ def benchmark_triton_sparse_forward_fp8( device=device, dtype=torch.bfloat16, layout="bshd", - input_source="llm", + input_source="random", ) q_fp8, q_scale = quant.quantize_fp8(q) @@ -153,6 +159,8 @@ def fn(): value_scale=v_scale, window_size=(None, None), is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -166,7 +174,7 @@ def benchmark_triton_gated_forward( device=device, dtype=dtype, layout="bshd", - input_source="llm", + input_source="random", ) alpha = torch.randn( cfg.batch_size, cfg.seqlen_q, cfg.num_heads, device=device, dtype=dtype @@ -192,6 +200,8 @@ def fn(): is_logsigmoid_gate=False, is_adapt_gate=False, window_size=(None, None), + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -205,7 +215,7 @@ def benchmark_triton_gated_forward_fp8( device=device, dtype=torch.bfloat16, layout="bshd", - input_source="llm", + input_source="random", ) q_fp8, q_scale = quant.quantize_fp8(q) @@ -244,6 +254,8 @@ def fn(): value_scale=v_scale, window_size=(None, None), is_quant=True, + is_autotune=True, + skip_checks=True, ) return do_bench(fn, warmup=20, rep=100) @@ -257,7 +269,7 @@ def benchmark_fa_dense_forward( device=device, dtype=dtype, layout="bhsd", - input_source="llm", + input_source="random", ) softmax_scale = cfg.head_dim**-0.5 @@ -286,7 +298,7 @@ def benchmark_cudnn_dense_forward( device=device, dtype=dtype, layout="bhsd", - input_source="llm", + input_source="random", ) softmax_scale = cfg.head_dim**-0.5