Skip to content

fix: deterministic backward deadlock with block-sparse attention (sm100)#1

Open
littsk wants to merge 1 commit into
jiayus-nvidia:arbitrary_maskfrom
littsk:fix/deterministic-blocksparse-deadlock
Open

fix: deterministic backward deadlock with block-sparse attention (sm100)#1
littsk wants to merge 1 commit into
jiayus-nvidia:arbitrary_maskfrom
littsk:fix/deterministic-blocksparse-deadlock

Conversation

@littsk

@littsk littsk commented Apr 7, 2026

Copy link
Copy Markdown

Summary / 概述

Bug: When deterministic=True and block-sparse attention is active (arbitrary=True, func_num > 1) on sm100 (Blackwell), the backward kernel hangs indefinitely due to a GPU deadlock in dQacc_reduce.

Bug: 在 sm100 (Blackwell) GPU 上,当同时启用 deterministic=True 和 block-sparse attention (arbitrary=True, func_num > 1) 时,backward kernel 由于 dQacc_reduce 中的 GPU 死锁而无限挂起。


1. Background / 背景

1.1 Flash Attention 4 (FA4)

Flash Attention 4 targets NVIDIA Blackwell (sm100), written in CUTLASS Cute DSL, JIT-compiled to efficient CUDA kernels. The backward pass (FlashAttentionBackwardSm100) has three parallel phases:

  • Compute warps (4–11): dQ/dK/dV matrix multiply
  • Reduce warps (0–3): reduce dQ intermediate results from TMEM to global memory
  • Load warps: TMA data prefetch

1.2 Deterministic Mode / 确定性模式

Multiple CTAs write to the same dQ[m_block] in non-deterministic order → floating-point accumulation differs across runs. Deterministic mode uses semaphore-based ordering to force all CTAs to accumulate dQ in a fixed sequence, ensuring bit-level reproducibility.

由于 GPU 的并行特性,不同 CTA 写入同一 dQ[m_block] 的顺序在不同运行之间不确定。Deterministic 模式通过信号量机制强制所有 CTA 按固定顺序累加 dQ,确保结果比特级可复现。

1.3 Block-Sparse Attention

In real scenarios (HSTU, varlen sequence packing), the attention matrix is not fully dense. Block-Sparse Attention uses CSR structures to describe block-level sparsity, skipping unnecessary block pairs:

LinearBlockSparseTensorsTorch:
    mask_block_cnt:    [n_blocks]        # per-n_block count of mask m_blocks
    mask_block_offset: [n_blocks + 1]    # CSR offset array
    mask_block_idx:    [nnz_mask]        # actual m_block indices
    full_block_cnt:    [n_blocks]        # per-n_block count of full-attention m_blocks
    full_block_offset: [n_blocks + 1]
    full_block_idx:    [nnz_full]
  • n_block (K-dim tiles): outer iteration dimension in backward, each CTA handles one n_block
  • m_block (Q-dim tiles): inner iteration, each n_block only processes m_blocks listed in its CSR

2. Original Design: Deterministic dQ Reduce Semaphore Protocol

2.1 Protocol

The backward reduce phase accumulates dQ contributions from multiple n_blocks into dQ_accum[m_block]. Deterministic mode uses a per-m_block int32 semaphore array, initialized to 0:

For each (n_block, m_block) pair:
  1. Acquire: barrier.wait_eq(semaphore[m_block], lock_value)
  2. Execute: dQ_accum[m_block] += contribution
  3. Release: barrier.arrive_inc(semaphore[m_block], 1)

2.2 Dense Case

In non-block-sparse (dense) mode, every n_block processes every m_block, so lock_value = n_block works correctly:

Step n_block Action semaphore[M]
0 0 wait_eq(0) ✅ → reduce → arrive_inc(1) 0 → 1
1 1 wait_eq(1) ✅ → reduce → arrive_inc(1) 1 → 2
2 2 wait_eq(2) ✅ → reduce → arrive_inc(1) 2 → 3

The corresponding code in flash_bwd_sm100.py :: dQacc_reduce:

