Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,4 @@ training/data
# ck modules
csrc/composable_kernel
csrc/cutlass
.analysis
.amd
169 changes: 140 additions & 29 deletions flash_attn/flash_attn_triton_amd/fwd_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,26 +382,97 @@ def _attn_fwd_mask(acc, l_i, m_i,


@triton.jit
def compute_masking(seqlen_k, seqlen_q, start_m,
IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr,
WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
"""
Classify K blocks for attention computation with sliding window support.
def compute_window_bounds(q_start, q_end, diag, seqlen_k,
WINDOW_SIZE_LEFT: tl.constexpr,
WINDOW_SIZE_RIGHT: tl.constexpr,
IS_CAUSAL: tl.constexpr):
"""Calculate the window boundaries for a query block."""
# Left boundary
if WINDOW_SIZE_LEFT < 0:
left_min = 0
left_max = 0
else:
left_min = tl.maximum(0, q_start + diag - WINDOW_SIZE_LEFT)
left_max = tl.maximum(0, q_end + diag - WINDOW_SIZE_LEFT)

Returns:
- n_front_skip_blocks: Blocks completely before the window
- n_front_masked_blocks: Blocks partially overlapping window front
- n_full_blocks: Blocks completely inside the window
- n_back_masked_blocks: Blocks partially overlapping window back
- n_extra_tokens: Padding tokens in last K block
# Right boundary
if IS_CAUSAL:
# Causal cap: col ≤ row + diag
right_min = tl.minimum(seqlen_k - 1, q_start + diag)
right_max = tl.minimum(seqlen_k - 1, q_end + diag)
else:
if WINDOW_SIZE_RIGHT < 0:
right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT)
right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT)
else:
# Non-causal doesn't have the diagonal constraint
right_min = tl.minimum(seqlen_k - 1, q_start + diag + WINDOW_SIZE_RIGHT)
right_max = tl.minimum(seqlen_k - 1, q_end + diag + WINDOW_SIZE_RIGHT)

return left_min, left_max, right_min, right_max

@triton.jit
def classify_window_blocks(left_min, left_max, right_min, right_max,
BLOCK_N: tl.constexpr):
"""Classify blocks based on window boundaries."""
# First and last blocks that have ANY overlap with window
first_block = left_min // BLOCK_N
last_block = right_max // BLOCK_N

# First block that is FULLY visible for all rows in Q block
full_left_block = left_max // BLOCK_N + (left_max % BLOCK_N != 0)
clipped_left = tl.minimum(full_left_block, last_block + 1)

# Last block that is FULLY visible for all rows in Q block
last_full_block_candidate = right_min // BLOCK_N
if (last_full_block_candidate + 1) * BLOCK_N - 1 > right_min:
last_full_block_candidate -= 1
full_right_block = tl.maximum(last_full_block_candidate, clipped_left - 1)

# Calculate counts
n_front_skip_blocks = first_block
n_front_masked_blocks = tl.maximum(0, clipped_left - first_block)
n_full_blocks = tl.maximum(0, full_right_block - clipped_left + 1)
n_back_masked_blocks = tl.maximum(0, last_block - full_right_block)

return (n_front_skip_blocks, n_front_masked_blocks,
n_full_blocks, n_back_masked_blocks,
clipped_left) # Return clipped_left for padded block handling

@triton.jit
def handle_padded_last_block(n_extra_tokens, last_block, total_k_blocks,
clipped_left, n_front_masked_blocks,
n_full_blocks, n_back_masked_blocks):
"""Ensure a padded last K-block is never classified as 'full'.

We move the padded last block (if visible) into the back-masked bucket.
If it's already back-masked, we do nothing. If it was counted in the
front-masked range, we decrement front-masked; if it was counted as full,
we decrement full. Then we increment back-masked.
"""
# Example case
# BLOCK_M = 4, BLOCK_N = 4, seqlen_q = 8, seqlen_k = 10
padded_last_k = (n_extra_tokens != 0) & (last_block == total_k_blocks - 1)

if padded_last_k:
# current 'full' range right edge
full_right_block = clipped_left + n_full_blocks - 1

# If last_block is already beyond full_right_block, it's already in back-masked → nothing to do
last_already_back_masked = last_block > full_right_block
if not last_already_back_masked:
# If the window starts past last_block, it was counted in front-masked
if clipped_left > last_block:
n_front_masked_blocks = tl.maximum(0, n_front_masked_blocks - 1)
else:
# Otherwise it was counted 'full' → move it out of full
n_full_blocks = tl.maximum(0, n_full_blocks - 1)
# In both cases we need one more back-masked block
n_back_masked_blocks = n_back_masked_blocks + 1

# Total K blocks in the key sequence
total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N)
return n_front_masked_blocks, n_full_blocks, n_back_masked_blocks

@triton.jit
def compute_padding_info(seqlen_k, BLOCK_N: tl.constexpr):
"""Calculate padding information for the last K block."""
# check if we will need to do masking due either BLOCK_N being bigger than seqlen_k or seqlen_k not being a factor of BLOCK_N
# n_extra_tokens = 10 % 4 = 2
# This means the last K block has 2 valid tokens and 2 padding positions
Expand All @@ -415,15 +486,60 @@ def compute_masking(seqlen_k, seqlen_q, start_m,
elif seqlen_k % BLOCK_N:
n_extra_tokens = seqlen_k % BLOCK_N
else:
n_extra_tokens = 0
n_extra_tokens = 0
return n_extra_tokens

@triton.jit
def compute_block_masking(seqlen_k, seqlen_q, start_m,
IS_CAUSAL: tl.constexpr, USE_SLIDING_WINDOW: tl.constexpr,
WINDOW_SIZE_LEFT: tl.constexpr, WINDOW_SIZE_RIGHT: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
"""
Classify K blocks for attention computation with sliding window support.

Returns:
- n_front_skip_blocks: Blocks completely before the window
- n_front_masked_blocks: Blocks partially overlapping window front
- n_full_blocks: Blocks completely inside the window
- n_back_masked_blocks: Blocks partially overlapping window back
- n_extra_tokens: Padding tokens in last K block
"""

# common
q_start = start_m * BLOCK_M
q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1)
diag = seqlen_k - seqlen_q
total_k_blocks = tl.cdiv(seqlen_k, BLOCK_N)
n_extra_tokens = compute_padding_info(seqlen_k, BLOCK_N)

if USE_SLIDING_WINDOW:
# TODO: Optimize by computing which blocks can be fully skipped
# For now, process all blocks with the mask function
if IS_CAUSAL:
return 0, 0, 0, total_k_blocks, n_extra_tokens
else:
return 0, 0, 0, total_k_blocks, n_extra_tokens
# get window bounds
left_min, left_max, right_min, right_max = compute_window_bounds(
q_start, q_end, diag, seqlen_k,
WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, IS_CAUSAL
)

# window vanishes → early exit
if right_max < left_min:
return 0, 0, 0, 0, n_extra_tokens

# classify blocks
(n_front_skip_blocks, n_front_masked_blocks,
n_full_blocks, n_back_masked_blocks,
clipped_left) = classify_window_blocks(
left_min, left_max, right_min, right_max, BLOCK_N
)

# handle padded last block if needed
if n_extra_tokens != 0:
last_block = right_max // BLOCK_N
n_front_masked_blocks, n_full_blocks, n_back_masked_blocks = handle_padded_last_block(
n_extra_tokens, last_block, total_k_blocks,
clipped_left, n_front_masked_blocks,
n_full_blocks, n_back_masked_blocks
)
return (n_front_skip_blocks, n_front_masked_blocks,
n_full_blocks, n_back_masked_blocks, n_extra_tokens)
else:
if IS_CAUSAL:
# ========== CAUSAL MODE: Classify K Blocks ==========
Expand All @@ -444,11 +560,6 @@ def compute_masking(seqlen_k, seqlen_q, start_m,
# 1. figure out, in tokens, the right-most K position
# this Q-block may attend to
# ------------------------------------------------------------
q_start = start_m * BLOCK_M
q_end = tl.minimum((start_m + 1) * BLOCK_M - 1, seqlen_q - 1)

# causal diagonal offset between the two streams
diag = seqlen_k - seqlen_q # 0 when |Q| == |K|
k_max_token = q_end + diag # last visible K index

# this Q-block is entirely above the diagonal ⇒ nothing to do
Expand Down Expand Up @@ -575,7 +686,7 @@ def attn_fwd(Q, K, V, bias,


# figure out masking pattern
n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_masking(
n_front_skip_blocks, n_front_masked_blocks, n_full_blocks, n_back_masked_blocks, n_extra_tokens = compute_block_masking(
seqlen_k, seqlen_q, start_m, IS_CAUSAL, USE_SLIDING_WINDOW,
WINDOW_SIZE_LEFT, WINDOW_SIZE_RIGHT, BLOCK_M, BLOCK_N
)
Expand Down
Loading