diff --git a/flashmask/flash_mask/flash_attn_v4/copy_utils.py b/flashmask/flash_mask/flash_attn_v4/copy_utils.py index 7b7b86eb4a5..04f9bb15ed5 100644 --- a/flashmask/flash_mask/flash_attn_v4/copy_utils.py +++ b/flashmask/flash_mask/flash_attn_v4/copy_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. import math from typing import Optional, Type, Callable @@ -90,6 +90,24 @@ def tiled_copy_1d( def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + +# TODO(umiswing): Support calling the Quack kernel directly +# from Paddle rather than duplicating it here. +def quack_tiled_copy_2d( dtype: Type[cutlass.Numeric], threads_per_row: int, num_threads: int, @@ -107,7 +125,6 @@ def tiled_copy_2d( val_layout = cute.make_layout((1, num_copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) - @dsl_user_op def atomic_add_fp32x4( a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None @@ -372,74 +389,3 @@ def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwarg ) return copy_fn - - -# --- Vendored from quack/copy_utils.py (BSD-3, Tri Dao et al.) --- -# Required by SM100 code paths that cannot depend on quack. - -BIG_INT = 2**30 -MAX_INT = 2**31 - 1 -BIG_INT_INV = 2**64 // BIG_INT - - -@dsl_user_op -def create_ragged_tensor_for_tma( - T: cute.Tensor, - ragged_dim: int = 0, - ptr_shift: bool = False, - *, - loc=None, - ip=None, -) -> cute.Tensor: - rank = cute.rank(T) - if ragged_dim < 0: - ragged_dim += rank - if ptr_shift: - assert rank <= 4, "ptr_shift ragged tensor only supports up to 4 dimensions" - new_shape = T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT,) - new_stride = T.stride + (T.stride[ragged_dim],) - ptr_offset = (None,) * ragged_dim + (-BIG_INT,) + (None,) * (rank - ragged_dim - 1) - new_ptr = cute.domain_offset(ptr_offset, T).iterator - return cute.make_tensor(new_ptr, cute.make_layout(new_shape, stride=new_stride)) - else: - assert rank <= 3, "non-ptr_shift ragged tensor only supports up to 3 dimensions" - stride_r = T.stride[ragged_dim] - new_shape = ( - T.shape[:ragged_dim] + (BIG_INT,) + T.shape[ragged_dim + 1 :] + (MAX_INT, MAX_INT) - ) - new_stride = ( - T.stride[:ragged_dim] - + (stride_r,) - + T.stride[ragged_dim + 1 :] - + (BIG_INT_INV - stride_r, stride_r) - ) - return cute.make_tensor(T.iterator, cute.make_layout(new_shape, stride=new_stride)) - - -@dsl_user_op -def offset_ragged_tensor( - T: cute.Tensor, - offset: Int32, - length: Int32, - ragged_dim: int = 0, - ptr_shift: bool = False, - *, - loc=None, - ip=None, -) -> cute.Tensor: - rank = cute.rank(T) - if ragged_dim < 0: - ragged_dim += rank - big_int = cute.size(T, mode=[ragged_dim]) - offset_val = big_int - length - if ptr_shift: - # 1-extra-dim: rank = original_rank + 1 - assert rank >= ragged_dim + 2 - offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 2) - index_tuple = (None,) * (rank - 1) + (offset + length,) - else: - # 2-extra-dim: rank = original_rank + 2, last 2 modes are the wraparound dims - assert rank >= ragged_dim + 3 - offset_tuple = (None,) * ragged_dim + (offset_val,) + (None,) * (rank - ragged_dim - 3) - index_tuple = (None,) * (rank - 2) + (big_int, offset + length) - return cute.domain_offset(offset_tuple, T[index_tuple]) diff --git a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py index 2f7d2af50ca..ffe2384943b 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py @@ -177,7 +177,7 @@ def _setup_attributes(self): num_copy_elems = 128 // self.dtype.width threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems - self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.gmem_tiled_copy_dQ = copy_utils.quack_tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) # /////////////////////////////////////////////////////////////////////////////// diff --git a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py index e46348ce6bb..d14f916c5f4 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py @@ -579,7 +579,7 @@ def __call__( async_copy_elems = 128 // self.q_dtype.width num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) - gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + gmem_tiled_copy_Q = copy_utils.quack_tiled_copy_2d( self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True ) diff --git a/flashmask/flash_mask/flash_attn_v4/paddle/interface.py b/flashmask/flash_mask/flash_attn_v4/paddle/interface.py index fc374abb2d7..7d2199d7241 100644 --- a/flashmask/flash_mask/flash_attn_v4/paddle/interface.py +++ b/flashmask/flash_mask/flash_attn_v4/paddle/interface.py @@ -2285,6 +2285,9 @@ def flash_attn_varlen_func( min_seqlen_k: for varlen, specifies the minimum kv sequence length for any batch. Used with gather_kv_indices to determine if we need oob masking. """ + print(f"wsm debug {cu_seqlens_q=}") + print(f"wsm debug {cu_seqlens_k=}") + print(f"wsm debug {causal=}") return FlashAttnVarlenFunc.apply( q, k, @@ -2552,4 +2555,4 @@ def flash_attn_combine( seqused, varlen_batch_idx=varlen_batch_idx, ) - return out, lse \ No newline at end of file + return out, lse diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py index 915aed0d25b..3b6cf2f462f 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py @@ -23,6 +23,8 @@ from flash_mask.flash_attn_v4.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_mask.flash_attn_v4.cute_dsl_utils import assume_tensor_aligned + def _as_bshkrd_tensor( tensor: cute.Tensor, @@ -252,6 +254,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 29b8f6750ad..2fa89fd9ef8 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -2811,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2825,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py index d5be6b40f4d..0774fd23385 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_mask.flash_attn_v4.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -924,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1057,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1197,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1204,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1836,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1932,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1952,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2181,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2274,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py index 1ab8879375d..0c7b005f811 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index 02c471b8999..d47972fcba9 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -75,35 +75,567 @@ flash_attn_varlen_func = _mod.flash_attn_varlen_func flash_attn_combine = _mod.flash_attn_combine -def convert_to_varlen( +def _find_dual_split(ute_doc, q_start, q_end, sk): + """Try to find a valid 2-way causal split for a document. + + When a document's ute pattern doesn't match a single causal formula, + this function checks if it can be split into two sub-calls where each + sub-call is individually causal. This arises in dual-chunk context + parallel where two Q chunks with different causal diagonals share the + same K range. + + Args: + ute_doc: ute values for K columns in this document (length sk). + q_start: start of Q range for this document. + q_end: end of Q range for this document. + sk: number of K columns in this document. + + Returns: + (q_mid, k_split) on success, None on failure. + - q_mid: the Q split point. Call 0 uses Q[q_start:q_mid], Call 1 uses Q[q_mid:q_end]. + - k_split: the K split point. Call 0 uses K[0:k_split], Call 1 uses K[0:sk] (full K). + """ + sq = q_end - q_start + if sq <= 1 or sk <= 1: + return None + + # Find plateau start: first j where ute[j] == ute[j+1] and ute[j] > q_start. + # The plateau value is q_mid (the boundary between the two Q chunks). + q_mid = None + k_split = None + + for j in range(sk - 1): + if int(ute_doc[j]) == int(ute_doc[j + 1]) and int(ute_doc[j]) > q_start: + q_mid = int(ute_doc[j]) + k_split = j + break + + if q_mid is None or q_mid >= q_end: + return None + + sq1 = q_mid - q_start + sk1 = k_split + sq2 = q_end - q_mid + sk2 = sk + + if sq1 <= 0 or sk1 <= 0 or sq2 <= 0: + return None + + # Verify sub-call 1: Q[q_start:q_mid] vs K[0:k_split], causal + for j in range(sk1): + expected = q_start + max(0, j - (sk1 - sq1)) + if int(ute_doc[j]) != expected: + return None + + # Verify sub-call 2: Q[q_mid:q_end] vs K[0:sk], causal + # effective_ute[j] = max(0, ute[j] - q_mid) should equal max(0, j - (sk2 - sq2)) + for j in range(sk2): + effective_ute = max(0, int(ute_doc[j]) - q_mid) + expected = max(0, j - (sk2 - sq2)) + if effective_ute != expected: + return None + + # Also verify Q[q_start:q_mid] does NOT attend to K[k_split:sk]: + # ute[j] >= q_mid for all j >= k_split + for j in range(k_split, sk): + if int(ute_doc[j]) < q_mid: + return None + + return (q_mid, k_split) + + +def _convert_to_varlen_bound2( query: paddle.Tensor, key: paddle.Tensor, value: paddle.Tensor, startend_row_indices: paddle.Tensor, causal: bool, ): + """Convert padded tensors to varlen format using lts+ute boundary detection. + + Supports asymmetric seqlen_q != seqlen_k and zero-length documents. + Requires bound_num == 2 (both lts and ute available). + + Handles Q padding: when last doc's lts < seqlen_q, merges the Q padding + into the last document and zeros out the padding Q positions via + output_to_padded. This avoids creating documents with seqlen_k=0. + + When a document's ute pattern cannot be expressed as a single causal call, + attempts a dual-split into two separate varlen calls (for dual-chunk CP). + Returns a result with "multi_call": True in that case. + """ b, sq, hq, d = query.shape - _, skv, hkv, dv = value.shape - assert sq == skv + _, sk, hkv, dv = value.shape - _, hfm, _, bound_num = startend_row_indices.shape - assert hfm == 1 - assert bound_num == 1 or bound_num == 2 + sri_np = startend_row_indices.numpy() # (batch, 1, sk, 2) + lts_np = sri_np[:, 0, :, 0] # (batch, sk) + ute_np = sri_np[:, 0, :, 1] # (batch, sk) + + cu_seqlens_q_list = [] + cu_seqlens_k_list = [] + max_doc_sq = 0 + max_doc_sk = 0 + needs_q_padding = False + q_padding_starts = [] + + for bi in range(b): + lts = lts_np[bi] # (sk,) + ute = ute_np[bi] # (sk,) + + # Detect K-side document boundaries: + # 1. Where lts changes (different document in the original mask) + # 2. Where dead/active status transitions (ute >= lts vs ute < lts) + # Dead K columns have no Q rows attending; they form zero-length Q docs. + k_starts = [0] + for j in range(1, sk): + if lts[j] != lts[j - 1]: + k_starts.append(j) + elif (ute[j - 1] >= lts[j - 1]) != (ute[j] >= lts[j]): + k_starts.append(j) + + # Build q_bounds and k_bounds + q_bounds = [np.int32(0)] + k_bounds = [np.int32(k) for k in k_starts] + [np.int32(sk)] + + for doc_idx in range(len(k_starts)): + k_start = k_starts[doc_idx] + if ute[k_start] >= lts[k_start]: + # Dead segment: no Q rows attend, zero-length Q doc + q_bounds.append(q_bounds[-1]) + else: + doc_lts = np.int32(lts[k_start]) + q_bounds.append(doc_lts) + + # Handle Q padding: if last doc doesn't cover all Q positions, + # merge Q padding into the last doc (like _convert_to_varlen_bound1). + q_pad_start = int(q_bounds[-1]) + q_padding_starts.append(q_pad_start) + if q_pad_start < sq: + needs_q_padding = True + q_bounds[-1] = np.int32(sq) + + # Compute per-doc max lengths + for i in range(len(q_bounds) - 1): + doc_sq = int(q_bounds[i + 1]) - int(q_bounds[i]) + doc_sk = int(k_bounds[i + 1]) - int(k_bounds[i]) + max_doc_sq = max(max_doc_sq, doc_sq) + max_doc_sk = max(max_doc_sk, doc_sk) + + # Add batch offsets + cu_seqlens_q_list.append( + np.array(q_bounds[:-1], dtype=np.int32) + np.int32(bi * sq) + ) + cu_seqlens_k_list.append( + np.array(k_bounds[:-1], dtype=np.int32) + np.int32(bi * sk) + ) + + # Concatenate with final sentinel + cu_seqlens_q_np = np.concatenate( + cu_seqlens_q_list + [np.array([b * sq], dtype=np.int32)] + ) + cu_seqlens_k_np = np.concatenate( + cu_seqlens_k_list + [np.array([b * sk], dtype=np.int32)] + ) + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) + cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) + + # Flatten tensors + q_varlen = query.reshape([b * sq, hq, d]) + k_varlen = key.reshape([b * sk, hkv, d]) + v_varlen = value.reshape([b * sk, hkv, dv]) + + # Causal detection: for causal=False, check if ute matches causal pattern. + # Uses the MERGED q_bounds/k_bounds (after Q padding merge) so that the + # causal formula uses the correct doc sizes. + varlen_causal = causal + if not causal: + is_causal = True + for bi in range(b): + lts = lts_np[bi] + ute = ute_np[bi] + # Reconstruct merged bounds for this batch (same logic as above) + k_starts_bi = [0] + for j in range(1, sk): + if lts[j] != lts[j - 1]: + k_starts_bi.append(j) + elif (ute[j - 1] >= lts[j - 1]) != (ute[j] >= lts[j]): + k_starts_bi.append(j) + # q_bounds_bi: dead segments get zero-length Q, active get lts + # Use PRE-padding q_bounds for causal check (padding inflates last + # doc's sq beyond sk, corrupting the causal formula). + q_bounds_bi = [0] + for idx in range(len(k_starts_bi)): + k_start_idx = k_starts_bi[idx] + if ute[k_start_idx] >= lts[k_start_idx]: + q_bounds_bi.append(q_bounds_bi[-1]) + else: + q_bounds_bi.append(int(lts[k_start_idx])) + # Do NOT apply Q padding merge here — causal check uses real Q extents. + k_bounds_bi = k_starts_bi + [sk] + + for doc_idx in range(len(k_starts_bi)): + k_start = k_bounds_bi[doc_idx] + k_end = k_bounds_bi[doc_idx + 1] + doc_sk = k_end - k_start + q_offset = q_bounds_bi[doc_idx] + doc_sq = q_bounds_bi[doc_idx + 1] - q_offset + + if doc_sq == 0: + # Dead doc: vacuously causal, skip + continue + else: + for j_local in range(doc_sk): + # Skip dead K columns (ute == lts means no Q attends) + if int(ute[k_start + j_local]) == int(lts[k_start + j_local]): + continue + expected = q_offset + max(0, j_local - (doc_sk - doc_sq)) + if int(ute[k_start + j_local]) != expected: + is_causal = False + break + if not is_causal: + break + if not is_causal: + break + if is_causal: + varlen_causal = True + + # When causal is detected AND Q padding exists, the padding merge + # inflates the last doc's sq beyond sk, shifting the causal diagonal + # incorrectly. Fix: undo the merge, trim Q to real tokens only, + # and pad output back to full sq with zeros afterward. + # NOTE: different batches may have different q_padding_starts, so + # we handle per-batch Q lengths and accumulate offsets properly. + if needs_q_padding: + # Rebuild cu_seqlens without padding: use real q_bounds, + # with per-batch Q offsets accumulated from actual real Q lengths. + cu_seqlens_q_list_nopad = [] + cu_seqlens_k_list_nopad = [] + max_doc_sq_nopad = 0 + max_doc_sk_nopad = 0 + q_offset_acc = 0 # accumulated Q offset across batches + + for bi in range(b): + lts = lts_np[bi] + ute = ute_np[bi] + k_starts_tmp = [0] + for j in range(1, sk): + if lts[j] != lts[j - 1]: + k_starts_tmp.append(j) + elif (ute[j - 1] >= lts[j - 1]) != (ute[j] >= lts[j]): + k_starts_tmp.append(j) + q_bounds_tmp = [np.int32(0)] + k_bounds_tmp = [np.int32(k) for k in k_starts_tmp] + [np.int32(sk)] + for idx in range(len(k_starts_tmp)): + k_start_idx = k_starts_tmp[idx] + if ute[k_start_idx] >= lts[k_start_idx]: + q_bounds_tmp.append(q_bounds_tmp[-1]) + else: + q_bounds_tmp.append(np.int32(lts[k_start_idx])) + # No merge: q_bounds_tmp[-1] = q_padding_starts[bi] + + for i in range(len(q_bounds_tmp) - 1): + dsq = int(q_bounds_tmp[i + 1]) - int(q_bounds_tmp[i]) + dsk = int(k_bounds_tmp[i + 1]) - int(k_bounds_tmp[i]) + max_doc_sq_nopad = max(max_doc_sq_nopad, dsq) + max_doc_sk_nopad = max(max_doc_sk_nopad, dsk) + + # Offsets based on accumulated real Q tokens + cu_seqlens_q_list_nopad.append( + np.array(q_bounds_tmp[:-1], dtype=np.int32) + np.int32(q_offset_acc) + ) + cu_seqlens_k_list_nopad.append( + np.array(k_bounds_tmp[:-1], dtype=np.int32) + np.int32(bi * sk) + ) + q_offset_acc += q_padding_starts[bi] + + total_real_q = q_offset_acc + cu_seqlens_q_nopad = paddle.to_tensor(np.concatenate( + cu_seqlens_q_list_nopad + [np.array([total_real_q], dtype=np.int32)] + )) + cu_seqlens_k_nopad = paddle.to_tensor(np.concatenate( + cu_seqlens_k_list_nopad + [np.array([b * sk], dtype=np.int32)] + )) + + # Trim Q: keep only real tokens [0:q_padding_starts[bi]] from + # each batch. Concatenate since lengths may differ per batch. + q_parts = [] + for bi in range(b): + q_parts.append(query[bi, :q_padding_starts[bi], :, :]) + q_trimmed = paddle.concat(q_parts, axis=0) # [total_real_q, hq, d] + + # output_to_padded: expand trimmed output back to [b, sq, nh, dv] + _b, _sq = b, sq + _q_padding_starts = list(q_padding_starts) + + def output_to_padded(out_varlen_pt): + nh = out_varlen_pt.shape[1] + dv_out = out_varlen_pt.shape[2] + out_padded = paddle.zeros([_b, _sq, nh, dv_out], dtype=out_varlen_pt.dtype) + offset = 0 + for bi_inner in range(_b): + real_len = _q_padding_starts[bi_inner] + out_padded[bi_inner, :real_len, :, :] = out_varlen_pt[offset:offset + real_len, :, :] + offset += real_len + return out_padded + + result = { + "q": q_trimmed, + "k": k_varlen, + "v": v_varlen, + "cu_seqlens_q": cu_seqlens_q_nopad, + "cu_seqlens_k": cu_seqlens_k_nopad, + "max_seqlen_q": max_doc_sq_nopad, + "max_seqlen_k": max_doc_sk_nopad, + "causal": varlen_causal, + "output_to_padded": output_to_padded, + } + return result + else: + # Single causal failed. Try dual-split for dual-chunk CP patterns. + dual_split = _try_dual_split_all_batches( + query, key, value, lts_np, ute_np, b, sq, sk, hq, hkv, d, dv + ) + if dual_split is not None: + return dual_split + + result = { + "q": q_varlen, + "k": k_varlen, + "v": v_varlen, + "cu_seqlens_q": cu_seqlens_q, + "cu_seqlens_k": cu_seqlens_k, + "max_seqlen_q": max_doc_sq, + "max_seqlen_k": max_doc_sk, + "causal": varlen_causal, + } + + # Zero out Q padding positions after kernel computes on the merged doc + if needs_q_padding: + _b, _sq = b, sq + _q_padding_starts = q_padding_starts + + def output_to_padded(out_varlen_pt): + nh = out_varlen_pt.shape[1] + dv_out = out_varlen_pt.shape[2] + out_padded = out_varlen_pt.reshape([_b, _sq, nh, dv_out]) + for bi_inner in range(_b): + q_end = _q_padding_starts[bi_inner] + if q_end < _sq: + out_padded[bi_inner, q_end:, :, :] = 0 + return out_padded + + result["output_to_padded"] = output_to_padded + + return result + + +def _try_dual_split_all_batches( + query, key, value, lts_np, ute_np, b, sq, sk, hq, hkv, d, dv +): + """Attempt to split into two varlen calls for dual-chunk CP patterns. + + Checks all batch items for a consistent dual-split. If found, returns + a multi-call result dict. Otherwise returns None. - # ── Move startend_row_indices to numpy (single GPU→CPU transfer) ── - sri_np = startend_row_indices.numpy() # (batch, hfm, seqlen_k, bound_num) + Handles dead K tails (ute >= lts) by excluding them from both calls, + and handles Q padding (when active docs don't cover the full Q range). + """ + # For each batch, reconstruct doc boundaries and try dual split on each doc. + split_q_mid = None + split_k_split = None + split_k_end = None # end of the active K range (excluding dead tail) + + for bi in range(b): + lts = lts_np[bi] + ute = ute_np[bi] + + # Reconstruct doc boundaries + k_starts_bi = [0] + for j in range(1, sk): + if lts[j] != lts[j - 1]: + k_starts_bi.append(j) + elif (ute[j - 1] >= lts[j - 1]) != (ute[j] >= lts[j]): + k_starts_bi.append(j) + + q_bounds_bi = [0] + for idx in range(len(k_starts_bi)): + k_start_idx = k_starts_bi[idx] + if ute[k_start_idx] >= lts[k_start_idx]: + q_bounds_bi.append(q_bounds_bi[-1]) + else: + q_bounds_bi.append(int(lts[k_start_idx])) + if q_bounds_bi[-1] < sq: + q_bounds_bi[-1] = sq + k_bounds_bi = k_starts_bi + [sk] + + # Find the non-causal document and try dual split + for doc_idx in range(len(k_starts_bi)): + k_start = k_bounds_bi[doc_idx] + k_end = k_bounds_bi[doc_idx + 1] + doc_sk = k_end - k_start + q_offset = q_bounds_bi[doc_idx] + doc_sq = q_bounds_bi[doc_idx + 1] - q_offset + + if doc_sq == 0: + continue + + # Check if this doc is already causal + doc_is_causal = True + for j_local in range(doc_sk): + if int(ute[k_start + j_local]) == int(lts[k_start + j_local]): + continue + expected = q_offset + max(0, j_local - (doc_sk - doc_sq)) + if int(ute[k_start + j_local]) != expected: + doc_is_causal = False + break + + if doc_is_causal: + continue + + # Try dual split on this document + ute_doc = ute[k_start:k_end] + split_result = _find_dual_split(ute_doc, q_offset, q_offset + doc_sq, doc_sk) + if split_result is None: + return None + + q_mid, k_split_doc = split_result + + if split_q_mid is None: + split_q_mid = q_mid + split_k_split = k_split_doc + split_k_end = k_end + elif split_q_mid != q_mid or split_k_split != k_split_doc: + # Inconsistent splits across batches/docs + return None + else: + # Track the K end (should be consistent across batches) + if split_k_end != k_end: + return None + + if split_q_mid is None: + return None + + # Build the two-call result + q_mid = split_q_mid + k_split = split_k_split + k_active_end = split_k_end # end of active K (excludes dead tail) + sq1 = q_mid # Q[0:q_mid] + sq2 = sq - q_mid # Q[q_mid:sq] + + # Determine if Q padding is needed. + # In dual-chunk CP, the actual attending Q range is [0, lts_value). + # If lts_value < sq, Q rows [lts_value:sq] are padding. + # Since q_mid comes from the active doc and q_end = doc's q_end, + # check if the second call's Q extends beyond the real Q range. + # Real Q end = lts value for the active doc = q_mid + (actual sq2). + # With dead K tail, the Q range is already determined by the active doc. + # If sq > q_mid + actual_sq2_from_split, we have padding in call 1. + # actual_sq2 = lts_np[0, 0] - q_mid (the lts of the active doc minus q_mid) + lts_val = int(lts_np[0, 0]) # lts is constant within the active doc + real_q_end = min(lts_val, sq) + real_sq2 = real_q_end - q_mid + needs_q_padding_call1 = (sq2 > real_sq2) + + # Call 0: Q[0:q_mid] vs K[0:k_split] + q0 = query[:, :q_mid, :, :].reshape([b * sq1, hq, d]) + k0 = key[:, :k_split, :, :].reshape([b * k_split, hkv, d]) + v0 = value[:, :k_split, :, :].reshape([b * k_split, hkv, dv]) + cu_seqlens_q0 = paddle.to_tensor( + np.arange(b + 1, dtype=np.int32) * np.int32(sq1) + ) + cu_seqlens_k0 = paddle.to_tensor( + np.arange(b + 1, dtype=np.int32) * np.int32(k_split) + ) + + # Call 1: Q[q_mid:sq] vs K[0:k_active_end] (active K only, excludes dead tail) + q1 = query[:, q_mid:, :, :].reshape([b * sq2, hq, d]) + k1 = key[:, :k_active_end, :, :].reshape([b * k_active_end, hkv, d]) + v1 = value[:, :k_active_end, :, :].reshape([b * k_active_end, hkv, dv]) + cu_seqlens_q1 = paddle.to_tensor( + np.arange(b + 1, dtype=np.int32) * np.int32(sq2) + ) + cu_seqlens_k1 = paddle.to_tensor( + np.arange(b + 1, dtype=np.int32) * np.int32(k_active_end) + ) + + result = { + "multi_call": True, + "calls": [ + { + "q": q0, + "k": k0, + "v": v0, + "cu_seqlens_q": cu_seqlens_q0, + "cu_seqlens_k": cu_seqlens_k0, + "max_seqlen_q": sq1, + "max_seqlen_k": k_split, + "causal": True, + }, + { + "q": q1, + "k": k1, + "v": v1, + "cu_seqlens_q": cu_seqlens_q1, + "cu_seqlens_k": cu_seqlens_k1, + "max_seqlen_q": sq2, + "max_seqlen_k": k_active_end, + "causal": True, + }, + ], + "q_split": q_mid, + } + + # Handle Q padding: if the real Q range doesn't fill sq, + # we need to zero out the padded Q positions in the output. + if needs_q_padding_call1: + _b = b + _sq = sq + _q_mid = q_mid + _real_sq2 = real_sq2 + + def output_to_padded(out): + """Zero out Q padding positions in call 1's output portion.""" + # out is [batch, sq, nheads, dv] after concat in flashmask_attention + for bi_inner in range(_b): + pad_start = _q_mid + _real_sq2 + if pad_start < _sq: + out[bi_inner, pad_start:, :, :] = 0 + return out + + result["output_to_padded"] = output_to_padded + + return result + + +def _convert_to_varlen_bound1( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + startend_row_indices: paddle.Tensor, + causal: bool, +): + """Convert padded tensors to varlen format using lts-only boundary detection. + + Only supports symmetric seqlen_q == seqlen_k. Kept for backward compatibility + with bound_num == 1 masks. + """ + b, sq, hq, d = query.shape + _, skv, hkv, dv = value.shape + assert sq == skv, ( + f"bound_num==1 requires seqlen_q == seqlen_k, " + f"got seqlen_q={sq}, seqlen_k={skv}" + ) + + sri_np = startend_row_indices.numpy() s_np = sri_np[:, 0, :, 0] # (batch, seqlen_k) - # ── Vectorised boundary detection in numpy ── - # Compare consecutive elements per batch: (b, skv-1) + # Boundary detection from lts changes diffs = s_np[:, 1:] != s_np[:, :-1] - # Real ends per batch - real_ends_np = s_np.max(axis=1) # (b,) + real_ends_np = s_np.max(axis=1) real_ends = real_ends_np.tolist() needs_padding_fixup = bool(np.any(real_ends_np < skv)) - # Per-batch boundary extraction (b is typically 1-2, loop is trivial) cu_seqlens_list = [] max_doc_len = 0 for bi in range(b): @@ -117,26 +649,16 @@ def convert_to_varlen( max_doc_len = max(max_doc_len, int(doc_lens.max())) cu_seqlens_list.append(boundaries[:-1] + np.int32(bi * skv)) - # Build cu_seqlens: one numpy concat, one paddle tensor creation cu_seqlens_np = np.concatenate( cu_seqlens_list + [np.array([b * skv], dtype=np.int32)] ) cu_seqlens = paddle.to_tensor(cu_seqlens_np) - # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) q_varlen = query.reshape([b * sq, hq, d]) k_varlen = key.reshape([b * skv, hkv, d]) v_varlen = value.reshape([b * skv, hkv, dv]) - # ── Detect simulated causal masks (numpy) ──────────────────────── varlen_causal = causal - if not causal and bound_num == 2: - lts_all = s_np # (b, skv), already extracted - ute_all = sri_np[:, 0, :, 1] # (b, skv) - arange_ref = np.arange(skv, dtype=np.int32).reshape(1, skv) - expected_causal_ute = np.minimum(arange_ref, lts_all) - if np.array_equal(ute_all, expected_causal_ute): - varlen_causal = True result = { "q": q_varlen, @@ -149,10 +671,6 @@ def convert_to_varlen( "causal": varlen_causal, } - # For non-causal masks with trailing padding: padding rows attend to - # nothing in flashmask (zero output), but varlen computes non-zero - # output for them. Zero out padding rows per batch to match flashmask. - # Note: real_end can differ across batch items. if needs_padding_fixup: _b, _sq = b, sq _real_ends = real_ends @@ -161,12 +679,11 @@ def output_to_padded(out_varlen_pt): nh = out_varlen_pt.shape[1] dv_out = out_varlen_pt.shape[2] out_padded = out_varlen_pt.reshape(_b, _sq, nh, dv_out) - # Vectorised per-batch zeroing row_idx = paddle.arange(_sq) real_end_t = paddle.to_tensor( _real_ends, dtype=paddle.int64, ).unsqueeze(1) - padding_mask = row_idx.unsqueeze(0) >= real_end_t # (b, sq) + padding_mask = row_idx.unsqueeze(0) >= real_end_t out_padded[padding_mask] = 0 return out_padded @@ -174,6 +691,37 @@ def output_to_padded(out_varlen_pt): return result + +def convert_to_varlen( + query: paddle.Tensor, + key: paddle.Tensor, + value: paddle.Tensor, + startend_row_indices: paddle.Tensor, + causal: bool, +): + _, hfm, _, bound_num = startend_row_indices.shape + b, sq, hq, d = query.shape + sk = value.shape[1] + + assert hfm == 1 + assert bound_num == 1 or bound_num == 2 + + if bound_num == 2: + assert not causal, ( + # TODO + ) + return _convert_to_varlen_bound2( + query, key, value, startend_row_indices, causal + ) + else: + assert sq == sk, ( + f"Asymmetric seqlen_q != seqlen_k requires bound_num == 2, " + f"got bound_num=1 with seqlen_q={sq}, seqlen_k={sk}" + ) + return _convert_to_varlen_bound1( + query, key, value, startend_row_indices, causal + ) + def flashmask_attention( query: paddle.Tensor, key: paddle.Tensor, @@ -211,10 +759,6 @@ def flashmask_attention( batch_size, seqlen_q, nheads, d = query.shape seqlen_k = key.shape[1] - assert seqlen_q == seqlen_k, ( - f"use_varlen requires seqlen_q == seqlen_k, ", - f"currently seqlen_q={seqlen_q}, seqlen_k={seqlen_k}", - ) dv = value.shape[-1] @@ -230,28 +774,69 @@ def flashmask_attention( value=value, startend_row_indices=startend_row_indices, causal=causal) - out, lse = flash_attn_varlen_func( - q=varlen_args["q"], - k=varlen_args["k"], - v=varlen_args["v"], - cu_seqlens_q=varlen_args["cu_seqlens_q"], - cu_seqlens_k=varlen_args["cu_seqlens_k"], - max_seqlen_q=varlen_args["max_seqlen_q"], - max_seqlen_k=varlen_args["max_seqlen_k"], - softmax_scale=softmax_scale, - causal=varlen_args["causal"], - return_lse=return_softmax_lse, - ) - if "output_to_padded" in varlen_args: - out = varlen_args["output_to_padded"](out) + if varlen_args.get("multi_call", False): + # Dual-split: two separate varlen calls with disjoint Q ranges. + outputs = [] + lses = [] + for call_args in varlen_args["calls"]: + out_i, lse_i = flash_attn_varlen_func( + q=call_args["q"], + k=call_args["k"], + v=call_args["v"], + cu_seqlens_q=call_args["cu_seqlens_q"], + cu_seqlens_k=call_args["cu_seqlens_k"], + max_seqlen_q=call_args["max_seqlen_q"], + max_seqlen_k=call_args["max_seqlen_k"], + softmax_scale=softmax_scale, + causal=call_args["causal"], + return_lse=return_softmax_lse, + ) + outputs.append(out_i) + lses.append(lse_i) + + # Reshape each output to [batch, sq_i, nheads, dv] then concat along seq dim + q_split = varlen_args["q_split"] + sq1 = q_split + sq2 = seqlen_q - q_split + out0 = outputs[0].reshape([batch_size, sq1, nheads, dv]) + out1 = outputs[1].reshape([batch_size, sq2, nheads, dv]) + out = paddle.concat([out0, out1], axis=1) + + # Zero out Q padding positions if needed + if "output_to_padded" in varlen_args: + out = varlen_args["output_to_padded"](out) + + if return_softmax_lse: + lse0 = lses[0].reshape([batch_size, nheads, sq1]) + lse1 = lses[1].reshape([batch_size, nheads, sq2]) + lse = paddle.concat([lse0, lse1], axis=2) + return [out, lse] + else: + return out else: - out = out.reshape([batch_size, seqlen_q, nheads, dv]) + out, lse = flash_attn_varlen_func( + q=varlen_args["q"], + k=varlen_args["k"], + v=varlen_args["v"], + cu_seqlens_q=varlen_args["cu_seqlens_q"], + cu_seqlens_k=varlen_args["cu_seqlens_k"], + max_seqlen_q=varlen_args["max_seqlen_q"], + max_seqlen_k=varlen_args["max_seqlen_k"], + softmax_scale=softmax_scale, + causal=varlen_args["causal"], + return_lse=return_softmax_lse, + ) - if return_softmax_lse: - return [out, lse] - else: - return out + if "output_to_padded" in varlen_args: + out = varlen_args["output_to_padded"](out) + else: + out = out.reshape([batch_size, seqlen_q, nheads, dv]) + + if return_softmax_lse: + return [out, lse] + else: + return out else: return _cute_flashmask_attention( query=query,