# Original code (buggy)
if const_expr(self.deterministic and stage == 0):
    if const_expr(self.spt):
        lock_value = n_block_max_for_m_block - 1 - n_block  # causal path
    else:
        lock_value = n_block   # ← assumes dense iteration
    barrier.wait_eq(semaphore[m_block], lock_value)

3. Bug Analysis: Semaphore Deadlock under Block-Sparsity

3.1 Root Cause / 问题根源

With block sparsity, the reduce loop only iterates over m_blocks in the current n_block's CSR:

if const_expr(self.use_block_sparsity):
    (curr_mask_cnt, curr_mask_offset, curr_full_cnt, curr_full_offset, loop_count
    ) = get_block_sparse_iteration_info_bwd(blocksparse_tensors, batch_idx, head_idx, n_block)
else:
    loop_count = m_block_max - m_block_min  # dense: all m_blocks

for iter_idx in cutlass.range(loop_count, unroll=1):
    m_block = get_m_block_from_iter_bwd(iter_idx, ...)
    # ... semaphore acquire, reduce, release ...

But lock_value = n_block was NOT updated. It still assumes every n_block processes every m_block.

启用 block sparsity 后,reduce 循环只遍历当前 n_block 的 CSR 中列出的 m_blocks,lock_value = n_block 这一计算并未随之改变。

3.2 Deadlock Mechanism / 死锁形成机制

Example: m_block M is processed only by n_blocks {0, 2, 5} (skipping 1, 3, 4):

Step n_block lock_value semaphore[M] Result
0 0 0 0 wait_eq(0) ✅ → reduce → arrive_inc → sem=1
1 1 1 Skips M (not in CSR), doesn't touch semaphore
2 2 2 1 wait_eq(2) ❌ but sem=1deadlock forever
n_block:  0     1     2     3     4     5     6     7
          ┌─────┐     ┌─────┐                 ┌─────┐
Process:  │ YES │ NO  │ YES │ NO    NO        │ YES │
          └──┬──┘     └──┬──┘                 └──┬──┘
             │           │                       │
sem[M]:   0→1         wait(2)                    │
                      sem=1≠2                    │
                      ╔═══════╗                  │
                      ║ HANG! ║                  │
                      ╚═══════╝                  │
                    never reaches here ───────────┘

3.3 Trigger Conditions / 触发条件

All must be true:

  1. sm100 (Blackwell) GPU — deterministic backward only implemented for sm100
  2. deterministic=True — enables semaphore synchronization
  3. arbitrary=True with func_num > 1 — triggers block-sparse attention (HSTU-encoded varlen masks)
  4. Non-trivial block sparsity — some (n_block, m_block) pairs are skipped

3.4 Why Hard to Discover / 为什么难以发现

  • Non-sparse tests (full_4k, causal_4k) work perfectly — no block sparsity
  • deterministic=False works perfectly — no semaphore
  • GPU shows 100% utilization during hang — easily mistaken for long computation
  • py-spy only shows torch.equal() / cuda.synchronize() — real deadlock is inside GPU kernel

3.5 Minimal Repro Trigger Analysis / 最小复现 Case 的触发分析

3.5.1 Scenario / 场景构造

The repro simulates varlen_full_1k: 10 disjoint full-attention segments packed into seqlen=1024:

RANGES = [
    (0, 366),     # segment 0: 366 tokens
    (366, 391),   # segment 1: 25 tokens
    (391, 471),   # segment 2: 80 tokens
    (471, 835),   # segment 3: 364 tokens
    (835, 984),   # segment 4: 149 tokens
    (984, 1005),  # segment 5: 21 tokens
    (1005, 1017), # segment 6: 12 tokens
    (1017, 1020), # segment 7: 3 tokens
    (1020, 1023), # segment 8: 3 tokens
    (1023, 1024), # segment 9: 1 token
]

Each segment has full attention within itself; no attention between segments → block-diagonal pattern.

3.5.2 From Ranges to Block-Sparse CSR / 从 Ranges 到 CSR 的推导

