From de667f79e290d20af30a909764291cdd9ca8eef0 Mon Sep 17 00:00:00 2001 From: xxyux <1650459510@qq.com> Date: Thu, 5 Feb 2026 17:30:26 +0800 Subject: [PATCH] feat: implement max_logits support for flashmask --- flashmask/flash_mask/cute/flash_fwd_sm100.py | 31 ++- flashmask/flash_mask/cute/interface.py | 46 +++- flashmask/tests/test_flashmask_max_logit.py | 227 +++++++++++++++++++ flashmask/tests/test_flashmask_util.py | 12 + 4 files changed, 309 insertions(+), 7 deletions(-) create mode 100644 flashmask/tests/test_flashmask_max_logit.py diff --git a/flashmask/flash_mask/cute/flash_fwd_sm100.py b/flashmask/flash_mask/cute/flash_fwd_sm100.py index d909008dea0..9d54290d004 100644 --- a/flashmask/flash_mask/cute/flash_fwd_sm100.py +++ b/flashmask/flash_mask/cute/flash_fwd_sm100.py @@ -269,6 +269,7 @@ def __call__( mV: cute.Tensor, # (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, page_size, h_k, dv) if there is page_table mO: cute.Tensor, # (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q mLSE: Optional[cute.Tensor], + mMaxLogit: Optional[cute.Tensor], softmax_scale: Float32, stream: cuda.CUstream, mCuSeqlensQ: Optional[cute.Tensor] = None, @@ -337,6 +338,11 @@ def __call__( if const_expr(mLSE is not None) else None ) + mMaxLogit = ( + cute.make_tensor(mMaxLogit.iterator, cute.select(mMaxLogit.layout, mode=LSE_layout_transpose)) + if const_expr(mMaxLogit is not None) + else None + ) # (s, d, h, b) -> (d, s, h, b) V_layout_transpose = [1, 0, 2, 3] if const_expr(mCuSeqlensK is None) else [1, 0, 2] mV = cute.make_tensor(mV.iterator, cute.select(mV.layout, mode=V_layout_transpose)) @@ -500,6 +506,10 @@ def __call__( mLSE = cute.make_tensor( mLSE.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) ) + if const_expr(mMaxLogit is not None): + mMaxLogit = cute.make_tensor( + mMaxLogit.iterator, cute.make_layout(shape_LSE_packed, stride=stride_LSE_packed) + ) self.tma_copy_bytes = { name: cute.size_in_bytes(mX.element_type, cute.select(layout, mode=[0, 1, 2])) @@ -718,6 +728,7 @@ class SharedStorage: mV, mO, mLSE, + mMaxLogit, mCuSeqlensQ, mCuSeqlensK, mSeqUsedQ, @@ -764,6 +775,7 @@ def kernel( mV: cute.Tensor, # (d, s_k, h_k, b_k) or (d, total_k, h_k) if there is cu_seqlens_k or (d, page_size, h_k, num_pages) if there is page_table mO: cute.Tensor, mLSE: Optional[cute.Tensor], + mMaxLogit: Optional[cute.Tensor], mCuSeqlensQ: Optional[cute.Tensor], mCuSeqlensK: Optional[cute.Tensor], mSeqUsedQ: Optional[cute.Tensor], @@ -1156,6 +1168,7 @@ def kernel( thr_mma_qk=thr_mma_qk, sScale=sScale, mLSE=mLSE, + mMaxLogit=mMaxLogit, learnable_sink=learnable_sink, mbar_ptr=mbar_ptr, block_info=block_info, @@ -1206,6 +1219,7 @@ def kernel( sScale, mO, mLSE, + mMaxLogit, sO, learnable_sink, gmem_tiled_copy_O, @@ -2203,6 +2217,7 @@ def softmax_loop( tStSi: cute.Tensor, sScale: cute.Tensor, mLSE: Optional[cute.Tensor], + mMaxLogit: Optional[cute.Tensor], learnable_sink: Optional[cute.Tensor], mbar_ptr: cute.Pointer, block_info: BlockInfo, @@ -2348,7 +2363,7 @@ def softmax_loop( softmax = SoftmaxSm100.create( softmax_scale_log2, - rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) else 0.0, + rescale_threshold=8.0 if const_expr(self.q_dtype.width == 16) and const_expr(mMaxLogit is None) else 0.0, softmax_scale=softmax_scale, ) softmax.reset() @@ -2833,6 +2848,7 @@ def correction_loop( sScale: cute.Tensor, mO: cute.Tensor, mLSE: cute.Tensor, + mMaxLogit: cute.Tensor, sO: cute.Tensor, learnable_sink: Optional[cute.Tensor], gmem_tiled_copy_O: cute.TiledCopy, @@ -2979,6 +2995,19 @@ def correction_loop( else: row_max = None cute.arch.mbarrier_arrive(mbar_ptr + self.mbar_softmax_corr_empty_offset + stage) + + if const_expr(mMaxLogit is not None): + if const_expr(not seqlen.has_cu_seqlens_q): + mRM_cur = mMaxLogit[None, head_idx, batch_idx] + else: + mRM_cur = cute.domain_offset((seqlen.offset_q,), mMaxLogit[None, head_idx]) + + gRM = cute.local_tile(mRM_cur, (self.m_block_size,), (self.q_stage * m_block + stage,)) + + # 边界检查并写回 + if tidx < seqlen.seqlen_q - (self.q_stage * m_block + stage) * self.m_block_size: + gRM[tidx] = row_max + if const_expr(learnable_sink is not None): LOG2_E = math.log2(math.e) sink_val = learnable_sink_val[stage] diff --git a/flashmask/flash_mask/cute/interface.py b/flashmask/flash_mask/cute/interface.py index ec126b28866..f5728fbf4c3 100644 --- a/flashmask/flash_mask/cute/interface.py +++ b/flashmask/flash_mask/cute/interface.py @@ -116,8 +116,10 @@ def _flash_attn_fwd( mask_mod: Optional[Callable] = None, block_sparse_tensors: Optional[BlockSparseTensorsPaddle] = None, return_lse: bool = False, + return_max_logit: bool = False, out: Optional[paddle.Tensor] = None, lse: Optional[paddle.Tensor] = None, + max_logit: Optional[paddle.Tensor] = None, aux_tensors: Optional[list[paddle.Tensor]] = None, startend_row_indices: Optional[paddle.Tensor] = None, ) -> Tuple[paddle.Tensor, paddle.Tensor]: @@ -129,6 +131,7 @@ def _flash_attn_fwd( mask_mod: A callable that takes token position information and selectively masks block_sparse_tensors: A tuple of tensors used for block sparsity. return_lse: Whether to return the log softmax of the attention scores. If set to True will always calculate + return_max_logit: Whether to return the maximum logit of the attention scores. out: Optional pre-allocated output tensor. If None, will be allocated internally. lse: Optional pre-allocated log-sum-exp tensor. If None, will be allocated when needed. aux_tensors: Some score_mods will want to read from global aux_tensors. This is how we thread them through to the inner kernel. @@ -284,6 +287,9 @@ def _flash_attn_fwd( ) assert lse.place.is_gpu_place(), "lse tensor must be on CUDA device" + if return_max_logit and max_logit is None: + max_logit = paddle.full(shape=lse_shape, fill_value=float('-inf'), dtype=paddle.float32) + dtype = paddle2cute_dtype_map[q.dtype] ( cu_seqlens_q_tensor, @@ -394,6 +400,8 @@ def _flash_attn_fwd( shape=[num_splits, *q_batch_seqlen_shape, num_head, head_dim_v], dtype=paddle.float32 ) lse_partial = paddle.empty(shape=[num_splits, *lse_shape], dtype=paddle.float32) + if return_max_logit: + max_logit_partial = paddle.empty(shape=[num_splits, *lse_shape], fill_value=float('-inf'), dtype=paddle.float32) q_tensor, k_tensor, v_tensor, o_tensor = [ from_dlpack(t.detach(), assumed_align=16).mark_layout_dynamic(leading_dim=t.ndim - 1) @@ -403,13 +411,23 @@ def _flash_attn_fwd( lse_tensor = from_dlpack(lse_partial.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=lse_partial.ndim - 1 ) + if return_max_logit: + max_logit_tensor = from_dlpack(max_logit_partial.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=max_logit_partial.ndim - 1 + ) elif lse is not None: lse_tensor = from_dlpack(lse.detach(), assumed_align=4).mark_layout_dynamic( leading_dim=lse.ndim - 1 ) + if return_max_logit: + max_logit_tensor = from_dlpack(max_logit.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=max_logit.ndim - 1 + ) else: lse_tensor = None + if not return_max_logit: + max_logit_tensor = None # hash score and mask mods for compile cache score_mod_hash = utils.hash_callable(score_mod) if score_mod is not None else False mask_mod_hash = utils.hash_callable(mask_mod) if mask_mod is not None else False @@ -469,6 +487,7 @@ def _flash_attn_fwd( use_block_sparsity, len(aux_tensors) if aux_tensors is not None else 0, lse is None, + max_logit is None, cu_seqlens_q is None, cu_seqlens_k is None, seqused_q is None, @@ -493,6 +512,7 @@ def _flash_attn_fwd( if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" assert not is_split_kv, "SplitKV not supported on SM 9.0" + assert not return_max_logit, "return_max_logit not supported on SM 9.0" # fa_fwd = FlashAttentionForwardSm80( fa_fwd = FlashAttentionForwardSm90( dtype, @@ -548,6 +568,7 @@ def _flash_attn_fwd( v_tensor, o_tensor, lse_tensor, + max_logit_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, @@ -568,6 +589,7 @@ def _flash_attn_fwd( v_tensor, o_tensor, lse_tensor, + max_logit_tensor, softmax_scale, current_stream, cu_seqlens_q_tensor, @@ -591,7 +613,8 @@ def _flash_attn_fwd( cu_seqlens_q, seqused_q, ) - return out, lse + paddle.max(max_logit_partial, axis=0, keepdim=False, out=max_logit) + return out, lse, max_logit _flash_attn_fwd.compile_cache = {} @@ -1632,21 +1655,23 @@ def forward( softmax_scale: float | None = None, startend_row_indices: paddle.Tensor | None = None, block_mask: paddle.Tensor | None = None, + return_max_logit: bool | None = False, ) -> paddle.Tensor | Tuple[paddle.Tensor, paddle.Tensor]: - out, lse = _flash_attn_fwd( + out, lse, max_logit = _flash_attn_fwd( query, key, value, causal=causal, softmax_scale=softmax_scale, return_lse=True, + return_max_logit=return_max_logit, startend_row_indices=startend_row_indices, pack_gqa=False, ) ctx.save_for_backward(query, key, value, startend_row_indices, out, lse) ctx.softmax_scale = softmax_scale ctx.causal = causal - return [out, lse] + return [out, lse, max_logit] @staticmethod def backward(ctx, dout, *args) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]: @@ -1689,6 +1714,7 @@ def flashmask_attention( name: str | None = None, softmax_scale: float | None = None, block_mask: paddle.Tensor | None = None, + return_max_logit: bool = False, ): if ( paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] == 4 @@ -1762,21 +1788,29 @@ def flashmask_attention( raise ValueError( f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}" ) + if return_max_logit: + assert return_softmax_lse, "max_logit requires lse to be calculated. Please set return_softmax_lse=True." # Note(wusiming): when softmax_scale is None, it will be set to 1.0 / math.sqrt(head_dim) in _flash_attn_fwd - out, lse = FlashMaskFunc.apply( + res = FlashMaskFunc.apply( query, key, value, causal=causal, softmax_scale=softmax_scale, startend_row_indices=startend_row_indices, + return_max_logit=return_max_logit, ) if return_softmax_lse: - return [out, lse] + outputs = [res[0], res[1]] # 基础列表 [out, lse] + if return_max_logit: + outputs += [res[2]] # 动态拼接 [out, lse, rowmax] + return outputs else: - return out + return res[0] + else: + assert return_max_logit == False, "Only flashmask-v4 support return_max_logit" original_flash_attn_version = paddle.base.framework.get_flags(["FLAGS_flash_attn_version"])["FLAGS_flash_attn_version"] if original_flash_attn_version == 4: paddle.set_flags({"FLAGS_flash_attn_version": 2}) diff --git a/flashmask/tests/test_flashmask_max_logit.py b/flashmask/tests/test_flashmask_max_logit.py new file mode 100644 index 00000000000..552967f220f --- /dev/null +++ b/flashmask/tests/test_flashmask_max_logit.py @@ -0,0 +1,227 @@ +import os +import math +import itertools +import pytest +from einops import rearrange, repeat +import paddle +from flash_mask.cute.interface import flashmask_attention +from tests.generate_startend_row_indices import ( + startend_row_indices_to_attn_bias, + generate_none_mask, + generate_sliding_window_mask, + generate_causal_document_mask, + generate_document_mask, + generate_share_question_mask, + generate_global_sliding_window_mask, + generate_causal_blockwise_mask, + generate_prefix_lm_document_mask, + generate_prefix_lm_causal_mask, + generate_qk_sparse_mask, + generate_random_eviction_mask, +) +from functools import partial +from tests.test_flashmask_util import attention_ref + +# batch_size, seqlen_q, seqlen_k, nheads, nheads_kv +shape_cases = ( + [ + (2840, 32, 32, 16, 4), + (1, 300, 300, 16, 16), + (2, 8192, 8192, 14, 1), + (2, 16384, 16384, 4, 1), + (1, 128, 127, 1, 1), + (1, 127, 128, 1, 1), + (2, 16383, 16384, 4, 1), + (2, 16384, 16383, 4, 1), + (2, 16383, 16385, 4, 1), + (2, 16385, 16383, 4, 1), + ] + # tridao case + + list(itertools.product( + [9], # batch_size + [1, 64, 128, 256, 239, 799, 113, 113, 128, 113, 108, 256, 384, 640, 512, 1024, 1023, 1024,], # seqlen_q + [128, 192, 256, 203, 128, 217, 211, 256, 512, 256, 128, 256, 1024, 1024, 1023,], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) + + list(itertools.product( + [2], # batch_size + [4096, 4224], # seqlen_q + [4096, 4224], # seqlen_k + [6], # nheads + [6, 2, 1], # nheads_kv + )) +) +# shape_cases = ( +# [ +# (2, 16384, 16384, 4, 1), +# ] +# ) + +# Generate all combinations for second param +def generate_shapes(): + for batch_size, seqlen_q, seqlen_k, nheads, nheads_kv in shape_cases: + nheads_startend_row_indices_values = [1, nheads_kv] + for nheads_startend_row_indices in nheads_startend_row_indices_values: + yield ( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices + ) + +@pytest.mark.parametrize("dtype", [paddle.bfloat16]) +@pytest.mark.parametrize("fa_version", [4]) +@pytest.mark.parametrize("d, dv", + [ + (64, 64), + # (80, 80), + (128, 128), + # (192, 192), + # (256, 256), + ]) +@pytest.mark.parametrize( + "batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, nheads_startend_row_indices", + list(generate_shapes()) +) +@pytest.mark.parametrize( + "gen_startend_row_indices", + [ + partial(generate_none_mask, causal=False), # full + partial(generate_none_mask, causal=True), # causal + partial(generate_sliding_window_mask), # sliding window + partial(generate_causal_document_mask), # causal document mask + partial(generate_document_mask), # document mask + partial(generate_share_question_mask), # share question mask + #partial(generate_global_sliding_window_mask), # global sliding window + partial(generate_causal_blockwise_mask), # causal blockwise mask + partial(generate_prefix_lm_document_mask), # prefix lm document mask + partial(generate_prefix_lm_causal_mask), # prefix lm causal mask + partial(generate_qk_sparse_mask), # qk-sparse mask + partial(generate_random_eviction_mask), # random eviction mask + # ###################################################################################### + ], +) +@pytest.mark.timeout(300) +def test_flashmask( + batch_size, seqlen_q, seqlen_k, nheads, nheads_kv, d, dv, nheads_startend_row_indices, fa_version, dtype, gen_startend_row_indices, softcap=0.0 +): + paddle.seed(2024) + assert nheads % nheads_kv == 0 + q_ref = paddle.randn(shape=[batch_size, seqlen_q, nheads, d], dtype=dtype) + k_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, d], dtype=dtype) + v_ref = paddle.randn(shape=[batch_size, seqlen_k, nheads_kv, dv], dtype=dtype) + + q_ref.stop_gradient = False + k_ref.stop_gradient = False + v_ref.stop_gradient = False + + q_bf16, k_bf16, v_bf16 = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q_bf16.stop_gradient = False + k_bf16.stop_gradient = False + v_bf16.stop_gradient = False + + q, k, v = [x.detach().clone() for x in (q_ref, k_ref, v_ref)] + + q.stop_gradient = False + k.stop_gradient = False + v.stop_gradient = False + + startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) + + attn_bias = startend_row_indices_to_attn_bias(startend_row_indices, seqlen_q, nheads, dtype, causal) + + out_ref, attn_ref, max_logit_ref = attention_ref( + q_ref, + k_ref, + v_ref, + causal=causal, + attn_bias=attn_bias, + return_max_logit=True + ) + + out_bf16, attn_bf16 = attention_ref( + q_bf16, + k_bf16, + v_bf16, + causal=causal, + attn_bias=attn_bias, + upcast=False, + reorder_ops=True + ) + + # # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + assert softcap == 0.0 + rtol = 2 if softcap == 0.0 else 3 + + print(f"Paddle naive bf16 Output max diff: {(out_bf16 - out_ref).abs().max().item()}") + print(f"Paddle naive bf16 Output mean diff: {(out_bf16 - out_ref).abs().mean().item()}") + + if fa_version == 2: + paddle.set_flags({'FLAGS_flash_attn_version': 2}) + elif fa_version == 3: + paddle.set_flags({'FLAGS_flash_attn_version': 3}) + elif fa_version == 4: + paddle.set_flags({'FLAGS_flash_attn_version': 4}) + else: + raise ValueError( + f"Invalid flash attention version: {fa_version}" + ) + + out, lse, max_logit = flashmask_attention( + q, + k, + v, + startend_row_indices=startend_row_indices, + causal=causal, + return_softmax_lse=True, + return_max_logit=True + ) + + print(f"flashmask Output max diff: {(out - out_ref).abs().max().item()}") + print(f"flashmask Output mean diff: {(out - out_ref).abs().mean().item()}") + + assert (out - out_ref).abs().max().item() <= rtol * (out_bf16 - out_ref).abs().max().item() + fwd_atol + + g = paddle.randn(shape=out.shape, dtype=out.dtype) + out.backward(g) + out_ref.backward(g) + out_bf16.backward(g) + + print(f"flashmask dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}") + print(f"flashmask dK max diff: {(k.grad - k_ref.grad).abs().max().item()}") + print(f"flashmask dV max diff: {(v.grad - v_ref.grad).abs().max().item()}") + print(f"flashmask dQ mean diff: {(q.grad - q_ref.grad).abs().mean().item()}") + print(f"flashmask dK mean diff: {(k.grad - k_ref.grad).abs().mean().item()}") + print(f"flashmask dV mean diff: {(v.grad - v_ref.grad).abs().mean().item()}") + + print(f"Paddle naive bf16 dQ max diff: {(q_bf16.grad - q_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dK max diff: {(k_bf16.grad - k_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dV max diff: {(v_bf16.grad - v_ref.grad).abs().max().item()}") + print(f"Paddle naive bf16 dQ mean diff: {(q_bf16.grad - q_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dK mean diff: {(k_bf16.grad - k_ref.grad).abs().mean().item()}") + print(f"Paddle naive bf16 dV mean diff: {(v_bf16.grad - v_ref.grad).abs().mean().item()}") + + dq_atol = 2 * (q_ref.grad + 0.3 - 0.3 - q_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (q.grad - q_ref.grad).abs().max().item() <= rtol * (q_bf16.grad - q_ref.grad).abs().max().item() + dq_atol + dk_atol = 2 * (k_ref.grad + 0.3 - 0.3 - k_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (k.grad - k_ref.grad).abs().max().item() <= rtol * (k_bf16.grad - k_ref.grad).abs().max().item() + dk_atol + dv_atol = 2 * (v_ref.grad + 0.3 - 0.3 - v_ref.grad).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (v.grad - v_ref.grad).abs().max().item() <= rtol * (v_bf16.grad - v_ref.grad).abs().max().item() + dv_atol + + print("--- Verifying Max Logit ---") + softmax_scale = 1.0 / math.sqrt(d) + max_logit_ref_unscaled = max_logit_ref / softmax_scale + + mask_threshold = -1e9 + + valid_mask = (max_logit > mask_threshold) & (max_logit_ref_unscaled > mask_threshold) + + diff = (max_logit - max_logit_ref_unscaled).abs() + masked_diff = paddle.where(valid_mask, diff, paddle.zeros_like(diff)) + + max_logit_diff = masked_diff.max().item() + + print(f"flashmask Max Logit max diff: {max_logit_diff}") + + # 6. 断言 + assert max_logit_diff <= 1e-4, f"Max Logit mismatch! Diff: {max_logit_diff}" \ No newline at end of file diff --git a/flashmask/tests/test_flashmask_util.py b/flashmask/tests/test_flashmask_util.py index 3ced1714339..3674778553d 100644 --- a/flashmask/tests/test_flashmask_util.py +++ b/flashmask/tests/test_flashmask_util.py @@ -73,6 +73,7 @@ def attention_ref( upcast=True, reorder_ops=False, intermediate_dtype=None, + return_max_logit=False, ): """ Arguments: @@ -186,6 +187,13 @@ def attention_ref( all_inf_mask = (attn_bias == -np.inf).all(axis=-1, keepdim=True) scores = paddle.where(all_inf_mask, paddle.full_like(scores, -1e9), scores) + ref_max_logit = None + if return_max_logit: + # scores_max = scores[...,-128:] + scores_max = scores.clone() + # 此时 scores 已经应用了所有 mask + ref_max_logit = scores_max.max(axis=-1) + attention = paddle.nn.functional.softmax(scores, axis=-1).cast(v.dtype) if attn_bias is not None: @@ -221,6 +229,10 @@ def attention_ref( output = paddle.transpose(output, [0, 2, 1, 3]) if query_padding_mask is not None: output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0) + + if return_max_logit: + return output.cast(dtype=dtype_og), attention.cast(dtype=dtype_og), ref_max_logit + return output.cast(dtype=dtype_og), attention.cast(dtype=dtype_og)