Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions flash_sparse_attn/ops/triton/flash_bwd_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@
import triton
import triton.language as tl

from flash_sparse_attn.ops.triton import cache_utils, launch_grid, seqlen_info
from flash_sparse_attn.ops.triton import (
cache_utils,
launch_grid,
seqlen_info,
kernel_repr,
)


@triton.jit
@triton.jit(repr=kernel_repr.bwd_postprocess_repr)
def _bwd_postprocess_kernel(
dQaccum,
dQ,
Expand Down
3 changes: 2 additions & 1 deletion flash_sparse_attn/ops/triton/flash_bwd_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
launch_grid,
seqlen_info,
activations,
kernel_repr,
)


@triton.jit
@triton.jit(repr=kernel_repr.bwd_preprocess_repr)
def _bwd_preprocess_kernel(
Out,
dO,
Expand Down
3 changes: 2 additions & 1 deletion flash_sparse_attn/ops/triton/flash_dec_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
launch_template,
launch_grid,
seqlen_info,
kernel_repr,
)


@triton.jit
@triton.jit(repr=kernel_repr.dec_combine_repr)
def _dec_combine_kernel(
Out_partial,
Lse_partial,
Expand Down
23 changes: 12 additions & 11 deletions flash_sparse_attn/ops/triton/flash_dense_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
mask,
flash_bwd_preprocess,
flash_bwd_postprocess,
kernel_repr,
)