With tile_size=128, Q/K each have 1024/128 = 8 blocks (0–7):

tile 0: [0, 128)    tile 1: [128, 256)   tile 2: [256, 384)   tile 3: [384, 512)
tile 4: [512, 640)   tile 5: [640, 768)   tile 6: [768, 896)   tile 7: [896, 1024)

Step 1: Range coverage per tile / 每个 Range 在 tile 上的覆盖

Segment boundaries don't align with tile boundaries:

Range Interval Tiles Coverage
0 [0, 366) {0, 1, 2} tile 0: FULL, tile 1: FULL, tile 2: partial(256–365)
1 [366, 391) {2, 3} tile 2: partial(366–383), tile 3: partial(384–390)
2 [391, 471) {3} tile 3: partial(391–470)
3 [471, 835) {3, 4, 5, 6} tile 3: partial(471–511), tile 4: FULL, tile 5: FULL, tile 6: partial(768–834)
4 [835, 984) {6, 7} tile 6: partial(835–895), tile 7: partial(896–983)
5–9 [984, 1024) {7} tile 7: partial

Step 2: 8×8 attention tile matrix

For each tile pair (m, n), classify as Full (F), Mask (M), or empty (.):

          K: n=0  n=1  n=2  n=3  n=4  n=5  n=6  n=7
    Q: m=0:  F    F    M    .    .    .    .    .
       m=1:  F    F    M    .    .    .    .    .
       m=2:  M    M    M    M    .    .    .    .
       m=3:  .    .    M    M    M    M    M    .
       m=4:  .    .    .    M    F    F    M    .
       m=5:  .    .    .    M    F    F    M    .
       m=6:  .    .    .    M    M    M    M    M
       m=7:  .    .    .    .    .    .    M    M

Key examples:

  • (m=0, n=0) = F: Range 0 fully covers both tile 0 and tile 0
  • (m=0, n=2) = M: Range 0 fully covers tile 0 but only partially covers tile 2 (to 365)
  • (m=4, n=4) = F: Range 3 fully covers tile 4
  • (m=0, n=3) = .: No range covers both tile 0 and tile 3

Step 3: Transpose to k2q CSR (backward direction)

Reading the matrix by columns, separating mask/full:

n_block 0: mask=[m2]              full=[m0, m1]    → processes {0, 1, 2}
n_block 1: mask=[m2]              full=[m0, m1]    → processes {0, 1, 2}
n_block 2: mask=[m0, m1, m2, m3]  full=[]          → processes {0, 1, 2, 3}
n_block 3: mask=[m2,m3,m4,m5,m6]  full=[]          → processes {2, 3, 4, 5, 6}
n_block 4: mask=[m3, m6]          full=[m4, m5]    → processes {3, 4, 5, 6}
n_block 5: mask=[m3, m6]          full=[m4, m5]    → processes {3, 4, 5, 6}
n_block 6: mask=[m3,m4,m5,m6,m7]  full=[]          → processes {3, 4, 5, 6, 7}
n_block 7: mask=[m6, m7]          full=[]          → processes {6, 7}

This is exactly the CSR data hardcoded in the repro's build_bwd_block_sparse(). No n_block processes all 8 m_blocks.

3.5.3 Deadlock Points / 死锁点精确定位

Transposing the CSR to see which n_blocks process each m_block:

m_block Processing n_blocks Skipped n_blocks Deadlocks?
0 {0, 1, 2} {3, 4, 5, 6, 7} ✅ OK
1 {0, 1, 2} {3, 4, 5, 6, 7} ✅ OK
2 {0, 1, 2, 3} {4, 5, 6, 7} ✅ OK
3 {2, 3, 4, 5, 6} {0, 1, 7} DEADLOCK
4 {3, 4, 5, 6} {0, 1, 2, 7} DEADLOCK
5 {3, 4, 5, 6} {0, 1, 2, 7} DEADLOCK
6 {3, 4, 5, 6, 7} {0, 1, 2} DEADLOCK
7 {6, 7} {0, 1, 2, 3, 4, 5} DEADLOCK

