Skip to content
102 changes: 62 additions & 40 deletions backends/cuda/triton/kernels/sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,21 +422,22 @@ def _sdpa_fwd_kernel_body(

offs_n_init = tl.arange(0, BLOCK_N)

# Window-aware early-exit. A KV block that is fully masked (sliding-window
# or causal) contributes nothing to the online softmax — every entry is
# -inf, so p=0 and m_i/l_i/acc are left unchanged. We detect such blocks up
# front and skip their K/V loads and both matmuls. This is exact: it only
# skips work the mask would have zeroed out anyway. At seq=2048 the 50
# sliding-window(1024) layers and the 10 causal layers each leave roughly
# half (or more) of their KV blocks fully masked, so this is a large cut to
# the dominant prefill cost. The skip condition is a CTA-wide reduction, so
# the branch is uniform and turns into a real skip (not predication).
if IS_CAUSAL:
max_seq_pos = tl.max(seq_pos)

for start_n in tl.range(0, Lk, BLOCK_N):
offs_n = start_n + offs_n_init

# K load: uniform (single KV head, shared across all Q heads in tile)
k_ptrs = K_ptr + (
b * stride_kb
+ h_kv * stride_kh
+ (offs_n[:, None] * stride_kn)
+ (offs_d[None, :] * stride_kd)
)
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)

qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)

# Decide whether any row in this tile actually attends to this KV block.
if HAS_MASK:
mask_ptrs = Mask_ptr + (
b * stride_mb
Expand All @@ -445,39 +446,60 @@ def _sdpa_fwd_kernel_body(
)
mn_mask = row_valid[:, None] & (offs_n[None, :] < Lk)
mask_block = tl.load(mask_ptrs, mask=mn_mask, other=False)
qk = tl.where(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
block_active = tl.sum(mask_block.to(tl.int32)) > 0
elif IS_CAUSAL:
# Block is entirely in the future for every row -> skip.
block_active = start_n <= max_seq_pos
else:
block_active = True

if block_active:
# K load: uniform (single KV head, shared across Q heads in tile)
k_ptrs = K_ptr + (
b * stride_kb
+ h_kv * stride_kh
+ (offs_n[:, None] * stride_kn)
+ (offs_d[None, :] * stride_kd)
)
k_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
k = tl.load(k_ptrs, mask=k_mask, other=0.0).to(tl.bfloat16)

if IS_CAUSAL:
causal = offs_n[None, :] > seq_pos[:, None]
qk = tl.where(
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
)
qk = (tl.dot(q, tl.trans(k)).to(tl.float32) * sm_scale).to(tl.float32)

m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)
if HAS_MASK:
qk = tl.where(
mask_block, qk, tl.full(qk.shape, -float("inf"), dtype=tl.float32)
)

# V load: uniform (single KV head)
v_ptrs = V_ptr + (
b * stride_vb
+ h_kv * stride_vh
+ (offs_n[:, None] * stride_vn)
+ (offs_d[None, :] * stride_vd)
)
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)
if IS_CAUSAL:
causal = offs_n[None, :] > seq_pos[:, None]
qk = tl.where(
causal, tl.full(qk.shape, -float("inf"), dtype=tl.float32), qk
)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij
m_ij = tl.maximum(m_i, tl.max(qk, axis=1).to(tl.float32))
safe_diff = tl.where(
m_ij[:, None] > -float("inf"), qk - m_ij[:, None], -float("inf")
)
p_f32 = tl.exp(safe_diff).to(tl.float32)
l_ij = tl.sum(p_f32, axis=1).to(tl.float32)
safe_alpha_diff = tl.where(m_ij > -float("inf"), m_i - m_ij, 0.0)
alpha = tl.exp(safe_alpha_diff).to(tl.float32)

# V load: uniform (single KV head)
v_ptrs = V_ptr + (
b * stride_vb
+ h_kv * stride_vh
+ (offs_n[:, None] * stride_vn)
+ (offs_d[None, :] * stride_vd)
)
v_mask = (offs_n[:, None] < Lk) & (offs_d[None, :] < HEAD_DIM)
v = tl.load(v_ptrs, mask=v_mask, other=0.0).to(tl.bfloat16)

p_bf16 = p_f32.to(tl.bfloat16)
acc = (acc * alpha[:, None] + tl.dot(p_bf16, v)).to(tl.float32)
l_i = (l_i * alpha + l_ij).to(tl.float32)
m_i = m_ij

inv_l_i = tl.where(l_i > 0, 1.0 / l_i, 0.0)
acc = acc * inv_l_i[:, None]
Expand Down
Loading