fix: deterministic backward deadlock with block-sparse attention (sm100)#1
Open
littsk wants to merge 1 commit into
Open
Conversation
When `deterministic=True` and block-sparse attention is active (arbitrary mask with func_num > 1), the dQ_semaphore in `dQacc_reduce` uses `lock_value = n_block`, which assumes every n_block processes every m_block (dense iteration). With block sparsity, skipped n_blocks never increment the semaphore, causing downstream CTAs to wait forever on `barrier.wait_eq`. Fix: precompute per-CSR-entry lock values on the host side (the rank of each n_block among all n_blocks that process the same m_block), pass them as `dQ_lock_values_mask` / `dQ_lock_values_full` tensors to the kernel, and read from these tensors instead of using raw `n_block`. The release side (`arrive_inc(sem, 1)`) needs no change since rank-based values naturally form a 0, 1, 2, ... sequence matching +1 increments. Verified on B200 (sm100) across all 12 attention mask configurations (full, varlen_full, causal, varlen_causal, sliding_window, varlen_sliding_window × 1k/4k) with bfloat16. Made-with: Cursor
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary / 概述
Bug: When
deterministic=Trueand block-sparse attention is active (arbitrary=True,func_num > 1) on sm100 (Blackwell), the backward kernel hangs indefinitely due to a GPU deadlock indQacc_reduce.Bug: 在 sm100 (Blackwell) GPU 上,当同时启用
deterministic=True和 block-sparse attention (arbitrary=True,func_num > 1) 时,backward kernel 由于dQacc_reduce中的 GPU 死锁而无限挂起。1. Background / 背景
1.1 Flash Attention 4 (FA4)
Flash Attention 4 targets NVIDIA Blackwell (sm100), written in CUTLASS Cute DSL, JIT-compiled to efficient CUDA kernels. The backward pass (
FlashAttentionBackwardSm100) has three parallel phases:1.2 Deterministic Mode / 确定性模式
Multiple CTAs write to the same
dQ[m_block]in non-deterministic order → floating-point accumulation differs across runs. Deterministic mode uses semaphore-based ordering to force all CTAs to accumulate dQ in a fixed sequence, ensuring bit-level reproducibility.由于 GPU 的并行特性,不同 CTA 写入同一
dQ[m_block]的顺序在不同运行之间不确定。Deterministic 模式通过信号量机制强制所有 CTA 按固定顺序累加 dQ,确保结果比特级可复现。1.3 Block-Sparse Attention
In real scenarios (HSTU, varlen sequence packing), the attention matrix is not fully dense. Block-Sparse Attention uses CSR structures to describe block-level sparsity, skipping unnecessary block pairs:
2. Original Design: Deterministic dQ Reduce Semaphore Protocol
2.1 Protocol
The backward reduce phase accumulates dQ contributions from multiple n_blocks into
dQ_accum[m_block]. Deterministic mode uses a per-m_block int32 semaphore array, initialized to 0:2.2 Dense Case
In non-block-sparse (dense) mode, every n_block processes every m_block, so
lock_value = n_blockworks correctly:The corresponding code in
flash_bwd_sm100.py :: dQacc_reduce:3. Bug Analysis: Semaphore Deadlock under Block-Sparsity
3.1 Root Cause / 问题根源
With block sparsity, the reduce loop only iterates over m_blocks in the current n_block's CSR:
But
lock_value = n_blockwas NOT updated. It still assumes every n_block processes every m_block.启用 block sparsity 后,reduce 循环只遍历当前 n_block 的 CSR 中列出的 m_blocks,但
lock_value = n_block这一计算并未随之改变。3.2 Deadlock Mechanism / 死锁形成机制
Example: m_block M is processed only by n_blocks {0, 2, 5} (skipping 1, 3, 4):
3.3 Trigger Conditions / 触发条件
All must be true:
deterministic=True— enables semaphore synchronizationarbitrary=Truewithfunc_num > 1— triggers block-sparse attention (HSTU-encoded varlen masks)3.4 Why Hard to Discover / 为什么难以发现
full_4k,causal_4k) work perfectly — no block sparsitydeterministic=Falseworks perfectly — no semaphorepy-spyonly showstorch.equal()/cuda.synchronize()— real deadlock is inside GPU kernel3.5 Minimal Repro Trigger Analysis / 最小复现 Case 的触发分析
3.5.1 Scenario / 场景构造
The repro simulates
varlen_full_1k: 10 disjoint full-attention segments packed into seqlen=1024:Each segment has full attention within itself; no attention between segments → block-diagonal pattern.
3.5.2 From Ranges to Block-Sparse CSR / 从 Ranges 到 CSR 的推导
With tile_size=128, Q/K each have 1024/128 = 8 blocks (0–7):
Step 1: Range coverage per tile / 每个 Range 在 tile 上的覆盖
Segment boundaries don't align with tile boundaries:
Step 2: 8×8 attention tile matrix
For each tile pair (m, n), classify as Full (F), Mask (M), or empty (.):
Key examples:
Step 3: Transpose to k2q CSR (backward direction)
Reading the matrix by columns, separating mask/full:
This is exactly the CSR data hardcoded in the repro's
build_bwd_block_sparse(). No n_block processes all 8 m_blocks.3.5.3 Deadlock Points / 死锁点精确定位
Transposing the CSR to see which n_blocks process each m_block:
With
lock_value = n_block:lock_value = 0, sem=0 → OK ✅lock_value = 2, but n_blocks 0,1 skip → sem stuck at 0 ≠ 2 → DEADLOCK ❌lock_value = 3, n_blocks 0,1,2 skip → sem = 0 ≠ 3 → DEADLOCK ❌lock_value = 6, n_blocks 0–5 all skip → sem = 0 ≠ 6 → DEADLOCK ❌5 out of 8 m_blocks deadlock (62.5%) — guaranteed hang.
3.5.4 Detailed Trace: m_block 3 / 以 m_block 3 为例的详细追踪
Before fix (
lock_value = n_block):After fix (
lock_value = rank among processors):3.5.5 Why This Config / 为什么选择这个配置
varlen_full_1kis the simplest natural scenario that triggers the bug:func_num=3(HSTU encoding) ensure thearbitrary=TruepathWith this config, 5/8 m_blocks deadlock (62.5%) — a guaranteed, always-reproducible bug.
4. Fix / 修复方案
4.1 Core Idea / 核心思路
Replace
lock_value = n_block(global n_block index) with rank of the current n_block among all n_blocks that process the same m_block:Semaphore 0→1→2→3 monotonically increases with no gaps.
4.2 Host-Side Precomputation (
interface.py)New function
_compute_bwd_dQ_lock_values()traverses k2q CSR in n_block order. For each entry, lock_value = count of previously seen processors for that m_block. Pure GPU computation, O(nnz), no CPU-GPU sync.4.3 Kernel-Side Read (
flash_bwd_sm100.py)Thread
dQ_lock_values_mask/dQ_lock_values_fullthrough__call__→kernel→dQacc_reduce:Cute DSL constraint: Runtime
ifbranches cannot first-define a variable. Must initializelock_value = Int32(0)before conditional overwrites (DSL compiles to MLIR/PTX with static type constraints).4.4 Release Side — No Change Needed
barrier.arrive_inc(semaphore[m_block], 1)always increments by 1, naturally matching rank-based values:5. Minimal Reproduction Script / 最小复现脚本
Click to expand (no external dependency except torch + flash_attn)
6. Files Changed / 修改文件
flash_attn/cute/interface.py_compute_bwd_dQ_lock_values(): traverse CSR to precompute lock valuesflash_attn/cute/interface.py_flash_attn_bwd, thread results through compile/executeflash_attn/cute/flash_bwd_sm100.pydQ_lock_values_mask/fullparams to__call__,kernel,dQacc_reduceflash_attn/cute/flash_bwd_sm100.py+125 lines, 0 deletions, 0 modifications to existing logic. Non-block-sparse, causal, and dense paths are completely untouched. No changes to CSR data structures (
LinearBlockSparseTensors), all existing callers unaffected.7. Verification / 验证
Minimal Repro
deterministic=Falsedeterministic=TrueFull Test Suite (NVIDIA B200 sm100, bfloat16)
full_1kfull_4kvarlen_full_1kvarlen_full_4kcausal_1kcausal_4kvarlen_causal_1kvarlen_causal_4ksliding_window_1ksliding_window_4kvarlen_sliding_window_1kvarlen_sliding_window_4kvarlen_full_1kincludesdeterministic=True+ GQA (qhead_per_kvhead=8) — the exact config that previously deadlocked — now passes.8. Lessons Learned / 经验与启示
8.1 Dense assumptions are a classic sparsity pitfall
The root cause is a classic sparsification error: loop range changed from dense to sparse, but synchronization logic dependent on dense semantics was not updated. In GPU kernels, such errors manifest as deadlocks rather than crashes, making debugging extremely difficult.
8.2 GPU semaphore deadlock debugging
py-spy dump --pid <PID>quickly locates Python-level blockingnvidia-smishowing GPU 100% + Python stuck onsynchronize()= classic GPU kernel deadlock8.3 Cute DSL programming constraints
During the fix, a Cute DSL limitation was encountered: runtime
ifbranches cannot first-define a variable ("lock_value is None prior to this if"). Variables must be initialized outside the branch first. This differs from standard Python — it's a static type constraint when DSL compiles to MLIR/PTX.8.4 Precompiled cache versioning
After modifying kernel signatures (adding parameters), stale precompiled caches cause
TypeError: unexpected keyword argument. Production environments should clear or version-control compilation caches after signature changes.