With lock_value = n_block:

  • m_block 0–2: First processor is n_block 0 → lock_value = 0, sem=0 → OK ✅
  • m_block 3: First processor is n_block 2lock_value = 2, but n_blocks 0,1 skip → sem stuck at 0 ≠ 2 → DEADLOCK
  • m_block 4–5: First processor is n_block 3lock_value = 3, n_blocks 0,1,2 skip → sem = 0 ≠ 3 → DEADLOCK
  • m_block 6: First processor is n_block 3 → same → DEADLOCK
  • m_block 7: First processor is n_block 6lock_value = 6, n_blocks 0–5 all skip → sem = 0 ≠ 6 → DEADLOCK

5 out of 8 m_blocks deadlock (62.5%) — guaranteed hang.

3.5.4 Detailed Trace: m_block 3 / 以 m_block 3 为例的详细追踪

Before fix (lock_value = n_block):

Initial: semaphore[3] = 0

n_block 0: m_block 3 not in CSR → skip, sem unchanged → sem = 0
n_block 1: m_block 3 not in CSR → skip, sem unchanged → sem = 0
n_block 2: m_block 3 in CSR → acquire:
           lock_value = n_block = 2
           wait_eq(semaphore[3], 2) → sem = 0 ≠ 2
           ╔══════════════════════════════════════════════╗
           ║  GPU CTA spins forever, never satisfied     ║
           ╚══════════════════════════════════════════════╝
→ entire backward kernel hangs

After fix (lock_value = rank among processors):

Initial: semaphore[3] = 0

n_block 2: rank 0 among m_block 3's processors → lock_value = 0
           wait_eq(semaphore[3], 0) → sem = 0 == 0 ✅
           reduce → arrive_inc → sem = 1

n_block 3: rank 1 → lock_value = 1
           wait_eq(semaphore[3], 1) → sem = 1 == 1 ✅
           reduce → arrive_inc → sem = 2

n_block 4: rank 2 → lock_value = 2
           wait_eq(semaphore[3], 2) → sem = 2 == 2 ✅
           ...
→ all processors complete successfully, no deadlock

3.5.5 Why This Config / 为什么选择这个配置

varlen_full_1k is the simplest natural scenario that triggers the bug:

  1. Varlen packing produces multiple full-attention segments; boundaries don't align with tile edges → naturally sparse block pattern
  2. seqlen=1024 gives only 8 blocks — small enough to analyze, complex enough to produce gaps
  3. GQA (qhead_per_kvhead=8) and func_num=3 (HSTU encoding) ensure the arbitrary=True path
  4. 10 unequal-length segments (366 tokens down to 1 token) ensure boundaries cross multiple tile edges, producing rich mask/full mixtures

With this config, 5/8 m_blocks deadlock (62.5%) — a guaranteed, always-reproducible bug.


4. Fix / 修复方案

4.1 Core Idea / 核心思路

Replace lock_value = n_block (global n_block index) with rank of the current n_block among all n_blocks that process the same m_block:

n_block Original lock_value Fixed lock_value
0 0 0 (rank 0)
2 2 ← ❌ 1 (rank 1)
5 5 ← ❌ 2 (rank 2)

Semaphore 0→1→2→3 monotonically increases with no gaps.

4.2 Host-Side Precomputation (interface.py)

New function _compute_bwd_dQ_lock_values() traverses k2q CSR in n_block order. For each entry, lock_value = count of previously seen processors for that m_block. Pure GPU computation, O(nnz), no CPU-GPU sync.

def _compute_bwd_dQ_lock_values(block_sparse_tensors):
    """For each CSR entry (n_block, m_block), compute lock_value =
    count of n_blocks in [0, n_block) that also process this m_block."""
    m_counter = {}  # m_block → count of seen processors
    for n in range(num_n_blocks):
        for m_block in n's mask entries:
            rank = m_counter.get(m_block, 0)
            mask_lock_values[offset + i] = rank
            m_counter[m_block] = rank + 1
        for m_block in n's full entries:
            rank = m_counter.get(m_block, 0)
            full_lock_values[offset + i] = rank
            m_counter[m_block] = rank + 1
    return mask_lock_values, full_lock_values

