Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 18 additions & 128 deletions flash_sparse_attn/ops/triton/autotuner.py
Original file line number Diff line number Diff line change
@@ -1,113 +1,12 @@
import torch
import triton


def _get_max_shared_mem():
props = torch.cuda.get_device_properties(torch.cuda.current_device())
return getattr(
props, "shared_memory_per_block_optin", props.shared_memory_per_block
)


def _smem_bytes_fwd(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = tile_m * tile_k * dtype_bytes
pipelined = 2 * tile_n * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _smem_bytes_bwd(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = 2 * tile_n * tile_k * dtype_bytes
pipelined = 2 * tile_m * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _smem_bytes_dec(tile_m, tile_n, tile_k, num_stages, dtype_bytes):
resident = tile_m * tile_k * dtype_bytes
pipelined = 2 * tile_n * tile_k * dtype_bytes * num_stages
return resident + pipelined


def _prune_fwd_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_fwd(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_fwd(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def _prune_bwd_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_bwd(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_bwd(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def _prune_dec_configs(configs, named_args, **kwargs):
tile_k = kwargs.get("TILE_K", named_args.get("TILE_K", 128))
dtype_bytes = named_args["Q"].element_size()
max_smem = _get_max_shared_mem() - 4 * 1024
pruned = []
for cfg in configs:
tm, tn, ns = cfg.kwargs["TILE_M"], cfg.kwargs["TILE_N"], cfg.num_stages
if _smem_bytes_dec(tm, tn, tile_k, ns, dtype_bytes) <= max_smem:
pruned.append(cfg)
if not pruned:
pruned = [
min(
configs,
key=lambda c: _smem_bytes_dec(
c.kwargs["TILE_M"],
c.kwargs["TILE_N"],
tile_k,
c.num_stages,
dtype_bytes,
),
)
]
return pruned


def get_fwd_dense_autotune_configs():
configs = []
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -124,7 +23,7 @@ def get_fwd_sparse_autotune_configs():
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -141,7 +40,7 @@ def get_fwd_gated_autotune_configs():
for tile_m in [64, 128, 256]:
for tile_n in [32, 64, 128]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -158,7 +57,7 @@ def get_bwd_dense_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -175,7 +74,7 @@ def get_bwd_sparse_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -192,7 +91,7 @@ def get_bwd_gated_autotune_configs():
for tile_m in [32, 64, 128]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -209,7 +108,7 @@ def get_dec_dense_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -226,7 +125,7 @@ def get_dec_sparse_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -243,7 +142,7 @@ def get_dec_gated_autotune_configs():
for tile_m in [16, 32, 64]:
for tile_n in [64, 128, 256]:
for num_warps in [4, 8]:
for num_stages in [1, 2, 3]:
for num_stages in [1, 2]:
configs.append(
triton.Config(
{"TILE_M": tile_m, "TILE_N": tile_n},
Expand All @@ -259,80 +158,71 @@ def make_fwd_dense_autotuned_kernel(jit_kernel):
configs = get_fwd_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_fwd_sparse_autotuned_kernel(jit_kernel):
configs = get_fwd_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_fwd_gated_autotuned_kernel(jit_kernel):
configs = get_fwd_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_fwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_bwd_dense_autotuned_kernel(jit_kernel):
configs = get_bwd_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_bwd_sparse_autotuned_kernel(jit_kernel):
configs = get_bwd_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_bwd_gated_autotuned_kernel(jit_kernel):
configs = get_bwd_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_CAUSAL", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_bwd_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_dec_dense_autotuned_kernel(jit_kernel):
configs = get_dec_dense_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_dec_sparse_autotuned_kernel(jit_kernel):
configs = get_dec_sparse_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


def make_dec_gated_autotuned_kernel(jit_kernel):
configs = get_dec_gated_autotune_configs()
return triton.autotune(
configs=configs,
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K", "IS_LOCAL"],
prune_configs_by={"early_config_prune": _prune_dec_configs},
key=["SEQLEN_Q_CACHE", "SEQLEN_K_CACHE", "TILE_K"],
)(jit_kernel)


Expand Down
Loading
Loading