diff --git a/flashmask/flash_mask/flash_attn_v4/copy_utils.py b/flashmask/flash_mask/flash_attn_v4/copy_utils.py index 7b7b86eb4a5..b6e5c6911fe 100644 --- a/flashmask/flash_mask/flash_attn_v4/copy_utils.py +++ b/flashmask/flash_mask/flash_attn_v4/copy_utils.py @@ -90,6 +90,24 @@ def tiled_copy_1d( def tiled_copy_2d( + dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width + copy_elems = num_copy_bits // dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + gmem_threads_per_row = major_mode_size // copy_elems + assert num_threads % gmem_threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // gmem_threads_per_row, gmem_threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + +# TODO(umiswing): Support calling the Quack kernel directly +# from Paddle rather than duplicating it here. +def quack_tiled_copy_2d( dtype: Type[cutlass.Numeric], threads_per_row: int, num_threads: int, @@ -107,7 +125,6 @@ def tiled_copy_2d( val_layout = cute.make_layout((1, num_copy_elems)) return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) - @dsl_user_op def atomic_add_fp32x4( a: Float32, b: Float32, c: Float32, d: Float32, gmem_ptr: cute.Pointer, *, loc=None, ip=None diff --git a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py index 2f7d2af50ca..ffe2384943b 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_bwd_postprocess.py @@ -177,7 +177,7 @@ def _setup_attributes(self): num_copy_elems = 128 // self.dtype.width threads_per_row = math.gcd(128, self.tile_hdim) // num_copy_elems - self.gmem_tiled_copy_dQ = copy_utils.tiled_copy_2d( + self.gmem_tiled_copy_dQ = copy_utils.quack_tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) # /////////////////////////////////////////////////////////////////////////////// diff --git a/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py b/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py index 76946556d4c..13f9c5c1da5 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_bwd_preprocess.py @@ -111,7 +111,7 @@ def _setup_attributes(self): ) num_copy_elems = 128 // self.dtype.width threads_per_row = gmem_k_block_size // num_copy_elems - self.gmem_tiled_copy_O = copy_utils.tiled_copy_2d( + self.gmem_tiled_copy_O = copy_utils.quack_tiled_copy_2d( self.dtype, threads_per_row, self.num_threads, num_copy_elems ) universal_copy_bits = 128 @@ -367,4 +367,4 @@ def kernel( gLSElog2 = cute.local_tile(mLSElog2_cur, (self.tile_m,), (m_block,)) LOG2_E = math.log2(math.e) if tidx < seqlen_q_rounded - m_block * self.tile_m: - gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 \ No newline at end of file + gLSElog2[tidx] = lse * LOG2_E if lse != -Float32.inf else 0.0 diff --git a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py index e46348ce6bb..d14f916c5f4 100644 --- a/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py +++ b/flashmask/flash_mask/flash_attn_v4/flash_fwd_sm100.py @@ -579,7 +579,7 @@ def __call__( async_copy_elems = 128 // self.q_dtype.width num_load_threads = cute.arch.WARP_SIZE * len(self.load_warp_ids) threads_per_row = math.gcd(self.head_dim_padded // async_copy_elems, num_load_threads) - gmem_tiled_copy_Q = copy_utils.tiled_copy_2d( + gmem_tiled_copy_Q = copy_utils.quack_tiled_copy_2d( self.q_dtype, threads_per_row, num_load_threads, async_copy_elems, is_async=True ) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py index 915aed0d25b..3b6cf2f462f 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward.py @@ -23,6 +23,8 @@ from flash_mask.flash_attn_v4.sm100_hd256_2cta_fmha_backward_dkdvkernel import ( BlackwellFusedMultiHeadAttentionBackwardDKDVKernel, ) +from flash_mask.flash_attn_v4.cute_dsl_utils import assume_tensor_aligned + def _as_bshkrd_tensor( tensor: cute.Tensor, @@ -252,6 +254,8 @@ def __call__( else: b = Q.shape[0] + Q, K, V, dQ, dK, dV, dO = [assume_tensor_aligned(t) for t in (Q, K, V, dQ, dK, dV, dO)] + Q = _as_bshkrd_tensor(Q, h_k, h_r, varlen) K = _as_bshkrd_tensor(K, h_k, 1, varlen) V = _as_bshkrd_tensor(V, h_k, 1, varlen) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py index 29b8f6750ad..2fa89fd9ef8 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dkdvkernel.py @@ -32,6 +32,7 @@ Sm100FmhaStaticTileSchedulerParams as FmhaStaticTileSchedulerParams, ) +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils LAYOUT_RANK_CONSTANT = 3 @@ -2811,13 +2812,11 @@ def epilogue_clear( dK.iterator + mdK_offset, cute.make_layout((K, self.tile_shape_dQ_K, HB), stride=dK.stride), ) - gdK = cute.local_tile( - mdK, (self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1]), (None, None, None) - ) + gdK = cute.local_tile(mdK, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdK = gdK[None, None, blk_coord_k, 0, blk_coord_batch] cdK = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.dSQ_mma_tiler[0], self.dSQ_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) mdV_offset = cute.assume(blk_offset[1] * dV.stride[0], divby=64) @@ -2825,24 +2824,41 @@ def epilogue_clear( dV.iterator + mdV_offset, cute.make_layout((K, self.tile_shape_dV_dO, HB), stride=dV.stride), ) - gdV = cute.local_tile( - mdV, (self.PdO_mma_tiler[0], self.PdO_mma_tiler[1]), (None, None, None) - ) + gdV = cute.local_tile(mdV, (self.cta_tiler[1], self.cta_tiler[2]), (None, None, None)) gdV = gdV[None, None, blk_coord_k, 0, blk_coord_batch] cdV = cute.domain_offset( (blk_coord_k * self.tile_shape_K, 0), - cute.make_identity_tensor((self.PdO_mma_tiler[0], self.PdO_mma_tiler[1])), + cute.make_identity_tensor((self.cta_tiler[1], self.cta_tiler[2])), ) - for i in cutlass.range(tidx * 8, cute.size(gdK), block_dim_x * 8): - if cute.elem_less(cdK[i], cute.select(problem_shape, mode=[1, 2])): - gdK_i = cute.make_tensor(gdK.iterator + cute.assume(i, divby=8), (8)) - gdK_i.fill(0) + num_zero_epi_threads = 256 + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + dK.element_type, self.cta_tiler[2], num_zero_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + tRG_gdK = thr_copy_r2g.partition_D(gdK) + tRG_cdK = thr_copy_r2g.partition_D(cdK) + tRG_gdV = thr_copy_r2g.partition_D(gdV) + tRG_cdV = thr_copy_r2g.partition_D(cdV) + + zero_frg = cute.make_rmem_tensor_like(tRG_gdK[None, 0, None]) + zero_frg.fill(dK.element_type(0.0)) + + # check we don't need zero fragment duplication + V_frg_size = cute.size(tRG_gdV[None, 0, None]) + assert cute.size(zero_frg) == V_frg_size + + if tidx < num_zero_epi_threads: + for n in cutlass.range(cute.size(tRG_gdK.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdK[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdK[None, n, None]) - for i in cutlass.range(tidx * 8, cute.size(gdV), block_dim_x * 8): - if cute.elem_less(cdV[i], cute.select(problem_shape, mode=[1, 2])): - gdV_i = cute.make_tensor(gdV.iterator + cute.assume(i, divby=8), (8)) - gdV_i.fill(0) + for n in cutlass.range(cute.size(tRG_gdV.shape[1]), unroll_full=True): + if cute.elem_less(tRG_cdV[0, n, 0][0], problem_shape[1]): + cute.copy(tiled_copy_r2g, zero_frg, tRG_gdV[None, n, None]) @cute.jit def epilogue( diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py index d5be6b40f4d..0774fd23385 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_backward_dqkernel.py @@ -29,6 +29,7 @@ Sm100FusedMask as FusedMask, ) from flash_mask.flash_attn_v4.tile_scheduler import SM100_TMEM_CAPACITY_COLUMNS +import flash_mask.flash_attn_v4.copy_utils as fa_copy_utils class BlackwellFusedMultiHeadAttentionBackwardDQKernel: @@ -924,36 +925,45 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False batch_coord = curr_block_coord[2][1] seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) cuseqlen_k = Int32(0) - block_offset = ( - Int32(0), - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.qk_mma_tiler[0], + mma_block_coord[0], + seqlen_q, + ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k + + if has_work: block_offset = ( cuseqlen_q, cuseqlen_k, Int32(0), ((Int32(0), Int32(0)), Int32(0)), ) - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( - self.qk_mma_tiler[0], - mma_block_coord[0], - seqlen_q, - ) - if not continue_cond: mQ_qdl_ = cute.domain_offset(cute.select(block_offset, mode=[0, 2, 3]), mQ_qdl) mK_kdl_ = cute.domain_offset(cute.select(block_offset, mode=[1, 2, 3]), mK_kdl) mdO_qdl_ = cute.domain_offset( @@ -1057,18 +1067,6 @@ def kernel( # ((atom_v, rest_v), RestN, RestK) tKTgKT = tKgK_dkl[None, None, None, mma_block_coord[2]] - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) - ) # LSE lse_handle = load_lse_producer.acquire_and_advance() # 32 threads loading 128 values of 32b each @@ -1197,6 +1195,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1204,41 +1205,37 @@ def kernel( curr_block_coord[1], curr_block_coord[2], ) - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] batch_coord = curr_block_coord[2][1] + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k - - seqlen_kv_loop_start, seqlen_kv_loop_steps = ( - FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 + if has_work: # dq_handle = mma_dq_producer.acquire_and_advance() load_q_releaser = load_q_consumer.clone() load_do_releaser = load_do_consumer.clone() @@ -1836,33 +1833,35 @@ def kernel( curr_block_coord[2], ) batch_coord = curr_block_coord[2][1] - continue_cond = False seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + is_valid_k = trip_count > 0 + has_work = is_valid_q and is_valid_k - start_count, trip_count = FusedMask.get_trip_start_count_via_block_info( - mma_block_coord, - self.qk_mma_tiler, - seqlen_q, - seqlen_k, - self.is_causal, - self.is_local, - window_size_left, - window_size_right, - ) + if has_work: end_count = start_count + trip_count if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( @@ -1932,6 +1931,7 @@ def kernel( ) lse_handle.release() sum_odo_handle.release() + work_tile = tile_sched.advance_to_next_work() ds_mma_producer.tail() @@ -1952,61 +1952,75 @@ def kernel( # cute.printf("batch_coord={}", batch_coord) seqlen_q = mQ_qdl.shape[0] seqlen_k = mK_kdl.shape[0] - continue_cond = False cuseqlen_q = Int32(0) + is_valid_q = True if cutlass.const_expr(cum_seqlen_q is not None): cuseqlen_q = cum_seqlen_q[batch_coord] seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q - continue_cond = not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + is_valid_q = FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( self.qk_mma_tiler[0], mma_block_coord[0], seqlen_q, ) + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + seqlen_kv_loop_start, seqlen_kv_loop_steps = ( + FusedMask.get_trip_start_count_via_block_info( + mma_block_coord, + self.qk_mma_tiler, + seqlen_q, + seqlen_k, + self.is_causal, + self.is_local, + window_size_left, + window_size_right, + ) + ) + is_valid_k = seqlen_kv_loop_steps > 0 + has_work = is_valid_q and is_valid_k - if not continue_cond: - if cutlass.const_expr(cum_seqlen_k is not None): - cuseqlen_k = cum_seqlen_k[batch_coord] - seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + mdQ_qdl_eff = mdQ_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + block_offset_dQ = (cuseqlen_q,) + (None,) * 2 + mdQ_qdl_eff = cute.domain_offset(block_offset_dQ, mdQ_qdl) - mdQ_qdl_eff = mdQ_qdl - if cutlass.const_expr(cum_seqlen_q is not None): - block_offset_dQ = ( - cuseqlen_q, - Int32(0), - Int32(0), - ((Int32(0), Int32(0)), Int32(0)), - ) - mdQ_qdl_eff = cute.domain_offset( - cute.select(block_offset_dQ, mode=[0, 2, 3]), mdQ_qdl - ) + # (bM, bN, loopM, loopN, loopL) + gdQ_qdl = cute.flat_divide( + mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + cdQ_qdl = cute.flat_divide( + cute.make_identity_tensor(mdQ_qdl_eff.shape), + cute.select(self.dsk_block_tiler, mode=[0, 1]), + ) - # (bM, bN, loopM, loopN, loopL) - gdQ_qdl = cute.flat_divide( - mdQ_qdl_eff, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - cdQ_qdl = cute.flat_divide( - cute.make_identity_tensor(mdQ_qdl_eff.shape), - cute.select(self.dsk_block_tiler, mode=[0, 1]), - ) + gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] + gdQ_tma_staged = gdQ_staged - gdQ_staged = gdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - cdQ_staged = cdQ_qdl[None, None, curr_block_coord[0], None, curr_block_coord[2]] - gdQ_tma_staged = gdQ_staged - if cutlass.const_expr(not varlen): - gdQ_tma_qdl = cute.flat_divide( - mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) - ) - gdQ_tma_staged = gdQ_tma_qdl[ - None, None, curr_block_coord[0], None, curr_block_coord[2] - ] + if cutlass.const_expr(not varlen): + gdQ_tma_qdl = cute.flat_divide( + mdQ_tma, cute.select(self.dsk_block_tiler, mode=[0, 1]) + ) + gdQ_tma_staged = gdQ_tma_qdl[ + None, None, curr_block_coord[0], None, curr_block_coord[2] + ] + if has_work: # dQ TMEM to GMEM mma_dq_consumer = self.dQ_epilogue( - (seqlen_q, cuseqlen_q, mQ_qdl.shape[0], batch_coord), + seqlen_q, (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged), self.epi_tile, (tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen), ) + else: + self.dQ_epilogue_write_zero( + seqlen_q, + gdQ_staged, + cdQ_staged, + ) + work_tile = tile_sched.advance_to_next_work() # NOTE: tmem.free() moved to kernel end to enable cluster-wide sync @@ -2181,12 +2195,11 @@ def compute_step( @cute.jit def dQ_epilogue( self, - value_args: Tuple, + seqlen_q: int, dq_args: Tuple, epi_tile: cute.Tile, tma_args: Tuple, ) -> Tuple[pipeline.PipelineConsumer, pipeline.PipelineProducer]: - seqlen_q, cuseqlen_q, total_q, batch_coord = value_args (mma_dq_consumer, gdQ_staged, cdQ_staged, tdQtdQ_staged) = dq_args tma_atom_dQ, gdQ_tma_staged, s_epi_dQ, varlen = tma_args dq_handle = mma_dq_consumer.wait_and_advance() @@ -2274,3 +2287,31 @@ def dQ_epilogue( cute.autovec_copy(tSMrdQ, tTMEM_LOADgdQ_i) dq_handle.release() return mma_dq_consumer + + @cute.jit + def dQ_epilogue_write_zero( + self, + seqlen_q, + gdQ_staged, + cdQ_staged, + ): + num_epi_threads = self.threads_per_warp * len(self.epilogue_warp_ids) + tidx = cute.arch.thread_idx()[0] % num_epi_threads + + tiled_copy_r2g = fa_copy_utils.tiled_copy_2d( + self.dq_dtype, cute.size(gdQ_staged.shape[1]), num_epi_threads + ) + + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + tdQgdQ_staged = thr_copy_r2g.partition_D(gdQ_staged) + tdQcdQ_staged = thr_copy_r2g.partition_D(cdQ_staged) + + tdQrdQ = cute.make_rmem_tensor_like(tdQgdQ_staged[None, 0, None, 0]) + tdQrdQ.fill(self.dq_dtype(0.0)) + + for iter in cutlass.range(self.iterations_dsk, unroll_full=True): + tdQgdQ = tdQgdQ_staged[None, None, None, iter] + tdQcdQ = tdQcdQ_staged[None, None, None, iter] + for m in cutlass.range(cute.size(tdQgdQ.shape[1]), unroll_full=True): + if cute.elem_less(tdQcdQ[0, m, 0][0], seqlen_q): + cute.copy(tiled_copy_r2g, tdQrdQ, tdQgdQ[None, m, None]) diff --git a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py index 1ab8879375d..0c7b005f811 100644 --- a/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py +++ b/flashmask/flash_mask/flash_attn_v4/sm100_hd256_2cta_fmha_forward.py @@ -1030,6 +1030,9 @@ def kernel( if warp_idx == self.mma_warp_id: cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + is_leader_cta = cta_rank_in_cluster % 2 == 0 + while work_tile.is_valid_tile: curr_block_coord = work_tile.tile_idx mma_block_coord = ( @@ -1071,10 +1074,6 @@ def kernel( ) seqlen_kv_loop_end = seqlen_kv_loop_start + seqlen_kv_loop_steps - cta_rank_in_cluster = cute.arch.make_warp_uniform( - cute.arch.block_idx_in_cluster() - ) - is_leader_cta = cta_rank_in_cluster % 2 == 0 load_q_releaser = load_q_consumer.clone() pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, False) if seqlen_kv_loop_steps > 1: @@ -1323,6 +1322,11 @@ def kernel( window_size_right, ) end_count = start_count + trip_count + # require at least one softmax iteration for zero trip_count case; + # rely on masking this iteration for correctness + if end_count <= start_count: + start_count = 0 + end_count = 1 if cutlass.const_expr(self.use_semantic_trip_range): n_block_min_causal_local_mask, n_block_min_before_local_mask = ( FusedMask.get_trip_mask_bounds_via_block_info( @@ -1349,6 +1353,7 @@ def kernel( need_apply_mask = ( step >= n_block_min_causal_local_mask or step < n_block_min_before_local_mask + or step == end_count - 1 ) else: # Residual path only needs seqlen masking on the last K tile. @@ -1797,7 +1802,8 @@ def correction_epilog( row_sum = sSum[thread_idx] cute.arch.fence_view_async_shared() sum_handle.release() - scale = scale_output / row_sum + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum + scale = scale_output / row_sum if not row_sum_is_zero_or_nan else 0.0 o_handle = mma_o_consumer.wait_and_advance() for iter in cutlass.range(self.iterations_pv): gO = gO_staged[None, None, iter] @@ -1855,6 +1861,7 @@ def store_sum_max( sSum[thread_idx] = row_sum cute.arch.fence_view_async_shared() sum_handle.commit() + row_sum_is_zero_or_nan = row_sum == 0.0 or row_sum != row_sum if cutlass.const_expr(mLSE is not None): q_idx = current_block_coord[0] * self.cta_tiler[0] + tidx @@ -1863,7 +1870,11 @@ def store_sum_max( if cutlass.const_expr(cum_seqlen_q is not None) else current_block_coord[2] ) - lse_value = scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + lse_value = ( + scale_softmax * row_max + cute.math.log(row_sum, fastmath=True) + if not row_sum_is_zero_or_nan + else -Float32.inf + ) if cute.elem_less(q_idx, seqlen_q): global_q_idx = ( q_idx + cuseqlen_q if cutlass.const_expr(cum_seqlen_q is not None) else q_idx diff --git a/flashmask/flash_mask/interface.py b/flashmask/flash_mask/interface.py index 02c471b8999..af4307fa197 100644 --- a/flashmask/flash_mask/interface.py +++ b/flashmask/flash_mask/interface.py @@ -75,102 +75,143 @@ flash_attn_varlen_func = _mod.flash_attn_varlen_func flash_attn_combine = _mod.flash_attn_combine -def convert_to_varlen( - query: paddle.Tensor, - key: paddle.Tensor, - value: paddle.Tensor, - startend_row_indices: paddle.Tensor, +def convert_to_varlen_np( + batch_size: int, + seqlen_q: int, + seqlen_kv: int, + startend_row_indices: np.ndarray, causal: bool, ): - b, sq, hq, d = query.shape - _, skv, hkv, dv = value.shape - assert sq == skv + assert isinstance(startend_row_indices, np.ndarray), ( + f"Expected np.ndarray, got {type(startend_row_indices)}" + ) + assert seqlen_q == seqlen_kv _, hfm, _, bound_num = startend_row_indices.shape assert hfm == 1 assert bound_num == 1 or bound_num == 2 - - # ── Move startend_row_indices to numpy (single GPU→CPU transfer) ── - sri_np = startend_row_indices.numpy() # (batch, hfm, seqlen_k, bound_num) - s_np = sri_np[:, 0, :, 0] # (batch, seqlen_k) + s_np = startend_row_indices[:, 0, :, 0] # (batch, seqlen_k) # ── Vectorised boundary detection in numpy ── - # Compare consecutive elements per batch: (b, skv-1) + # Compare consecutive elements per batch: (batch_size, seqlen_kv-1) diffs = s_np[:, 1:] != s_np[:, :-1] # Real ends per batch - real_ends_np = s_np.max(axis=1) # (b,) - real_ends = real_ends_np.tolist() - needs_padding_fixup = bool(np.any(real_ends_np < skv)) - - # Per-batch boundary extraction (b is typically 1-2, loop is trivial) - cu_seqlens_list = [] - max_doc_len = 0 - for bi in range(b): + real_ends_np = s_np.max(axis=1) # (batch_size,) + + # Per-batch boundary extraction (batch_size is typically 1-2, loop is trivial) + cu_seqlens_q_list = [] + cu_seqlens_k_list = [] + max_doc_len_q = 0 + max_doc_len_k = 0 + for bi in range(batch_size): change_idx = np.nonzero(diffs[bi])[0].astype(np.int32) + 1 boundaries = np.concatenate([ np.zeros(1, dtype=np.int32), change_idx, - np.array([skv], dtype=np.int32), + np.array([seqlen_kv], dtype=np.int32), ]) - doc_lens = boundaries[1:] - boundaries[:-1] - max_doc_len = max(max_doc_len, int(doc_lens.max())) - cu_seqlens_list.append(boundaries[:-1] + np.int32(bi * skv)) + real_end = int(real_ends_np[bi]) + + if real_end < seqlen_kv: + # Has padding: only include real documents, then add two + # padding documents at tail. + # Find the boundary index where padding starts + real_boundaries = boundaries[boundaries <= real_end] + if real_boundaries[-1] != real_end: + real_boundaries = np.concatenate([ + real_boundaries, np.array([real_end], dtype=np.int32) + ]) + + doc_lens = real_boundaries[1:] - real_boundaries[:-1] + max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) + max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) + + pad_len = seqlen_kv - real_end + + # cu_seqlens_q: [...real_doc_starts, real_end, real_end, real_end+pad_len] + # -> docs: real docs, then (0 seqlen_q, pad_len seqlen_k), then (pad_len seqlen_q, 0 seqlen_k) + # cu_seqlens_k: [...real_doc_starts, real_end, real_end+pad_len, real_end+pad_len] + offset = np.int32(bi * seqlen_kv) + cu_seqlens_q_list.append(np.concatenate([ + real_boundaries[:-1] + offset, + np.array([real_end, real_end], dtype=np.int32) + offset, + ])) + cu_seqlens_k_list.append(np.concatenate([ + real_boundaries[:-1] + offset, + np.array([real_end, real_end + pad_len], dtype=np.int32) + offset, + ])) + + max_doc_len_k = max(max_doc_len_k, pad_len) + max_doc_len_q = max(max_doc_len_q, pad_len) + else: + # No padding + doc_lens = boundaries[1:] - boundaries[:-1] + max_doc_len_q = max(max_doc_len_q, int(doc_lens.max())) + max_doc_len_k = max(max_doc_len_k, int(doc_lens.max())) + cu_seqlens_q_list.append(boundaries[:-1] + np.int32(bi * seqlen_kv)) + cu_seqlens_k_list.append(boundaries[:-1] + np.int32(bi * seqlen_kv)) # Build cu_seqlens: one numpy concat, one paddle tensor creation - cu_seqlens_np = np.concatenate( - cu_seqlens_list + [np.array([b * skv], dtype=np.int32)] + cu_seqlens_q_np = np.concatenate( + cu_seqlens_q_list + [np.array([batch_size * seqlen_kv], dtype=np.int32)] + ) + cu_seqlens_k_np = np.concatenate( + cu_seqlens_k_list + [np.array([batch_size * seqlen_kv], dtype=np.int32)] ) - cu_seqlens = paddle.to_tensor(cu_seqlens_np) - - # ── Flatten q, k, v: (batch, seqlen, heads, dim) -> (total, heads, dim) - q_varlen = query.reshape([b * sq, hq, d]) - k_varlen = key.reshape([b * skv, hkv, d]) - v_varlen = value.reshape([b * skv, hkv, dv]) # ── Detect simulated causal masks (numpy) ──────────────────────── varlen_causal = causal if not causal and bound_num == 2: - lts_all = s_np # (b, skv), already extracted - ute_all = sri_np[:, 0, :, 1] # (b, skv) - arange_ref = np.arange(skv, dtype=np.int32).reshape(1, skv) + lts_all = s_np # (batch_size, seqlen_kv), already extracted + ute_all = startend_row_indices[:, 0, :, 1] # (batch_size, seqlen_kv) + arange_ref = np.arange(seqlen_kv, dtype=np.int32).reshape(1, seqlen_kv) expected_causal_ute = np.minimum(arange_ref, lts_all) if np.array_equal(ute_all, expected_causal_ute): varlen_causal = True result = { - "q": q_varlen, - "k": k_varlen, - "v": v_varlen, - "cu_seqlens_q": cu_seqlens, - "cu_seqlens_k": cu_seqlens, - "max_seqlen_q": max_doc_len, - "max_seqlen_k": max_doc_len, + "cu_seqlens_q": cu_seqlens_q_np, + "cu_seqlens_k": cu_seqlens_k_np, + "max_seqlen_q": max_doc_len_q, + "max_seqlen_k": max_doc_len_k, "causal": varlen_causal, } - # For non-causal masks with trailing padding: padding rows attend to - # nothing in flashmask (zero output), but varlen computes non-zero - # output for them. Zero out padding rows per batch to match flashmask. - # Note: real_end can differ across batch items. - if needs_padding_fixup: - _b, _sq = b, sq - _real_ends = real_ends - - def output_to_padded(out_varlen_pt): - nh = out_varlen_pt.shape[1] - dv_out = out_varlen_pt.shape[2] - out_padded = out_varlen_pt.reshape(_b, _sq, nh, dv_out) - # Vectorised per-batch zeroing - row_idx = paddle.arange(_sq) - real_end_t = paddle.to_tensor( - _real_ends, dtype=paddle.int64, - ).unsqueeze(1) - padding_mask = row_idx.unsqueeze(0) >= real_end_t # (b, sq) - out_padded[padding_mask] = 0 - return out_padded - - result["output_to_padded"] = output_to_padded + return result + +def convert_to_varlen( + batch_size: int, + seqlen_q: int, + seqlen_kv: int, + startend_row_indices: paddle.Tensor, + causal: bool, +): + assert seqlen_q == seqlen_kv + + _, hfm, _, bound_num = startend_row_indices.shape + assert hfm == 1 + assert bound_num == 1 or bound_num == 2 + + # ── Move startend_row_indices to numpy (single GPU→CPU transfer) ── + sri_np = startend_row_indices.numpy() # (batch, hfm, seqlen_k, bound_num) + + result = convert_to_varlen_np( + batch_size, + seqlen_q, + seqlen_kv, + sri_np, + causal, + ) + + cu_seqlens_q_np = result["cu_seqlens_q"] + cu_seqlens_k_np = result["cu_seqlens_k"] + + cu_seqlens_q = paddle.to_tensor(cu_seqlens_q_np) + cu_seqlens_k = paddle.to_tensor(cu_seqlens_k_np) + + result["cu_seqlens_q"] = cu_seqlens_q + result["cu_seqlens_k"] = cu_seqlens_k return result @@ -210,10 +251,10 @@ def flashmask_attention( ) batch_size, seqlen_q, nheads, d = query.shape - seqlen_k = key.shape[1] - assert seqlen_q == seqlen_k, ( - f"use_varlen requires seqlen_q == seqlen_k, ", - f"currently seqlen_q={seqlen_q}, seqlen_k={seqlen_k}", + _, seqlen_kv, nheads_kv, dv = value.shape + assert seqlen_q == seqlen_kv, ( + f"use_varlen requires seqlen_q == seqlen_kv, ", + f"currently seqlen_q={seqlen_q}, seqlen_kv={seqlen_kv}", ) dv = value.shape[-1] @@ -225,15 +266,15 @@ def flashmask_attention( ) varlen_args = convert_to_varlen( - query=query, - key=key, - value=value, + batch_size=batch_size, + seqlen_q=seqlen_q, + seqlen_kv=seqlen_kv, startend_row_indices=startend_row_indices, causal=causal) out, lse = flash_attn_varlen_func( - q=varlen_args["q"], - k=varlen_args["k"], - v=varlen_args["v"], + q=query.reshape(batch_size * seqlen_q, nheads, d), + k=key.reshape(batch_size * seqlen_kv, nheads_kv, d), + v=value.reshape(batch_size * seqlen_kv, nheads_kv, dv), cu_seqlens_q=varlen_args["cu_seqlens_q"], cu_seqlens_k=varlen_args["cu_seqlens_k"], max_seqlen_q=varlen_args["max_seqlen_q"], @@ -243,10 +284,7 @@ def flashmask_attention( return_lse=return_softmax_lse, ) - if "output_to_padded" in varlen_args: - out = varlen_args["output_to_padded"](out) - else: - out = out.reshape([batch_size, seqlen_q, nheads, dv]) + out = out.reshape([batch_size, seqlen_q, nheads, dv]) if return_softmax_lse: return [out, lse]