4.3 Kernel-Side Read (flash_bwd_sm100.py)

Thread dQ_lock_values_mask / dQ_lock_values_full through __call__kerneldQacc_reduce:

# Fixed code
if const_expr(self.deterministic and stage == 0):
    if const_expr(self.spt):
        lock_value = n_block_max_for_m_block - 1 - n_block  # causal: unchanged
    elif const_expr(self.use_block_sparsity):
        lock_value = Int32(0)  # Must initialize (Cute DSL constraint)
        if const_expr(dQ_lock_values_full is not None):
            if iter_idx < curr_mask_cnt:
                lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx]
            else:
                lock_value = dQ_lock_values_full[curr_full_offset + iter_idx - curr_mask_cnt]
        else:
            lock_value = dQ_lock_values_mask[curr_mask_offset + iter_idx]
    else:
        lock_value = n_block  # dense: unchanged
    barrier.wait_eq(semaphore[m_block], lock_value)

Cute DSL constraint: Runtime if branches cannot first-define a variable. Must initialize lock_value = Int32(0) before conditional overwrites (DSL compiles to MLIR/PTX with static type constraints).

4.4 Release Side — No Change Needed

barrier.arrive_inc(semaphore[m_block], 1) always increments by 1, naturally matching rank-based values:

  • Rank 0 releases: sem = 0 + 1 = 1
  • Rank 1 waits for 1 ✅, releases: sem = 1 + 1 = 2
  • Rank 2 waits for 2 ✅, releases: sem = 2 + 1 = 3

5. Minimal Reproduction Script / 最小复现脚本

Click to expand (no external dependency except torch + flash_attn)
#!/usr/bin/env python3
"""Minimal repro: FA4 deterministic backward hangs with block-sparse attention.

Expected:
  - deterministic=False → OK (~1s)
  - deterministic=True  → HANG indefinitely (before fix)

Usage: CUDA_VISIBLE_DEVICES=0 python repro.py
"""
import signal, sys, time, torch
from flash_attn.cute.interface import _flash_attn_fwd, _flash_attn_bwd
from flash_attn.cute.block_sparsity import LinearBlockSparseTensorsTorch

SEQLEN_Q = SEQLEN_K = 1024
NUM_HEADS_Q, NUM_HEADS_KV, HEAD_DIM = 32, 4, 128
DTYPE, PAD, TIMEOUT = torch.bfloat16, 256, 60

RANGES = [
    (0, 366), (366, 391), (391, 471), (471, 835), (835, 984),
    (984, 1005), (1005, 1017), (1017, 1020), (1020, 1023), (1023, 1024),
]

def build_aux_tensor(ranges, seqlen, pad):
    total = seqlen + pad
    func = torch.zeros(3, total, dtype=torch.int32)
    func[0, 0:366] = 366
    for qs, qe in ranges:
        func[1, qs:qe] = qs
    func[1, 0:366] = 0
    for qs, qe in ranges:
        func[2, qs:qe] = qe
    func[2, 0:366] = 0
    return func.unsqueeze(0).unsqueeze(0).cuda()

def build_bwd_block_sparse():
    return LinearBlockSparseTensorsTorch(
        mask_block_cnt=torch.tensor([1,1,4,5,2,2,5,2], dtype=torch.int32).cuda(),
        mask_block_offset=torch.tensor([0,1,2,6,11,13,15,20,22], dtype=torch.int32).cuda(),
        mask_block_idx=torch.tensor([2,2,0,1,2,3,2,3,4,5,6,3,6,3,6,3,4,5,6,7,6,7], dtype=torch.int32).cuda(),
        full_block_cnt=torch.tensor([2,2,0,0,2,2,0,0], dtype=torch.int32).cuda(),
        full_block_offset=torch.tensor([0,2,4,4,4,6,8,8,8], dtype=torch.int32).cuda(),
        full_block_idx=torch.tensor([0,1,0,1,4,5,4,5], dtype=torch.int32).cuda(),
    )