@triton.jit
def _bwd_inner_dense_base_kernel(
def _bwd_inner_dense_kernel(
acc_dk,
acc_dv,
k_tile,
Expand Down Expand Up @@ -115,8 +116,8 @@ def _bwd_inner_dense_base_kernel(
return acc_dk, acc_dv, q_ptrs, do_ptrs, lse_ptrs, dpsum_ptrs


@triton.jit
def _bwd_dense_base_kernel(
@triton.jit(repr=kernel_repr.bwd_dense_repr)
def _bwd_dense_kernel(
Q,
K,
V,
Expand Down Expand Up @@ -444,7 +445,7 @@ def _bwd_dense_base_kernel(
)

acc_dk, acc_dv, q_ptrs, do_ptrs, lse_ptrs, dpsum_ptrs = (
_bwd_inner_dense_base_kernel(
_bwd_inner_dense_kernel(
acc_dk=acc_dk,
acc_dv=acc_dv,
k_tile=k_tile,
Expand Down Expand Up @@ -514,7 +515,7 @@ def _bwd_dense_base_kernel(
)

acc_dk, acc_dv, q_ptrs, do_ptrs, lse_ptrs, dpsum_ptrs = (
_bwd_inner_dense_base_kernel(
_bwd_inner_dense_kernel(
acc_dk=acc_dk,
acc_dv=acc_dv,
k_tile=k_tile,
Expand Down Expand Up @@ -584,7 +585,7 @@ def _bwd_dense_base_kernel(
)

acc_dk, acc_dv, q_ptrs, do_ptrs, lse_ptrs, dpsum_ptrs = (
_bwd_inner_dense_base_kernel(
_bwd_inner_dense_kernel(
acc_dk=acc_dk,
acc_dv=acc_dv,
k_tile=k_tile,
Expand Down Expand Up @@ -635,10 +636,10 @@ def _bwd_dense_base_kernel(
tl.store(dk_ptrs, acc_dk, boundary_check=(0, 1), cache_modifier=".wb")


_bwd_dense_base_kernel = cache_utils.wrap_kernel(_bwd_dense_base_kernel)
_bwd_dense_kernel = cache_utils.wrap_kernel(_bwd_dense_kernel)


def _flash_dense_attn_base_backward(
def _flash_dense_attn_backward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -745,7 +746,7 @@ def _flash_dense_attn_base_backward(
batch_size=batch_size,
)

_bwd_dense_base_kernel[grid](
_bwd_dense_kernel[grid](
query,
key,
value,
Expand Down Expand Up @@ -823,7 +824,7 @@ def _flash_dense_attn_base_backward(
return dq, dk, dv


def _flash_dense_attn_varlen_base_backward(
def _flash_dense_attn_varlen_backward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -945,7 +946,7 @@ def _flash_dense_attn_varlen_base_backward(
batch_size=batch_size,
)

_bwd_dense_base_kernel[grid](
_bwd_dense_kernel[grid](
query,
key,
value,
Expand Down
163 changes: 79 additions & 84 deletions flash_sparse_attn/ops/triton/flash_dense_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
activations,
mask,
flash_fwd_combine,
kernel_repr,
)


@triton.jit
def _fwd_inner_dense_base_kernel(
def _fwd_inner_dense_kernel(
q_tile,
k_tile,
k_ptrs,
Expand Down Expand Up @@ -97,8 +98,8 @@ def _fwd_inner_dense_base_kernel(
return k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum


@triton.jit
def _fwd_dense_base_kernel(
@triton.jit(repr=kernel_repr.fwd_dense_repr)
def _fwd_dense_kernel(
Q,
K,
V,
Expand Down Expand Up @@ -415,37 +416,35 @@ def _fwd_dense_base_kernel(
# Process n_blocks with masking
if IS_CAUSAL or IS_LOCAL:
for n_block in tl.range(n_block_max - 1, n_block_max_no_mask - 1, -1):
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = (
_fwd_inner_dense_base_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_max_no_mask,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=True,
MASK_CAUSAL=IS_CAUSAL,
MASK_LOCAL=IS_LOCAL,
CHECK_INF=True,
)
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_max_no_mask,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=True,
MASK_CAUSAL=IS_CAUSAL,
MASK_LOCAL=IS_LOCAL,
CHECK_INF=True,
)
else:
# First iteration with seqlen masking
n_block = n_block_max - 1

k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_base_kernel(
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
Expand Down Expand Up @@ -493,31 +492,29 @@ def _fwd_dense_base_kernel(
)
k_tile = tl.load(k_ptrs, boundary_check=(0, 1), cache_modifier=".cg")
for n_block in tl.range(n_block_max_no_mask - 1, n_block_min_no_mask - 1, -1):
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = (
_fwd_inner_dense_base_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_min_no_mask,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=IS_LOCAL,
MASK_CAUSAL=False,
MASK_LOCAL=False,
CHECK_INF=IS_LOCAL,
)
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_min_no_mask,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=IS_LOCAL,
MASK_CAUSAL=False,
MASK_LOCAL=False,
CHECK_INF=IS_LOCAL,
)

# Process n_blocks with masking
Expand All @@ -540,31 +537,29 @@ def _fwd_dense_base_kernel(
)
k_tile = tl.load(k_ptrs, boundary_check=(0, 1), cache_modifier=".cg")
for n_block in tl.range(n_block_min_no_mask - 1, n_block_min - 1, -1):
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = (
_fwd_inner_dense_base_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_min,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=True,
MASK_CAUSAL=False,
MASK_LOCAL=True,
CHECK_INF=True,
)
k_tile, k_ptrs, v_ptrs, acc_o, row_max, row_sum = _fwd_inner_dense_kernel(
q_tile=q_tile,
k_tile=k_tile,
k_ptrs=k_ptrs,
v_ptrs=v_ptrs,
acc_o=acc_o,
row_max=row_max,
row_sum=row_sum,
softmax_scale_log2=softmax_scale_log2,
m_block=m_block,
n_block=n_block,
n_block_min=n_block_min,
actual_seqlen_q=actual_seqlen_q,
actual_seqlen_k=actual_seqlen_k,
TILE_M=TILE_M,
TILE_N=TILE_N,
WINDOW_SIZE_LEFT=WINDOW_SIZE_LEFT,
WINDOW_SIZE_RIGHT=WINDOW_SIZE_RIGHT,
QHEADS_PER_KVHEAD_PACKGQA=QHEADS_PER_KVHEAD_PACKGQA,
IS_MASK=True,
MASK_CAUSAL=False,
MASK_LOCAL=True,
CHECK_INF=True,
)

# Finalize softmax
Expand Down Expand Up @@ -608,10 +603,10 @@ def _fwd_dense_base_kernel(
tl.store(out_ptrs, acc_o, boundary_check=(0, 1), cache_modifier=".wb")


_fwd_dense_base_kernel = cache_utils.wrap_kernel(_fwd_dense_base_kernel)
_fwd_dense_kernel = cache_utils.wrap_kernel(_fwd_dense_kernel)


def _flash_dense_attn_base_forward(
def _flash_dense_attn_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -703,7 +698,7 @@ def _flash_dense_attn_base_forward(
num_splits=num_splits,
)

_fwd_dense_base_kernel[grid](
_fwd_dense_kernel[grid](
query,
key,
value,
Expand Down Expand Up @@ -765,7 +760,7 @@ def _flash_dense_attn_base_forward(
return out, lse, softmax_scale


def _flash_dense_attn_varlen_base_forward(
def _flash_dense_attn_varlen_forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
Expand Down Expand Up @@ -866,7 +861,7 @@ def _flash_dense_attn_varlen_base_forward(
num_splits=num_splits,
)

_fwd_dense_base_kernel[grid](
_fwd_dense_kernel[grid](
query,
key,
value,
Expand Down
3 changes: 2 additions & 1 deletion flash_sparse_attn/ops/triton/flash_fwd_combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
launch_template,
launch_grid,
seqlen_info,
kernel_repr,
)


@triton.jit
@triton.jit(repr=kernel_repr.fwd_combine_repr)
def _fwd_combine_kernel(
Out_partial,
Lse_partial,
Expand Down
Loading
Loading