Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
7b08b9c
Refactor block min/max functions to use window size parameters directly
LoserCheems May 16, 2026
44f8925
Refactor apply_mask function to streamline window size parameters and…
LoserCheems May 16, 2026
9c4fbd8
Fix normalization logic in combine kernels to handle zero sums correctly
LoserCheems May 16, 2026
0f2cb55
Fix parameter naming for query heads in softmax and gate threshold fu…
LoserCheems May 16, 2026
c19e4c1
Enhance num_splits_heuristic to ensure a minimum return value of 1 an…
LoserCheems May 16, 2026
00f9eec
Fix parameter naming for query heads in launch configuration functions
LoserCheems May 16, 2026
faa0946
Add window size parameters to block min/max functions and update retu…
LoserCheems May 16, 2026
ba8c30f
Refactor masking logic to prevent simultaneous application of causal …
LoserCheems May 16, 2026
8b3b0eb
Refactor kernel parameters to use window size variables and correct q…
LoserCheems May 16, 2026
3cb47af
Refactor kernel parameters to include window size variables and corre…
LoserCheems May 16, 2026
51afda4
Refactor kernel parameters to include window size variables and corre…
LoserCheems May 16, 2026
d12554a
Refactor _bwd_inner_dense_kernel to use window size parameters directly
LoserCheems May 16, 2026
dba9f28
Refactor _bwd_inner_sparse_kernel to use window size parameters directly
LoserCheems May 16, 2026
cbfe34d
Refactor _bwd_inner_gated_kernel to use window size parameters directly
LoserCheems May 16, 2026
b96e248
Refactor _fwd_dense_kernel and related functions to incorporate windo…
LoserCheems May 16, 2026
4df37c6
Refactor _fwd_sparse_kernel and _flash_sparse_attn functions to integ…
LoserCheems May 16, 2026
7125da9
Refactor _fwd_gated_kernel and _flash_gated_attn functions to incorpo…
LoserCheems May 16, 2026
ca02613
Refactor _bwd_dense_kernel and related functions to incorporate windo…
LoserCheems May 16, 2026
efdfac8
Refactor _bwd_sparse_kernel and related functions to incorporate wind…
LoserCheems May 16, 2026
7c9bbb2
Refactor _bwd_gated_kernel and _flash_gated_attn functions to incorpo…
LoserCheems May 16, 2026
ebffbfe
Refactor attention functions to replace window_size parameter with is…
LoserCheems May 16, 2026
3bf67e9
Fix condition in _fwd_sparse_kernel to correctly handle local attenti…
LoserCheems May 16, 2026
d1f5b64
Enhance local attention handling in get_n_block_min_max by updating n…
LoserCheems May 18, 2026
fdab34f
Refactor _dec_inner_dense_kernel to replace window size constants wit…
LoserCheems May 18, 2026
bb33953
Refactor _dec_inner_sparse_kernel to replace window size constants wi…
LoserCheems May 18, 2026
7ba6377
Refactor _dec_inner_gated_kernel to replace window size constants wit…
LoserCheems May 18, 2026
afbfdf1
Refactor window_sizes_heuristic to add equal_bandwidth parameter for …
LoserCheems May 18, 2026
5a92d3c
Enhance get_n_block_min_max and get_n_block_min_before_local_mask by …
LoserCheems May 19, 2026
83480b5
Refactor _dec_dense_kernel and _flash_dense_attn_decode to incorporat…
LoserCheems May 19, 2026
cca319b
Refactor _dec_sparse_kernel and _flash_sparse_attn_decode to incorpor…
LoserCheems May 19, 2026
d310248
Refactor _dec_gated_kernel and _flash_gated_attn_decode to support dy…
LoserCheems May 19, 2026
a415eae
Refactor test cases to support is_local parameter for enhanced flexib…
LoserCheems May 19, 2026
8eabb00
Refactor reference score functions to support local attention handlin…
LoserCheems May 19, 2026
4997a9f
Remove window_size parameter from benchmark functions for consistency…
LoserCheems May 19, 2026
d69f1c3
Refactor autotuner functions to support dynamic memory pruning for fo…
LoserCheems May 20, 2026
5a19ef8
Refactor test cases to standardize parameters for dense and sparse at…
LoserCheems May 20, 2026
0364b44
Refactor get_m_block_min_max function to support split configurations…
LoserCheems May 24, 2026
5d5fde3
Enhance backward kernel functions to support split configurations for…
LoserCheems May 24, 2026
9e20ab0
Enhance backward kernel functions to support split configurations for…
LoserCheems May 24, 2026
e703354
Enhance backward kernel functions to support split configurations for…
LoserCheems May 24, 2026
6bf22d7
Add support for split-QO configuration in attention functions
LoserCheems May 24, 2026
e417ffe
Refactor grid functions to enhance backward and forward kernel handling
LoserCheems May 24, 2026
689039c
Add IS_CAUSAL and IS_LOCAL keys to autotuned kernel configurations
LoserCheems May 24, 2026
9fbf2a2
Remove static buffer pool and adjust compiled kernel cache size
LoserCheems May 24, 2026
3e4bd7d
Refactor launch configuration management for Triton kernels
LoserCheems May 24, 2026
7a0f357
Refactor kernel launch configuration for backward, forward, and decod…
LoserCheems May 24, 2026
873e5bd
Update is_autotune parameter description to clarify cache behavior
LoserCheems May 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 128 additions & 18 deletions flash_sparse_attn/ops/triton/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,113 @@
import torch
import triton