def build_fwd_block_sparse():
    q2k = LinearBlockSparseTensorsTorch(
        mask_block_cnt=torch.tensor([1,7,2,5], dtype=torch.int32).cuda(),
        mask_block_offset=torch.tensor([0,1,8,10,15], dtype=torch.int32).cuda(),
        mask_block_idx=torch.tensor([2,0,1,2,3,4,5,6,3,6,3,4,5,6,7], dtype=torch.int32).cuda(),
        full_block_cnt=torch.tensor([2,0,2,0], dtype=torch.int32).cuda(),
        full_block_offset=torch.tensor([0,2,2,4,4], dtype=torch.int32).cuda(),
        full_block_idx=torch.tensor([0,1,4,5], dtype=torch.int32).cuda(),
    )
    k2q = LinearBlockSparseTensorsTorch(
        mask_block_cnt=torch.tensor([1,1,4,5,2,2,5,2], dtype=torch.int32).cuda(),
        mask_block_offset=torch.tensor([0,1,2,6,11,13,15,20,22], dtype=torch.int32).cuda(),
        mask_block_idx=torch.tensor([2,2,0,1,2,3,2,3,4,5,6,3,6,3,6,3,4,5,6,7,6,7], dtype=torch.int32).cuda(),
        full_block_cnt=torch.tensor([2,2,0,0,2,2,0,0], dtype=torch.int32).cuda(),
        full_block_offset=torch.tensor([0,2,4,4,4,6,8,8,8], dtype=torch.int32).cuda(),
        full_block_idx=torch.tensor([0,1,0,1,4,5,4,5], dtype=torch.int32).cuda(),
    )
    return q2k, k2q

def main():
    torch.manual_seed(42)
    dev = torch.device("cuda:0")
    assert torch.cuda.get_device_capability(dev)[0] == 10, "Requires sm100"

    q = torch.randn(1, SEQLEN_Q, NUM_HEADS_Q, HEAD_DIM, dtype=DTYPE, device=dev)
    k = torch.randn(1, SEQLEN_K, NUM_HEADS_KV, HEAD_DIM, dtype=DTYPE, device=dev)
    v = torch.randn(1, SEQLEN_K, NUM_HEADS_KV, HEAD_DIM, dtype=DTYPE, device=dev)
    dout = torch.randn_like(q)

    aux = build_aux_tensor(RANGES, SEQLEN_Q, PAD)
    q2k, k2q_fwd = build_fwd_block_sparse()
    bwd_sparse = build_bwd_block_sparse()

    # Forward
    out, lse = _flash_attn_fwd(q, k, v, softmax_scale=HEAD_DIM**-0.5, causal=False,
        arbitrary=True, softcap=0.0, num_splits=1, pack_gqa=False,
        block_sparse_tensors=q2k, aux_tensors=[aux], return_lse=True)
    torch.cuda.synchronize()

    # Backward warmup
    _flash_attn_bwd(q, k, v, out=out, dout=dout, lse=lse, softmax_scale=HEAD_DIM**-0.5,
        causal=False, arbitrary=True, softcap=0.0,
        block_sparse_tensors=bwd_sparse, aux_tensors=[aux], deterministic=False)
    torch.cuda.synchronize()

    # Backward deterministic — expected to HANG before fix
    print(f"Backward (deterministic=True, timeout={TIMEOUT}s) ...", flush=True)
    signal.signal(signal.SIGALRM, lambda *_: (print("HANG DETECTED"), sys.exit(1)))
    signal.alarm(TIMEOUT)
    t0 = time.time()
    _flash_attn_bwd(q, k, v, out=out, dout=dout, lse=lse, softmax_scale=HEAD_DIM**-0.5,
        causal=False, arbitrary=True, softcap=0.0,
        block_sparse_tensors=bwd_sparse, aux_tensors=[aux], deterministic=True)
    torch.cuda.synchronize()
    signal.alarm(0)
    print(f"OK ({time.time()-t0:.2f}s)")

