Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions flash_attn/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,8 @@ def __call__(
mdQ_semaphore: Optional[cute.Tensor] = None,
mdK_semaphore: Optional[cute.Tensor] = None,
mdV_semaphore: Optional[cute.Tensor] = None,
dQ_lock_values_mask: Optional[cute.Tensor] = None,
dQ_lock_values_full: Optional[cute.Tensor] = None,
):
assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), (
"Variable sequence length is not supported yet in FlashAttentionBackwardSm100"
Expand Down Expand Up @@ -694,6 +696,8 @@ class SharedStorage:
mdQ_semaphore,
mdK_semaphore,
mdV_semaphore,
dQ_lock_values_mask,
dQ_lock_values_full,
tma_atom_Q,
tma_atom_K,
tma_atom_V,
Expand Down Expand Up @@ -752,6 +756,8 @@ def kernel(
mdQ_semaphore: Optional[cute.Tensor],
mdK_semaphore: Optional[cute.Tensor],
mdV_semaphore: Optional[cute.Tensor],
dQ_lock_values_mask: Optional[cute.Tensor],
dQ_lock_values_full: Optional[cute.Tensor],
tma_atom_Q: cute.CopyAtom,
tma_atom_K: cute.CopyAtom,
tma_atom_V: cute.CopyAtom,
Expand Down Expand Up @@ -1182,6 +1188,8 @@ def kernel(
TileSchedulerCls,
blocksparse_tensors,
mdQ_semaphore,
dQ_lock_values_mask,
dQ_lock_values_full,
)

return
Expand Down Expand Up @@ -2348,6 +2356,8 @@ def dQacc_reduce(
TileSchedulerCls: Callable,
blocksparse_tensors: Optional[LinearBlockSparseTensors],
mdQ_semaphore: Optional[cute.Tensor],
dQ_lock_values_mask: Optional[cute.Tensor],
dQ_lock_values_full: Optional[cute.Tensor],
):
num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids)
tidx = cute.arch.thread_idx()[0] % num_reduce_threads
Expand Down Expand Up @@ -2470,6 +2480,15 @@ def dQacc_reduce(
),
)
lock_value = n_block_max_for_m_block - 1 - n_block
elif const_expr(self.use_block_sparsity):
lock_value = Int32(0)
if const_expr(dQ_lock_values_full is not None):
if iter_idx < curr_mask_cnt:
lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx]
else:
lock_value = dQ_lock_values_full[curr_full_offset + iter_idx - curr_mask_cnt]
else:
lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx]
else:
lock_value = n_block
barrier.wait_eq(
Expand Down
106 changes: 106 additions & 0 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,92 @@ def _flash_attn_fwd(
_flash_attn_fwd.compile_cache = {}


def _compute_bwd_dQ_lock_values(block_sparse_tensors):
"""Precompute per-entry semaphore lock values for deterministic backward.

For block-sparse attention, not every n_block processes every m_block.
The original code uses ``lock_value = n_block`` which assumes dense
iteration and deadlocks when blocks are skipped. This function computes
the correct lock value for each CSR entry: the rank of the current
n_block among all n_blocks that process the same m_block.

Pure GPU implementation -- no ``.item()``, ``.cpu()``, or any other
CPU-GPU synchronisation.
"""
device = block_sparse_tensors.mask_block_idx.device

mask_cnt = block_sparse_tensors.mask_block_cnt
mask_off = block_sparse_tensors.mask_block_offset
mask_idx = block_sparse_tensors.mask_block_idx
has_full = block_sparse_tensors.full_block_cnt is not None

mask_len = mask_idx.shape[0]
num_n = mask_cnt.shape[0]

if has_full:
full_cnt = block_sparse_tensors.full_block_cnt
full_off = block_sparse_tensors.full_block_offset
full_idx = block_sparse_tensors.full_block_idx
full_len = full_idx.shape[0]
else:
full_len = 0

total = mask_len + full_len
if total == 0:
return (
torch.zeros_like(mask_idx),
torch.zeros_like(block_sparse_tensors.full_block_idx) if has_full else None,
)

total_per_n = mask_cnt.to(torch.int64)
if has_full:
total_per_n = total_per_n + full_cnt.to(torch.int64)
cum = torch.zeros(num_n + 1, dtype=torch.int64, device=device)
cum[1:] = torch.cumsum(total_per_n, dim=0)

mp = torch.arange(mask_len, dtype=torch.int64, device=device)
mo = mask_off.to(torch.int64)
mn = torch.searchsorted(mo, mp, right=True) - 1
mg = cum[mn] + (mp - mo[mn])

fg = None
if has_full and full_len > 0:
fp = torch.arange(full_len, dtype=torch.int64, device=device)
fo = full_off.to(torch.int64)
fn = torch.searchsorted(fo, fp, right=True) - 1
fg = cum[fn] + mask_cnt[fn].to(torch.int64) + (fp - fo[fn])

flat = torch.zeros(total, dtype=torch.int32, device=device)
flat.scatter_(0, mg, mask_idx)
if fg is not None:
flat.scatter_(0, fg, full_idx)

si = torch.argsort(flat.to(torch.int64), stable=True)
sv = flat[si]

pos = torch.arange(total, dtype=torch.int64, device=device)
boundary_pos = torch.full_like(pos, -1)
boundary_pos[0] = 0
if total > 1:
changed = sv[1:] != sv[:-1]
boundary_pos[1:] = torch.where(changed, pos[1:], torch.tensor(-1, dtype=torch.int64, device=device))
last_boundary, _ = torch.cummax(boundary_pos, dim=0)
sorted_rank = (pos - last_boundary).to(torch.int32)

flat_lock = torch.empty(total, dtype=torch.int32, device=device)
flat_lock[si] = sorted_rank

mask_lock = flat_lock[mg]
if has_full and full_len > 0:
full_lock = flat_lock[fg]
elif has_full:
full_lock = torch.zeros_like(block_sparse_tensors.full_block_idx)
else:
full_lock = None

return mask_lock, full_lock


def _flash_attn_bwd(
q: torch.Tensor,
k: torch.Tensor,
Expand Down Expand Up @@ -762,6 +848,12 @@ def _flash_attn_bwd(

dtype = torch2cute_dtype_map[q.dtype]
use_block_sparsity = block_sparse_tensors is not None

dQ_lock_values_mask = None
dQ_lock_values_full = None
if deterministic and use_block_sparsity:
dQ_lock_values_mask, dQ_lock_values_full = _compute_bwd_dQ_lock_values(block_sparse_tensors)

if deterministic:
dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda")
else:
Expand Down Expand Up @@ -902,6 +994,16 @@ def _flash_attn_bwd(
cute_aux_tensors = None
if aux_tensors is not None:
cute_aux_tensors = [from_dlpack(buf, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=buf.ndim - 1) for buf in aux_tensors]
dQ_lock_mask_tensor = (
from_dlpack(dQ_lock_values_mask.detach(), assumed_align=4, enable_tvm_ffi=True)
.mark_layout_dynamic(leading_dim=dQ_lock_values_mask.ndim - 1)
if dQ_lock_values_mask is not None else None
)
dQ_lock_full_tensor = (
from_dlpack(dQ_lock_values_full.detach(), assumed_align=4, enable_tvm_ffi=True)
.mark_layout_dynamic(leading_dim=dQ_lock_values_full.ndim - 1)
if dQ_lock_values_full is not None else None
)

fa_bwd_sm80 = FlashAttentionBackwardSm80(
dtype,
Expand Down Expand Up @@ -983,6 +1085,8 @@ def _flash_attn_bwd(
mdQ_semaphore=dQ_semaphore_tensor,
mdK_semaphore=dK_semaphore_tensor,
mdV_semaphore=dV_semaphore_tensor,
dQ_lock_values_mask=dQ_lock_mask_tensor,
dQ_lock_values_full=dQ_lock_full_tensor,
options="--enable-tvm-ffi"
)
# Execute with torch tensors directly
Expand All @@ -1008,6 +1112,8 @@ def _flash_attn_bwd(
mdQ_semaphore=dQ_semaphore,
mdK_semaphore=dK_semaphore,
mdV_semaphore=dV_semaphore,
dQ_lock_values_mask=dQ_lock_values_mask,
dQ_lock_values_full=dQ_lock_values_full,
)

num_threads = 256 if compute_capability == 9 else 128
Expand Down