From e865186640fd178d00ef5fbb02d6d3d453c610f5 Mon Sep 17 00:00:00 2001 From: taozewei Date: Tue, 7 Apr 2026 10:38:15 +0800 Subject: [PATCH] fix: deterministic backward deadlock with block-sparse attention (sm100) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When `deterministic=True` and block-sparse attention is active (arbitrary mask with func_num > 1), the dQ_semaphore in `dQacc_reduce` uses `lock_value = n_block`, which assumes every n_block processes every m_block (dense iteration). With block sparsity, skipped n_blocks never increment the semaphore, causing downstream CTAs to wait forever on `barrier.wait_eq`. Fix: precompute per-CSR-entry lock values on the host side (the rank of each n_block among all n_blocks that process the same m_block), pass them as `dQ_lock_values_mask` / `dQ_lock_values_full` tensors to the kernel, and read from these tensors instead of using raw `n_block`. The release side (`arrive_inc(sem, 1)`) needs no change since rank-based values naturally form a 0, 1, 2, ... sequence matching +1 increments. Verified on B200 (sm100) across all 12 attention mask configurations (full, varlen_full, causal, varlen_causal, sliding_window, varlen_sliding_window × 1k/4k) with bfloat16. Made-with: Cursor --- flash_attn/cute/flash_bwd_sm100.py | 19 ++++++ flash_attn/cute/interface.py | 106 +++++++++++++++++++++++++++++ 2 files changed, 125 insertions(+) diff --git a/flash_attn/cute/flash_bwd_sm100.py b/flash_attn/cute/flash_bwd_sm100.py index 479906bf199..e9496f14ffa 100644 --- a/flash_attn/cute/flash_bwd_sm100.py +++ b/flash_attn/cute/flash_bwd_sm100.py @@ -380,6 +380,8 @@ def __call__( mdQ_semaphore: Optional[cute.Tensor] = None, mdK_semaphore: Optional[cute.Tensor] = None, mdV_semaphore: Optional[cute.Tensor] = None, + dQ_lock_values_mask: Optional[cute.Tensor] = None, + dQ_lock_values_full: Optional[cute.Tensor] = None, ): assert all(x is None for x in (mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, mSeqUsedK)), ( "Variable sequence length is not supported yet in FlashAttentionBackwardSm100" @@ -694,6 +696,8 @@ class SharedStorage: mdQ_semaphore, mdK_semaphore, mdV_semaphore, + dQ_lock_values_mask, + dQ_lock_values_full, tma_atom_Q, tma_atom_K, tma_atom_V, @@ -752,6 +756,8 @@ def kernel( mdQ_semaphore: Optional[cute.Tensor], mdK_semaphore: Optional[cute.Tensor], mdV_semaphore: Optional[cute.Tensor], + dQ_lock_values_mask: Optional[cute.Tensor], + dQ_lock_values_full: Optional[cute.Tensor], tma_atom_Q: cute.CopyAtom, tma_atom_K: cute.CopyAtom, tma_atom_V: cute.CopyAtom, @@ -1182,6 +1188,8 @@ def kernel( TileSchedulerCls, blocksparse_tensors, mdQ_semaphore, + dQ_lock_values_mask, + dQ_lock_values_full, ) return @@ -2348,6 +2356,8 @@ def dQacc_reduce( TileSchedulerCls: Callable, blocksparse_tensors: Optional[LinearBlockSparseTensors], mdQ_semaphore: Optional[cute.Tensor], + dQ_lock_values_mask: Optional[cute.Tensor], + dQ_lock_values_full: Optional[cute.Tensor], ): num_reduce_threads = cute.arch.WARP_SIZE * len(self.reduce_warp_ids) tidx = cute.arch.thread_idx()[0] % num_reduce_threads @@ -2470,6 +2480,15 @@ def dQacc_reduce( ), ) lock_value = n_block_max_for_m_block - 1 - n_block + elif const_expr(self.use_block_sparsity): + lock_value = Int32(0) + if const_expr(dQ_lock_values_full is not None): + if iter_idx < curr_mask_cnt: + lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx] + else: + lock_value = dQ_lock_values_full[curr_full_offset + iter_idx - curr_mask_cnt] + else: + lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx] else: lock_value = n_block barrier.wait_eq( diff --git a/flash_attn/cute/interface.py b/flash_attn/cute/interface.py index b6000c6fc1e..9499b25e6e1 100644 --- a/flash_attn/cute/interface.py +++ b/flash_attn/cute/interface.py @@ -549,6 +549,92 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache = {} +def _compute_bwd_dQ_lock_values(block_sparse_tensors): + """Precompute per-entry semaphore lock values for deterministic backward. + + For block-sparse attention, not every n_block processes every m_block. + The original code uses ``lock_value = n_block`` which assumes dense + iteration and deadlocks when blocks are skipped. This function computes + the correct lock value for each CSR entry: the rank of the current + n_block among all n_blocks that process the same m_block. + + Pure GPU implementation -- no ``.item()``, ``.cpu()``, or any other + CPU-GPU synchronisation. + """ + device = block_sparse_tensors.mask_block_idx.device + + mask_cnt = block_sparse_tensors.mask_block_cnt + mask_off = block_sparse_tensors.mask_block_offset + mask_idx = block_sparse_tensors.mask_block_idx + has_full = block_sparse_tensors.full_block_cnt is not None + + mask_len = mask_idx.shape[0] + num_n = mask_cnt.shape[0] + + if has_full: + full_cnt = block_sparse_tensors.full_block_cnt + full_off = block_sparse_tensors.full_block_offset + full_idx = block_sparse_tensors.full_block_idx + full_len = full_idx.shape[0] + else: + full_len = 0 + + total = mask_len + full_len + if total == 0: + return ( + torch.zeros_like(mask_idx), + torch.zeros_like(block_sparse_tensors.full_block_idx) if has_full else None, + ) + + total_per_n = mask_cnt.to(torch.int64) + if has_full: + total_per_n = total_per_n + full_cnt.to(torch.int64) + cum = torch.zeros(num_n + 1, dtype=torch.int64, device=device) + cum[1:] = torch.cumsum(total_per_n, dim=0) + + mp = torch.arange(mask_len, dtype=torch.int64, device=device) + mo = mask_off.to(torch.int64) + mn = torch.searchsorted(mo, mp, right=True) - 1 + mg = cum[mn] + (mp - mo[mn]) + + fg = None + if has_full and full_len > 0: + fp = torch.arange(full_len, dtype=torch.int64, device=device) + fo = full_off.to(torch.int64) + fn = torch.searchsorted(fo, fp, right=True) - 1 + fg = cum[fn] + mask_cnt[fn].to(torch.int64) + (fp - fo[fn]) + + flat = torch.zeros(total, dtype=torch.int32, device=device) + flat.scatter_(0, mg, mask_idx) + if fg is not None: + flat.scatter_(0, fg, full_idx) + + si = torch.argsort(flat.to(torch.int64), stable=True) + sv = flat[si] + + pos = torch.arange(total, dtype=torch.int64, device=device) + boundary_pos = torch.full_like(pos, -1) + boundary_pos[0] = 0 + if total > 1: + changed = sv[1:] != sv[:-1] + boundary_pos[1:] = torch.where(changed, pos[1:], torch.tensor(-1, dtype=torch.int64, device=device)) + last_boundary, _ = torch.cummax(boundary_pos, dim=0) + sorted_rank = (pos - last_boundary).to(torch.int32) + + flat_lock = torch.empty(total, dtype=torch.int32, device=device) + flat_lock[si] = sorted_rank + + mask_lock = flat_lock[mg] + if has_full and full_len > 0: + full_lock = flat_lock[fg] + elif has_full: + full_lock = torch.zeros_like(block_sparse_tensors.full_block_idx) + else: + full_lock = None + + return mask_lock, full_lock + + def _flash_attn_bwd( q: torch.Tensor, k: torch.Tensor, @@ -762,6 +848,12 @@ def _flash_attn_bwd( dtype = torch2cute_dtype_map[q.dtype] use_block_sparsity = block_sparse_tensors is not None + + dQ_lock_values_mask = None + dQ_lock_values_full = None + if deterministic and use_block_sparsity: + dQ_lock_values_mask, dQ_lock_values_full = _compute_bwd_dQ_lock_values(block_sparse_tensors) + if deterministic: dQ_semaphore = torch.zeros(batch_size, num_head, seqlen_q_rounded // m_block_size, 1, dtype=torch.int32, device="cuda") else: @@ -902,6 +994,16 @@ def _flash_attn_bwd( cute_aux_tensors = None if aux_tensors is not None: cute_aux_tensors = [from_dlpack(buf, enable_tvm_ffi=True).mark_layout_dynamic(leading_dim=buf.ndim - 1) for buf in aux_tensors] + dQ_lock_mask_tensor = ( + from_dlpack(dQ_lock_values_mask.detach(), assumed_align=4, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=dQ_lock_values_mask.ndim - 1) + if dQ_lock_values_mask is not None else None + ) + dQ_lock_full_tensor = ( + from_dlpack(dQ_lock_values_full.detach(), assumed_align=4, enable_tvm_ffi=True) + .mark_layout_dynamic(leading_dim=dQ_lock_values_full.ndim - 1) + if dQ_lock_values_full is not None else None + ) fa_bwd_sm80 = FlashAttentionBackwardSm80( dtype, @@ -983,6 +1085,8 @@ def _flash_attn_bwd( mdQ_semaphore=dQ_semaphore_tensor, mdK_semaphore=dK_semaphore_tensor, mdV_semaphore=dV_semaphore_tensor, + dQ_lock_values_mask=dQ_lock_mask_tensor, + dQ_lock_values_full=dQ_lock_full_tensor, options="--enable-tvm-ffi" ) # Execute with torch tensors directly @@ -1008,6 +1112,8 @@ def _flash_attn_bwd( mdQ_semaphore=dQ_semaphore, mdK_semaphore=dK_semaphore, mdV_semaphore=dV_semaphore, + dQ_lock_values_mask=dQ_lock_values_mask, + dQ_lock_values_full=dQ_lock_values_full, ) num_threads = 256 if compute_capability == 9 else 128