def _get_max_shared_mem():
props = torch.cuda.get_device_properties(torch.cuda.current_device())
return getattr(
props, "shared_memory_per_block_optin", props.shared_memory_per_block
)


def _smem_bytes_fwd(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = tile_m * tile_k * dtype_bytes
pipelined = 2 * tile_n * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _smem_bytes_bwd(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = 2 * tile_n * tile_k * dtype_bytes
pipelined = 2 * tile_m * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _smem_bytes_dec(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = tile_m * tile_k * dtype_bytes
pipelined = 2 * tile_n * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _prune_fwd_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_fwd(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_fwd(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def _prune_bwd_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_bwd(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_bwd(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def _prune_dec_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_dec(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_dec(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def get_fwd_dense_autotune_configs():
configs = []
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -23,7 +124,7 @@ def get_fwd_sparse_autotune_configs():
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -40,7 +141,7 @@ def get_fwd_gated_autotune_configs():
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -57,7 +158,7 @@ def get_bwd_dense_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -74,7 +175,7 @@ def get_bwd_sparse_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -91,7 +192,7 @@ def get_bwd_gated_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -108,7 +209,7 @@ def get_dec_dense_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -125,7 +226,7 @@ def get_dec_sparse_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -142,7 +243,7 @@ def get_dec_gated_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2]:
for num_stages in [1, 2, 3]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -158,71 +259,80 @@ def make_fwd_dense_autotuned_kernel(jit_kernel):
configs = get_fwd_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
)(jit_kernel)


def make_fwd_sparse_autotuned_kernel(jit_kernel):
configs = get_fwd_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
)(jit_kernel)


def make_fwd_gated_autotuned_kernel(jit_kernel):
configs = get_fwd_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
)(jit_kernel)


def make_bwd_dense_autotuned_kernel(jit_kernel):
configs = get_bwd_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
)(jit_kernel)


def make_bwd_sparse_autotuned_kernel(jit_kernel):
configs = get_bwd_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
)(jit_kernel)


def make_bwd_gated_autotuned_kernel(jit_kernel):
configs = get_bwd_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
)(jit_kernel)


def make_dec_dense_autotuned_kernel(jit_kernel):
configs = get_dec_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
)(jit_kernel)


def make_dec_sparse_autotuned_kernel(jit_kernel):
configs = get_dec_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
)(jit_kernel)


def make_dec_gated_autotuned_kernel(jit_kernel):
configs = get_dec_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
)(jit_kernel)


Expand Down
Loading
Loading