if __name__ == "__main__":
    main()

6. Files Changed / 修改文件

File Change Type Description
flash_attn/cute/interface.py New function _compute_bwd_dQ_lock_values(): traverse CSR to precompute lock values
flash_attn/cute/interface.py Modified Call precompute in _flash_attn_bwd, thread results through compile/execute
flash_attn/cute/flash_bwd_sm100.py Modified Add dQ_lock_values_mask/full params to __call__, kernel, dQacc_reduce
flash_attn/cute/flash_bwd_sm100.py Modified Block-sparse reduce path reads precomputed lock_value

+125 lines, 0 deletions, 0 modifications to existing logic. Non-block-sparse, causal, and dense paths are completely untouched. No changes to CSR data structures (LinearBlockSparseTensors), all existing callers unaffected.


7. Verification / 验证

Minimal Repro

Scenario Before Fix After Fix
deterministic=False ✅ 0.00s ✅ 0.00s
deterministic=True ❌ hang forever ✅ ~4s (incl. JIT)

Full Test Suite (NVIDIA B200 sm100, bfloat16)

Config Result Time
full_1k ✅ PASSED 7.6s
full_4k ✅ PASSED 31.9s
varlen_full_1k ✅ PASSED 19.7s
varlen_full_4k ✅ PASSED 22.8s
causal_1k ✅ PASSED 7.4s
causal_4k ✅ PASSED 7.3s
varlen_causal_1k ✅ PASSED 7.4s
varlen_causal_4k ✅ PASSED 7.4s
sliding_window_1k ✅ PASSED 7.4s
sliding_window_4k ✅ PASSED 7.7s
varlen_sliding_window_1k ✅ PASSED 7.8s
varlen_sliding_window_4k ✅ PASSED 7.7s

varlen_full_1k includes deterministic=True + GQA (qhead_per_kvhead=8) — the exact config that previously deadlocked — now passes.


8. Lessons Learned / 经验与启示

8.1 Dense assumptions are a classic sparsity pitfall

The root cause is a classic sparsification error: loop range changed from dense to sparse, but synchronization logic dependent on dense semantics was not updated. In GPU kernels, such errors manifest as deadlocks rather than crashes, making debugging extremely difficult.

8.2 GPU semaphore deadlock debugging

  • py-spy dump --pid <PID> quickly locates Python-level blocking
  • nvidia-smi showing GPU 100% + Python stuck on synchronize() = classic GPU kernel deadlock
  • Most effective: write a minimal reproduction script removing all framework dependencies

8.3 Cute DSL programming constraints

During the fix, a Cute DSL limitation was encountered: runtime if branches cannot first-define a variable ("lock_value is None prior to this if"). Variables must be initialized outside the branch first. This differs from standard Python — it's a static type constraint when DSL compiles to MLIR/PTX.

8.4 Precompiled cache versioning

After modifying kernel signatures (adding parameters), stale precompiled caches cause TypeError: unexpected keyword argument. Production environments should clear or version-control compilation caches after signature changes.

When `deterministic=True` and block-sparse attention is active (arbitrary
mask with func_num > 1), the dQ_semaphore in `dQacc_reduce` uses
`lock_value = n_block`, which assumes every n_block processes every
m_block (dense iteration). With block sparsity, skipped n_blocks never
increment the semaphore, causing downstream CTAs to wait forever on
`barrier.wait_eq`.

Fix: precompute per-CSR-entry lock values on the host side (the rank of
each n_block among all n_blocks that process the same m_block), pass them
as `dQ_lock_values_mask` / `dQ_lock_values_full` tensors to the kernel,
and read from these tensors instead of using raw `n_block`.

The release side (`arrive_inc(sem, 1)`) needs no change since rank-based
values naturally form a 0, 1, 2, ... sequence matching +1 increments.

Verified on B200 (sm100) across all 12 attention mask configurations
(full, varlen_full, causal, varlen_causal, sliding_window,
varlen_sliding_window × 1k/4k) with bfloat16.

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant