Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion flashmask/flash_mask/cute/flash_fwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ def __call__(
mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table
mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q
mLSE: Optional[cute.Tensor],
mMaxLogit: Optional[cute.Tensor],
softmax_scale: Float32,
stream: cuda.CUstream,
mCuSeqlensQ: Optional[cute.Tensor] = None,
Expand Down Expand Up @@ -337,6 +338,11 @@ def __call__(
if const_expr(mLSE is not None)
else None
)
mMaxLogit = (
cute.make_tensor(mMaxLogit.iterator, cute.select(mMaxLogit.layout, mode=LSE_layout_transpose))
if const_expr(mMaxLogit is not None)
else None
)
# (s, d, h, b) -> (d, s, h, b)
V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2]
mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose))
Expand Down Expand Up @@ -500,6 +506,10 @@ def __call__(
mLSE = cute.make_tensor(
mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
)
if const_expr(mMaxLogit is not None):
mMaxLogit = cute.make_tensor(
mMaxLogit.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed)
)

self.tma_copy_bytes = {
name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2]))
Expand Down Expand Up @@ -718,6 +728,7 @@ class SharedStorage:
mV,
mO,
mLSE,
mMaxLogit,
mCuSeqlensQ,
mCuSeqlensK,
mSeqUsedQ,
Expand Down Expand Up @@ -764,6 +775,7 @@ def kernel(
mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table
mO: cute.Tensor,
mLSE: Optional[cute.Tensor],
mMaxLogit: Optional[cute.Tensor],
mCuSeqlensQ: Optional[cute.Tensor],
mCuSeqlensK: Optional[cute.Tensor],
mSeqUsedQ: Optional[cute.Tensor],
Expand Down Expand Up @@ -1156,6 +1168,7 @@ def kernel(
thr_mma_qk=thr_mma_qk,
sScale=sScale,
mLSE=mLSE,
mMaxLogit=mMaxLogit,
learnable_sink=learnable_sink,
mbar_ptr=mbar_ptr,
block_info=block_info,
Expand Down Expand Up @@ -1206,6 +1219,7 @@ def kernel(
sScale,
mO,
mLSE,
mMaxLogit,
sO,
learnable_sink,
gmem_tiled_copy_O,
Expand Down Expand Up @@ -2203,6 +2217,7 @@ def softmax_loop(
tStSi: cute.Tensor,
sScale: cute.Tensor,
mLSE: Optional[cute.Tensor],
mMaxLogit: Optional[cute.Tensor],
learnable_sink: Optional[cute.Tensor],
mbar_ptr: cute.Pointer,
block_info: BlockInfo,
Expand Down Expand Up @@ -2348,7 +2363,7 @@ def softmax_loop(

softmax = SoftmaxSm100.create(
softmax_scale_log2,
rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0,
rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) and const_expr(mMaxLogit is None) else 0.0,
softmax_scale=softmax_scale,
)
softmax.reset()
Expand Down Expand Up @@ -2833,6 +2848,7 @@ def correction_loop(
sScale: cute.Tensor,
mO: cute.Tensor,
mLSE: cute.Tensor,
mMaxLogit: cute.Tensor,
sO: cute.Tensor,
learnable_sink: Optional[cute.Tensor],
gmem_tiled_copy_O: cute.TiledCopy,
Expand Down Expand Up @@ -2979,6 +2995,19 @@ def correction_loop(
else:
row_max = None
cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage)

if const_expr(mMaxLogit is not None):
if const_expr(not seqlen.has_cu_seqlens_q):
mRM_cur = mMaxLogit[None, head_idx, batch_idx]
else:
mRM_cur = cute.domain_offset((seqlen.offset_q,), mMaxLogit[None, head_idx])

gRM = cute.local_tile(mRM_cur, (self.m_block_size,), (self.q_stage * m_block + stage,))

# 边界检查并写回
if tidx < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size:
gRM[tidx] = row_max

if const_expr(learnable_sink is not None):
LOG2_E = math.log2(math.e)
sink_val = learnable_sink_val[stage]
Expand Down
46 changes: 40 additions & 6 deletions flashmask/flash_mask/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ def _flash_attn_fwd(
mask_mod: Optional[Callable] = None,
block_sparse_tensors: Optional[BlockSparseTensorsPaddle] = None,
return_lse: bool = False,
return_max_logit: bool = False,
out: Optional[paddle.Tensor] = None,
lse: Optional[paddle.Tensor] = None,
max_logit: Optional[paddle.Tensor] = None,
aux_tensors: Optional[list[paddle.Tensor]] = None,
startend_row_indices: Optional[paddle.Tensor] = None,
) -> Tuple[paddle.Tensor, paddle.Tensor]:
Expand All @@ -129,6 +131,7 @@ def _flash_attn_fwd(
mask_mod: A callable that takes token position information and selectively masks
block_sparse_tensors: A tuple of tensors used for block sparsity.
return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate
return_max_logit: Whether to return the maximum logit of the attention scores.
out: Optional pre-allocated output tensor. If None, will be allocated internally.
lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed.
aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel.
Expand Down Expand Up @@ -284,6 +287,9 @@ def _flash_attn_fwd(
)
assert lse.place.is_gpu_place(), "lse tensor must be on CUDA device"

if return_max_logit and max_logit is None:
max_logit = paddle.full(shape=lse_shape, fill_value=float('-inf'), dtype=paddle.float32)

dtype = paddle2cute_dtype_map[q.dtype]
(
cu_seqlens_q_tensor,
Expand Down Expand Up @@ -394,6 +400,8 @@ def _flash_attn_fwd(
shape=[num_splits, *q_batch_seqlen_shape, num_head, head_dim_v], dtype=paddle.float32
)
lse_partial = paddle.empty(shape=[num_splits, *lse_shape], dtype=paddle.float32)
if return_max_logit:
max_logit_partial = paddle.empty(shape=[num_splits, *lse_shape], fill_value=float('-inf'), dtype=paddle.float32)

q_tensor, k_tensor, v_tensor, o_tensor = [
from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1)
Expand All @@ -403,13 +411,23 @@ def _flash_attn_fwd(
lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=lse_partial.ndim - 1
)
if return_max_logit:
max_logit_tensor = from_dlpack(max_logit_partial.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=max_logit_partial.ndim - 1
)
elif lse is not None:
lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=lse.ndim - 1
)
if return_max_logit:
max_logit_tensor = from_dlpack(max_logit.detach(), assumed_align=4).mark_layout_dynamic(
leading_dim=max_logit.ndim - 1
)
else:
lse_tensor = None

if not return_max_logit:
max_logit_tensor = None
# hash score and mask mods for compile cache
score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False
mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False
Expand Down Expand Up @@ -469,6 +487,7 @@ def _flash_attn_fwd(
use_block_sparsity,
len(aux_tensors) if aux_tensors is not None else 0,
lse is None,
max_logit is None,
cu_seqlens_q is None,
cu_seqlens_k is None,
seqused_q is None,
Expand All @@ -493,6 +512,7 @@ def _flash_attn_fwd(
if compute_capability == 9:
assert page_table is None, "paged KV not supported on SM 9.0"
assert not is_split_kv, "SplitKV not supported on SM 9.0"
assert not return_max_logit, "return_max_logit not supported on SM 9.0"
# fa_fwd = FlashAttentionForwardSm80(
fa_fwd = FlashAttentionForwardSm90(
dtype,
Expand Down Expand Up @@ -548,6 +568,7 @@ def _flash_attn_fwd(
v_tensor,
o_tensor,
lse_tensor,
max_logit_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
Expand All @@ -568,6 +589,7 @@ def _flash_attn_fwd(
v_tensor,
o_tensor,
lse_tensor,
max_logit_tensor,
softmax_scale,
current_stream,
cu_seqlens_q_tensor,
Expand All @@ -591,7 +613,8 @@ def _flash_attn_fwd(
cu_seqlens_q,
seqused_q,
)
return out, lse
paddle.max(max_logit_partial, axis=0, keepdim=False, out=max_logit)
return out, lse, max_logit


_flash_attn_fwd.compile_cache = {}
Expand Down Expand Up @@ -1632,21 +1655,23 @@ def forward(
softmax_scale: float | None = None,
startend_row_indices: paddle.Tensor | None = None,
block_mask: paddle.Tensor | None = None,
return_max_logit: bool | None = False,
) -> paddle.Tensor | Tuple[paddle.Tensor, paddle.Tensor]:
out, lse = _flash_attn_fwd(
out, lse, max_logit = _flash_attn_fwd(
query,
key,
value,
causal=causal,
softmax_scale=softmax_scale,
return_lse=True,
return_max_logit=return_max_logit,
startend_row_indices=startend_row_indices,
pack_gqa=False,
)
ctx.save_for_backward(query, key, value, startend_row_indices, out, lse)
ctx.softmax_scale = softmax_scale
ctx.causal = causal
return [out, lse]
return [out, lse, max_logit]

@staticmethod
def backward(ctx, dout, *args) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
Expand Down Expand Up @@ -1689,6 +1714,7 @@ def flashmask_attention(
name: str | None = None,
softmax_scale: float | None = None,
block_mask: paddle.Tensor | None = None,
return_max_logit: bool = False,
):
if (
paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4
Expand Down Expand Up @@ -1762,21 +1788,29 @@ def flashmask_attention(
raise ValueError(
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"
)
if return_max_logit:
assert return_softmax_lse, "max_logit requires lse to be calculated. Please set return_softmax_lse=True."

# Note(wusiming): when softmax_scale is None, it will be set to 1.0 / math.sqrt(head_dim) in _flash_attn_fwd
out, lse = FlashMaskFunc.apply(
res = FlashMaskFunc.apply(
query,
key,
value,
causal=causal,
softmax_scale=softmax_scale,
startend_row_indices=startend_row_indices,
return_max_logit=return_max_logit,
)
if return_softmax_lse:
return [out, lse]
outputs = [res[0], res[1]] # 基础列表 [out, lse]
if return_max_logit:
outputs += [res[2]] # 动态拼接 [out, lse, rowmax]
return outputs
else:
return out
return res[0]

else:
assert return_max_logit == False, "Only flashmask-v4 support return_max_logit"
original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"]
if original_flash_attn_version == 4:
paddle.set_flags({"FLAGS_flash_attn_version": 2})
Expand Down
Loading