From 3b585984d98067e763aec6a0fcd53a54a31844fe Mon Sep 17 00:00:00 2001 From: umiswing Date: Fri, 15 May 2026 19:52:00 +0800 Subject: [PATCH 1/9] [Cute,Sm100] allow for zero length sequences in hdim 256 kernels pick branch jshah/hdim256-varlen-zero-lengths, commit: 75db52f --- .../flash_mask/flash_attn_v4/copy_utils.py | 89 +------ .../sm100_hd256_2cta_fmha_backward.py | 4 + ...100_hd256_2cta_fmha_backward_dkdvkernel.py | 48 ++-- ...sm100_hd256_2cta_fmha_backward_dqkernel.py | 251 ++++++++++-------- .../sm100_hd256_2cta_fmha_forward.py | 23 +- 5 files changed, 207 insertions(+), 208 deletions(-) diff --git a/flashmask/flash_mask/flash_attn_v4/copy_utils.py b/flashmask/flash_mask/flash_attn_v4/copy_utils.py index 7b7b86eb4a5..d8c6083c8cc 100644 --- a/flashmask/flash_mask/flash_attn_v4/copy_utils.py +++ b/flashmask/flash_mask/flash_attn_v4/copy_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math from typing import Optional, Type, Callable @@ -90,21 +90,19 @@ def tiled_copy_1d( def tiled_copy_2d( - dtype: Type[cutlass.Numeric], - threads_per_row: int, - num_threads: int, - num_copy_elems: int = 1, - is_async: bool = False, + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False ) -> cute.TiledCopy: - num_copy_bits = num_copy_elems * dtype.width + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) - assert num_threads % threads_per_row == 0 + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 thr_layout = cute.make_ordered_layout( - (num_threads // threads_per_row, threads_per_row), + (num_threads // gmem_threads_per_row, gmem_threads_per_row), order=(1, 0), ) - val_layout = cute.make_layout((1, num_copy_elems)) + val_layout = cute.make_layout((1, copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) @@ -372,74 +370,3 @@ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwarg ) return copy_fn - - -# --- Vendored from quack/copy_utils.py (BSD-3, Tri Dao et al.) --- -# Required by SM100 code paths that cannot depend on quack. - -BIG_INT = 2**30 -MAX_INT = 2**31 - 1 -BIG_INT_INV = 2**64 // BIG_INT - - -@dsl_user_op -def create_ragged_tensor_for_tma( - T: cute.Tensor, - ragged_dim: int = 0, - ptr_shift: bool = False, - *, - loc=None, - ip=None, -) -> cute.Tensor: - rank = cute.rank(T) - if ragged_dim < 0: - ragged_dim += rank - if ptr_shift: - assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions" - new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,) - new_stride = T.stride + (T.stride[ragged_dim],) - ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1) - new_ptr = cute.domain_offset(ptr_offset, T).iterator - return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride)) - else: - assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions" - stride_r = T.stride[ragged_dim] - new_shape = ( - T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT) - ) - new_stride = ( - T.stride[:ragged_dim] - + (stride_r,) - + T.stride[ragged_dim + 1 :] - + (BIG_INT_INV - stride_r, stride_r) - ) - return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride)) - - -@dsl_user_op -def offset_ragged_tensor( - T: cute.Tensor, - offset: Int32, - length: Int32, - ragged_dim: int = 0, - ptr_shift: bool = False, - *, - loc=None, - ip=None, -) -> cute.Tensor: - rank = cute.rank(T) - if ragged_dim < 0: - ragged_dim += rank - big_int = cute.size(T, mode=[ragged_dim]) - offset_val = big_int - length - if ptr_shift: - # 1-extra-dim: rank = original_rank + 1 - assert rank >= ragged_dim + 2 - offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2) - index_tuple = (None,) * (rank - 1) + (offset + length,) - else: - # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims - assert rank >= ragged_dim + 3 - offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3) - index_tuple = (None,) * (rank - 2) + (big_int, offset + length) - return cute.domain_offset(offset_tuple, T[index_tuple]) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py index 915aed0d25b..3b6cf2f462f 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py @@ -23,6 +23,8 @@ from flash_mask.flash_attn_v4.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_mask.flash_attn_v4.cute_dsl_utils import assume_tensor_aligned + def _as_bshkrd_tensor( tensor: cute.Tensor, @@ -252,6 +254,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 29b8f6750ad..2fa89fd9ef8 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -2811,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2825,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py index d5be6b40f4d..0774fd23385 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_mask.flash_attn_v4.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -924,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1057,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1197,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1204,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1836,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1932,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1952,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2181,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2274,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py index 1ab8879375d..0c7b005f811 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx From 3b97140db8793af57cfefd7c567f1fc6e8b95199 Mon Sep 17 00:00:00 2001 From: umiswing Date: Mon, 18 May 2026 16:14:27 +0800 Subject: [PATCH 2/9] fix: use quack tiled_copy_2d in bwd postprocess --- .../flash_mask/flash_attn_v4/copy_utils.py | 19 +++++++++++++++++++ .../flash_attn_v4/flash_bwd_postprocess.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/flashmask/flash_mask/flash_attn_v4/copy_utils.py b/flashmask/flash_mask/flash_attn_v4/copy_utils.py index d8c6083c8cc..04f9bb15ed5 100644 --- a/flashmask/flash_mask/flash_attn_v4/copy_utils.py +++ b/flashmask/flash_mask/flash_attn_v4/copy_utils.py @@ -105,6 +105,25 @@ def tiled_copy_2d( val_layout = cute.make_layout((1, copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) +# TODO(umiswing): Support calling the Quack kernel directly +# from Paddle rather than duplicating it here. +def quack_tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) @dsl_user_op def atomic_add_fp32x4( diff --git a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py index 2f7d2af50ca..ffe2384943b 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py @@ -177,7 +177,7 @@ def _setup_attributes(self): num_copy_elems = 128 // self.dtype.width threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems - self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.gmem_tiled_copy_dQ = copy_utils.quack_tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) # /////////////////////////////////////////////////////////////////////////////// From 6ea6e56e129df41696db423a69a2466d03ef0317 Mon Sep 17 00:00:00 2001 From: umiswing Date: Mon, 18 May 2026 20:27:45 +0800 Subject: [PATCH 3/9] fix fwd quack tiled_copy_2d --- flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py index e46348ce6bb..d14f916c5f4 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py @@ -579,7 +579,7 @@ def __call__( async_copy_elems = 128 // self.q_dtype.width num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) - gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + gmem_tiled_copy_Q = copy_utils.quack_tiled_copy_2d( self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True ) From 2f2a41af219daf91f3ab75da75e94571ad94045e Mon Sep 17 00:00:00 2001 From: umiswing Date: Tue, 19 May 2026 22:08:40 +0800 Subject: [PATCH 4/9] refine --- .../flash_mask/flash_attn_v4/copy_utils.py | 73 ++++++++++++++++++- 1 file changed, 72 insertions(+), 1 deletion(-) diff --git a/flashmask/flash_mask/flash_attn_v4/copy_utils.py b/flashmask/flash_mask/flash_attn_v4/copy_utils.py index 04f9bb15ed5..b6e5c6911fe 100644 --- a/flashmask/flash_mask/flash_attn_v4/copy_utils.py +++ b/flashmask/flash_mask/flash_attn_v4/copy_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. +# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. import math from typing import Optional, Type, Callable @@ -389,3 +389,74 @@ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwarg ) return copy_fn + + +# --- Vendored from quack/copy_utils.py (BSD-3, Tri Dao et al.) --- +# Required by SM100 code paths that cannot depend on quack. + +BIG_INT = 2**30 +MAX_INT = 2**31 - 1 +BIG_INT_INV = 2**64 // BIG_INT + + +@dsl_user_op +def create_ragged_tensor_for_tma( + T: cute.Tensor, + ragged_dim: int = 0, + ptr_shift: bool = False, + *, + loc=None, + ip=None, +) -> cute.Tensor: + rank = cute.rank(T) + if ragged_dim < 0: + ragged_dim += rank + if ptr_shift: + assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions" + new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,) + new_stride = T.stride + (T.stride[ragged_dim],) + ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1) + new_ptr = cute.domain_offset(ptr_offset, T).iterator + return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride)) + else: + assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions" + stride_r = T.stride[ragged_dim] + new_shape = ( + T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT) + ) + new_stride = ( + T.stride[:ragged_dim] + + (stride_r,) + + T.stride[ragged_dim + 1 :] + + (BIG_INT_INV - stride_r, stride_r) + ) + return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride)) + + +@dsl_user_op +def offset_ragged_tensor( + T: cute.Tensor, + offset: Int32, + length: Int32, + ragged_dim: int = 0, + ptr_shift: bool = False, + *, + loc=None, + ip=None, +) -> cute.Tensor: + rank = cute.rank(T) + if ragged_dim < 0: + ragged_dim += rank + big_int = cute.size(T, mode=[ragged_dim]) + offset_val = big_int - length + if ptr_shift: + # 1-extra-dim: rank = original_rank + 1 + assert rank >= ragged_dim + 2 + offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2) + index_tuple = (None,) * (rank - 1) + (offset + length,) + else: + # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims + assert rank >= ragged_dim + 3 + offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3) + index_tuple = (None,) * (rank - 2) + (big_int, offset + length) + return cute.domain_offset(offset_tuple, T[index_tuple]) From c5430b1fb03fd84f85ba125c757da52113be32f4 Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 20 May 2026 16:07:37 +0800 Subject: [PATCH 5/9] replace output padding set zero with padding doc at tail (zero-length-doc) this fix the bug of output padding method when causal=False. when causal=False, set zero to output padding can not mask out the padding region of attn score at column dimension, while the causal=True branch happen to bypass this problem --- flashmask/flash_mask/interface.py | 100 ++++++++++++++++++------------ 1 file changed, 59 insertions(+), 41 deletions(-) diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index 02c471b8999..ca7759871c3 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -100,12 +100,12 @@ def convert_to_varlen( # Real ends per batch real_ends_np = s_np.max(axis=1) # (b,) - real_ends = real_ends_np.tolist() - needs_padding_fixup = bool(np.any(real_ends_np < skv)) # Per-batch boundary extraction (b is typically 1-2, loop is trivial) - cu_seqlens_list = [] - max_doc_len = 0 + cu_seqlens_q_list = [] + cu_seqlens_k_list = [] + max_doc_len_q = 0 + max_doc_len_k = 0 for bi in range(b): change_idx = np.nonzero(diffs[bi])[0].astype(np.int32) + 1 boundaries = np.concatenate([ @@ -113,15 +113,59 @@ def convert_to_varlen( change_idx, np.array([skv], dtype=np.int32), ]) - doc_lens = boundaries[1:] - boundaries[:-1] - max_doc_len = max(max_doc_len, int(doc_lens.max())) - cu_seqlens_list.append(boundaries[:-1] + np.int32(bi * skv)) + real_end = int(real_ends_np[bi]) + + if real_end < skv: + # Has padding: only include real documents, then add two + # padding documents at tail. + # Find the boundary index where padding starts + real_boundaries = boundaries[boundaries <= real_end] + if real_boundaries[-1] != real_end: + real_boundaries = np.concatenate([ + real_boundaries, np.array([real_end], dtype=np.int32) + ]) + + doc_lens = real_boundaries[1:] - real_boundaries[:-1] + max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) + max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) + + pad_len = skv - real_end + + # cu_seqlens_q: [...real_doc_starts, real_end, real_end, real_end+pad_len] + # -> docs: real docs, then (0 seqlen_q, pad_len seqlen_k), then (pad_len seqlen_q, 0 seqlen_k) + # cu_seqlens_k: [...real_doc_starts, real_end, real_end+pad_len, real_end+pad_len] + offset = np.int32(bi * skv) + cu_seqlens_q_list.append(np.concatenate([ + real_boundaries[:-1] + offset, + np.array([real_end, real_end], dtype=np.int32) + offset, + ])) + cu_seqlens_k_list.append(np.concatenate([ + real_boundaries[:-1] + offset, + np.array([real_end, real_end + pad_len], dtype=np.int32) + offset, + ])) + + print(f"wsm debug {cu_seqlens_q_list=}") + print(f"wsm debug {cu_seqlens_k_list=}") + + max_doc_len_k = max(max_doc_len_k, pad_len) + max_doc_len_q = max(max_doc_len_q, pad_len) + else: + # No padding + doc_lens = boundaries[1:] - boundaries[:-1] + max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) + max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) + cu_seqlens_q_list.append(boundaries[:-1] + np.int32(bi * skv)) + cu_seqlens_k_list.append(boundaries[:-1] + np.int32(bi * skv)) # Build cu_seqlens: one numpy concat, one paddle tensor creation - cu_seqlens_np = np.concatenate( - cu_seqlens_list + [np.array([b * skv], dtype=np.int32)] + cu_seqlens_q_np = np.concatenate( + cu_seqlens_q_list + [np.array([b * skv], dtype=np.int32)] + ) + cu_seqlens_k_np = np.concatenate( + cu_seqlens_k_list + [np.array([b * skv], dtype=np.int32)] ) - cu_seqlens = paddle.to_tensor(cu_seqlens_np) + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) + cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) q_varlen = query.reshape([b * sq, hq, d]) @@ -142,36 +186,13 @@ def convert_to_varlen( "q": q_varlen, "k": k_varlen, "v": v_varlen, - "cu_seqlens_q": cu_seqlens, - "cu_seqlens_k": cu_seqlens, - "max_seqlen_q": max_doc_len, - "max_seqlen_k": max_doc_len, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_doc_len_q, + "max_seqlen_k": max_doc_len_k, "causal": varlen_causal, } - # For non-causal masks with trailing padding: padding rows attend to - # nothing in flashmask (zero output), but varlen computes non-zero - # output for them. Zero out padding rows per batch to match flashmask. - # Note: real_end can differ across batch items. - if needs_padding_fixup: - _b, _sq = b, sq - _real_ends = real_ends - - def output_to_padded(out_varlen_pt): - nh = out_varlen_pt.shape[1] - dv_out = out_varlen_pt.shape[2] - out_padded = out_varlen_pt.reshape(_b, _sq, nh, dv_out) - # Vectorised per-batch zeroing - row_idx = paddle.arange(_sq) - real_end_t = paddle.to_tensor( - _real_ends, dtype=paddle.int64, - ).unsqueeze(1) - padding_mask = row_idx.unsqueeze(0) >= real_end_t # (b, sq) - out_padded[padding_mask] = 0 - return out_padded - - result["output_to_padded"] = output_to_padded - return result def flashmask_attention( @@ -243,10 +264,7 @@ def flashmask_attention( return_lse=return_softmax_lse, ) - if "output_to_padded" in varlen_args: - out = varlen_args["output_to_padded"](out) - else: - out = out.reshape([batch_size, seqlen_q, nheads, dv]) + out = out.reshape([batch_size, seqlen_q, nheads, dv]) if return_softmax_lse: return [out, lse] From 18a3763ca983e187f385234b022c6fc9f68db660 Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 20 May 2026 18:13:15 +0800 Subject: [PATCH 6/9] remove debug print --- flashmask/flash_mask/interface.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index ca7759871c3..c2376fa5bf0 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -144,9 +144,6 @@ def convert_to_varlen( np.array([real_end, real_end + pad_len], dtype=np.int32) + offset, ])) - print(f"wsm debug {cu_seqlens_q_list=}") - print(f"wsm debug {cu_seqlens_k_list=}") - max_doc_len_k = max(max_doc_len_k, pad_len) max_doc_len_q = max(max_doc_len_q, pad_len) else: From 02d412f8fa45e47076ecec88b4d2bfc6e7774c5b Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 20 May 2026 19:22:34 +0800 Subject: [PATCH 7/9] fix bwd preprocess quack tiled_copy_2d --- flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py b/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py index 76946556d4c..13f9c5c1da5 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py @@ -111,7 +111,7 @@ def _setup_attributes(self): ) num_copy_elems = 128 // self.dtype.width threads_per_row = gmem_k_block_size // num_copy_elems - self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( + self.gmem_tiled_copy_O = copy_utils.quack_tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) universal_copy_bits = 128 @@ -367,4 +367,4 @@ def kernel( gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.tile_m: - gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 \ No newline at end of file + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 From d435b7436f91669e0e30903d600c7ed6b09a60df Mon Sep 17 00:00:00 2001 From: umiswing Date: Wed, 20 May 2026 21:44:12 +0800 Subject: [PATCH 8/9] refine convert_to_varlen --- flashmask/flash_mask/interface.py | 68 +++++++++++++------------------ 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index c2376fa5bf0..84dfd9fc0cc 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -76,15 +76,13 @@ flash_attn_combine = _mod.flash_attn_combine def convert_to_varlen( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, + batch_size: int, + seqlen_q: int, + seqlen_kv: int, startend_row_indices: paddle.Tensor, causal: bool, ): - b, sq, hq, d = query.shape - _, skv, hkv, dv = value.shape - assert sq == skv + assert seqlen_q == seqlen_kv _, hfm, _, bound_num = startend_row_indices.shape assert hfm == 1 @@ -95,27 +93,27 @@ def convert_to_varlen( s_np = sri_np[:, 0, :, 0] # (batch, seqlen_k) # ── Vectorised boundary detection in numpy ── - # Compare consecutive elements per batch: (b, skv-1) + # Compare consecutive elements per batch: (batch_size, seqlen_kv-1) diffs = s_np[:, 1:] != s_np[:, :-1] # Real ends per batch - real_ends_np = s_np.max(axis=1) # (b,) + real_ends_np = s_np.max(axis=1) # (batch_size,) - # Per-batch boundary extraction (b is typically 1-2, loop is trivial) + # Per-batch boundary extraction (batch_size is typically 1-2, loop is trivial) cu_seqlens_q_list = [] cu_seqlens_k_list = [] max_doc_len_q = 0 max_doc_len_k = 0 - for bi in range(b): + for bi in range(batch_size): change_idx = np.nonzero(diffs[bi])[0].astype(np.int32) + 1 boundaries = np.concatenate([ np.zeros(1, dtype=np.int32), change_idx, - np.array([skv], dtype=np.int32), + np.array([seqlen_kv], dtype=np.int32), ]) real_end = int(real_ends_np[bi]) - if real_end < skv: + if real_end < seqlen_kv: # Has padding: only include real documents, then add two # padding documents at tail. # Find the boundary index where padding starts @@ -129,12 +127,12 @@ def convert_to_varlen( max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) - pad_len = skv - real_end + pad_len = seqlen_kv - real_end # cu_seqlens_q: [...real_doc_starts, real_end, real_end, real_end+pad_len] # -> docs: real docs, then (0 seqlen_q, pad_len seqlen_k), then (pad_len seqlen_q, 0 seqlen_k) # cu_seqlens_k: [...real_doc_starts, real_end, real_end+pad_len, real_end+pad_len] - offset = np.int32(bi * skv) + offset = np.int32(bi * seqlen_kv) cu_seqlens_q_list.append(np.concatenate([ real_boundaries[:-1] + offset, np.array([real_end, real_end], dtype=np.int32) + offset, @@ -151,38 +149,30 @@ def convert_to_varlen( doc_lens = boundaries[1:] - boundaries[:-1] max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) - cu_seqlens_q_list.append(boundaries[:-1] + np.int32(bi * skv)) - cu_seqlens_k_list.append(boundaries[:-1] + np.int32(bi * skv)) + cu_seqlens_q_list.append(boundaries[:-1] + np.int32(bi * seqlen_kv)) + cu_seqlens_k_list.append(boundaries[:-1] + np.int32(bi * seqlen_kv)) # Build cu_seqlens: one numpy concat, one paddle tensor creation cu_seqlens_q_np = np.concatenate( - cu_seqlens_q_list + [np.array([b * skv], dtype=np.int32)] + cu_seqlens_q_list + [np.array([batch_size * seqlen_kv], dtype=np.int32)] ) cu_seqlens_k_np = np.concatenate( - cu_seqlens_k_list + [np.array([b * skv], dtype=np.int32)] + cu_seqlens_k_list + [np.array([batch_size * seqlen_kv], dtype=np.int32)] ) cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) - # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) - q_varlen = query.reshape([b * sq, hq, d]) - k_varlen = key.reshape([b * skv, hkv, d]) - v_varlen = value.reshape([b * skv, hkv, dv]) - # ── Detect simulated causal masks (numpy) ──────────────────────── varlen_causal = causal if not causal and bound_num == 2: - lts_all = s_np # (b, skv), already extracted - ute_all = sri_np[:, 0, :, 1] # (b, skv) - arange_ref = np.arange(skv, dtype=np.int32).reshape(1, skv) + lts_all = s_np # (batch_size, seqlen_kv), already extracted + ute_all = sri_np[:, 0, :, 1] # (batch_size, seqlen_kv) + arange_ref = np.arange(seqlen_kv, dtype=np.int32).reshape(1, seqlen_kv) expected_causal_ute = np.minimum(arange_ref, lts_all) if np.array_equal(ute_all, expected_causal_ute): varlen_causal = True result = { - "q": q_varlen, - "k": k_varlen, - "v": v_varlen, "cu_seqlens_q": cu_seqlens_q, "cu_seqlens_k": cu_seqlens_k, "max_seqlen_q": max_doc_len_q, @@ -228,10 +218,10 @@ def flashmask_attention( ) batch_size, seqlen_q, nheads, d = query.shape - seqlen_k = key.shape[1] - assert seqlen_q == seqlen_k, ( - f"use_varlen requires seqlen_q == seqlen_k, ", - f"currently seqlen_q={seqlen_q}, seqlen_k={seqlen_k}", + _, seqlen_kv, nheads_kv, dv = value.shape + assert seqlen_q == seqlen_kv, ( + f"use_varlen requires seqlen_q == seqlen_kv, ", + f"currently seqlen_q={seqlen_q}, seqlen_kv={seqlen_kv}", ) dv = value.shape[-1] @@ -243,15 +233,15 @@ def flashmask_attention( ) varlen_args = convert_to_varlen( - query=query, - key=key, - value=value, + batch_size=batch_size, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, startend_row_indices=startend_row_indices, causal=causal) out, lse = flash_attn_varlen_func( - q=varlen_args["q"], - k=varlen_args["k"], - v=varlen_args["v"], + q=query.reshape(batch_size * seqlen_q, nheads, d), + k=key.reshape(batch_size * seqlen_kv, nheads_kv, d), + v=value.reshape(batch_size * seqlen_kv, nheads_kv, dv), cu_seqlens_q=varlen_args["cu_seqlens_q"], cu_seqlens_k=varlen_args["cu_seqlens_k"], max_seqlen_q=varlen_args["max_seqlen_q"], From 9fdfe28e41b05f6b227180ae9ef66fda4482d6df Mon Sep 17 00:00:00 2001 From: umiswing Date: Thu, 21 May 2026 14:50:19 +0800 Subject: [PATCH 9/9] refine and split the convert for more flexible use --- flashmask/flash_mask/interface.py | 55 ++++++++++++++++++++++++------- 1 file changed, 44 insertions(+), 11 deletions(-) diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index 84dfd9fc0cc..af4307fa197 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -75,22 +75,22 @@ flash_attn_varlen_func = _mod.flash_attn_varlen_func flash_attn_combine = _mod.flash_attn_combine -def convert_to_varlen( +def convert_to_varlen_np( batch_size: int, seqlen_q: int, seqlen_kv: int, - startend_row_indices: paddle.Tensor, + startend_row_indices: np.ndarray, causal: bool, ): + assert isinstance(startend_row_indices, np.ndarray), ( + f"Expected np.ndarray, got {type(startend_row_indices)}" + ) assert seqlen_q == seqlen_kv _, hfm, _, bound_num = startend_row_indices.shape assert hfm == 1 assert bound_num == 1 or bound_num == 2 - - # ── Move startend_row_indices to numpy (single GPU→CPU transfer) ── - sri_np = startend_row_indices.numpy() # (batch, hfm, seqlen_k, bound_num) - s_np = sri_np[:, 0, :, 0] # (batch, seqlen_k) + s_np = startend_row_indices[:, 0, :, 0] # (batch, seqlen_k) # ── Vectorised boundary detection in numpy ── # Compare consecutive elements per batch: (batch_size, seqlen_kv-1) @@ -159,22 +159,20 @@ def convert_to_varlen( cu_seqlens_k_np = np.concatenate( cu_seqlens_k_list + [np.array([batch_size * seqlen_kv], dtype=np.int32)] ) - cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) - cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) # ── Detect simulated causal masks (numpy) ──────────────────────── varlen_causal = causal if not causal and bound_num == 2: lts_all = s_np # (batch_size, seqlen_kv), already extracted - ute_all = sri_np[:, 0, :, 1] # (batch_size, seqlen_kv) + ute_all = startend_row_indices[:, 0, :, 1] # (batch_size, seqlen_kv) arange_ref = np.arange(seqlen_kv, dtype=np.int32).reshape(1, seqlen_kv) expected_causal_ute = np.minimum(arange_ref, lts_all) if np.array_equal(ute_all, expected_causal_ute): varlen_causal = True result = { - "cu_seqlens_q": cu_seqlens_q, - "cu_seqlens_k": cu_seqlens_k, + "cu_seqlens_q": cu_seqlens_q_np, + "cu_seqlens_k": cu_seqlens_k_np, "max_seqlen_q": max_doc_len_q, "max_seqlen_k": max_doc_len_k, "causal": varlen_causal, @@ -182,6 +180,41 @@ def convert_to_varlen( return result +def convert_to_varlen( + batch_size: int, + seqlen_q: int, + seqlen_kv: int, + startend_row_indices: paddle.Tensor, + causal: bool, +): + assert seqlen_q == seqlen_kv + + _, hfm, _, bound_num = startend_row_indices.shape + assert hfm == 1 + assert bound_num == 1 or bound_num == 2 + + # ── Move startend_row_indices to numpy (single GPU→CPU transfer) ── + sri_np = startend_row_indices.numpy() # (batch, hfm, seqlen_k, bound_num) + + result = convert_to_varlen_np( + batch_size, + seqlen_q, + seqlen_kv, + sri_np, + causal, + ) + + cu_seqlens_q_np = result["cu_seqlens_q"] + cu_seqlens_k_np = result["cu_seqlens_k"] + + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) + cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) + + result["cu_seqlens_q"] = cu_seqlens_q + result["cu_seqlens_k"] = cu_seqlens_k + + return result + def flashmask_attention( query: paddle.Tensor, key: paddle.Tensor,